Commit 7a985548 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.9.0' into v0.9.0-ori

parents 45d3785c dc1440cf
...@@ -9,8 +9,11 @@ import torch ...@@ -9,8 +9,11 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, from vllm.utils import (
create_kv_caches_with_random) STR_DTYPE_TO_TORCH_DTYPE,
FlexibleArgumentParser,
create_kv_caches_with_random,
)
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -38,19 +41,15 @@ def main( ...@@ -38,19 +41,15 @@ def main(
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
scale = float(1.0 / (head_size**0.5)) scale = float(1.0 / (head_size**0.5))
query = torch.empty(num_seqs, query = torch.empty(
num_query_heads, num_seqs, num_query_heads, head_size, dtype=dtype, device=device
head_size, )
dtype=dtype,
device=device)
query.uniform_(-scale, scale) query.uniform_(-scale, scale)
assert num_query_heads % num_kv_heads == 0 assert num_query_heads % num_kv_heads == 0
alibi_slopes = None alibi_slopes = None
if use_alibi: if use_alibi:
alibi_slopes = torch.randn(num_query_heads, alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device)
dtype=torch.float,
device=device)
seq_lens = [seq_len for _ in range(num_seqs)] seq_lens = [seq_len for _ in range(num_seqs)]
max_seq_len = max(seq_lens) max_seq_len = max(seq_lens)
...@@ -61,24 +60,23 @@ def main( ...@@ -61,24 +60,23 @@ def main(
block_tables_lst: list[list[int]] = [] block_tables_lst: list[list[int]] = []
for _ in range(num_seqs): for _ in range(num_seqs):
block_table = [ block_table = [
random.randint(0, NUM_BLOCKS - 1) random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq)
for _ in range(max_num_blocks_per_seq)
] ]
block_tables_lst.append(block_table) block_tables_lst.append(block_table)
block_tables = torch.tensor(block_tables_lst, block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device)
dtype=torch.int,
device=device)
# Create the KV cache. # Create the KV cache.
key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, key_caches, value_caches = create_kv_caches_with_random(
block_size, NUM_BLOCKS,
1, block_size,
num_kv_heads, 1,
head_size, num_kv_heads,
kv_cache_dtype, head_size,
dtype, kv_cache_dtype,
device=device) dtype,
device=device,
)
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Prepare for the paged attention kernel. # Prepare for the paged attention kernel.
...@@ -86,11 +84,8 @@ def main( ...@@ -86,11 +84,8 @@ def main(
if version == "v2": if version == "v2":
if current_platform.is_rocm(): if current_platform.is_rocm():
global PARTITION_SIZE global PARTITION_SIZE
if not args.custom_paged_attn: PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM
PARTITION_SIZE = 1024 num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE
else:
PARTITION_SIZE = PARTITION_SIZE_ROCM
num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE)
tmp_output = torch.empty( tmp_output = torch.empty(
size=(num_seqs, num_query_heads, num_partitions, head_size), size=(num_seqs, num_query_heads, num_partitions, head_size),
dtype=output.dtype, dtype=output.dtype,
...@@ -110,9 +105,7 @@ def main( ...@@ -110,9 +105,7 @@ def main(
start_time = time.perf_counter() start_time = time.perf_counter()
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = torch.tensor(1.0, k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
dtype=torch.float32,
device=device)
for _ in range(num_iters): for _ in range(num_iters):
if version == "v1": if version == "v1":
...@@ -195,30 +188,29 @@ def main( ...@@ -195,30 +188,29 @@ def main(
print(f"Kernel running time: {latency * 1000000:.3f} us") print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__': if __name__ == "__main__":
logger.warning("This script benchmarks the paged attention kernel. " logger.warning(
"By default this is no longer used in vLLM inference.") "This script benchmarks the paged attention kernel. "
"By default this is no longer used in vLLM inference."
)
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.")
description="Benchmark the paged attention kernel.") parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2")
parser.add_argument("--version",
type=str,
choices=["v1", "v2"],
default="v2")
parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--seq-len", type=int, default=4096)
parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-query-heads", type=int, default=64)
parser.add_argument("--num-kv-heads", type=int, default=8) parser.add_argument("--num-kv-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument(
type=int, "--head-size",
choices=[64, 80, 96, 112, 120, 128, 192, 256], type=int,
default=128) choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
parser.add_argument("--use-alibi", action="store_true") parser.add_argument("--use-alibi", action="store_true")
parser.add_argument("--dtype", parser.add_argument(
type=str, "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
choices=["half", "bfloat16", "float"], )
default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument( parser.add_argument(
...@@ -228,10 +220,11 @@ if __name__ == '__main__': ...@@ -228,10 +220,11 @@ if __name__ == '__main__':
default="auto", default="auto",
help="Data type for kv cache storage. If 'auto', will use model " help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)",
parser.add_argument("--custom-paged-attn", )
action="store_true", parser.add_argument(
help="Use custom paged attention") "--custom-paged-attn", action="store_true", help="Use custom paged attention"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser ...@@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode() @torch.inference_mode()
def main(num_tokens: int, def main(
hidden_size: int, num_tokens: int,
static_scale: bool, hidden_size: int,
quant_dtype: torch.dtype, static_scale: bool,
dtype: torch.dtype, quant_dtype: torch.dtype,
seed: int = 0, dtype: torch.dtype,
do_profile: bool = False, seed: int = 0,
num_warmup_iters: int = 5, do_profile: bool = False,
num_iters: int = 100) -> None: num_warmup_iters: int = 5,
num_iters: int = 100,
) -> None:
current_platform.seed_everything(seed) current_platform.seed_everything(seed)
torch.set_default_device("cuda") torch.set_default_device("cuda")
...@@ -56,7 +58,7 @@ def main(num_tokens: int, ...@@ -56,7 +58,7 @@ def main(num_tokens: int,
print(f"Kernel running time: {latency * 1000000:.3f} us") print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__': if __name__ == "__main__":
def to_torch_dtype(dt): def to_torch_dtype(dt):
if dt == "int8": if dt == "int8":
...@@ -66,37 +68,40 @@ if __name__ == '__main__': ...@@ -66,37 +68,40 @@ if __name__ == '__main__':
raise ValueError(f"Unsupported dtype: {dt}") raise ValueError(f"Unsupported dtype: {dt}")
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.") description="Benchmark the quantization (fp8 or int8) kernel."
)
parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true") parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype", parser.add_argument(
type=str, "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8"
choices=["fp8", "int8"], )
default="int8") parser.add_argument(
parser.add_argument("--dtype", "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half"
type=str, )
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true") parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5) parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters", parser.add_argument(
type=int, "--num-iters",
default=100, type=int,
help="Number of benchmark iterations. " default=100,
"If --profile is set, this number is ignored") help="Number of benchmark iterations. "
"If --profile is set, this number is ignored",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
main(num_tokens=args.num_tokens, main(
hidden_size=args.hidden_size, num_tokens=args.num_tokens,
static_scale=args.static_scale, hidden_size=args.hidden_size,
quant_dtype=to_torch_dtype(args.quant_dtype), static_scale=args.static_scale,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], quant_dtype=to_torch_dtype(args.quant_dtype),
seed=args.seed, dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
do_profile=args.profile, seed=args.seed,
num_warmup_iters=args.num_warmup_iters, do_profile=args.profile,
num_iters=args.num_iters) num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters,
)
...@@ -4,15 +4,14 @@ import itertools ...@@ -4,15 +4,14 @@ import itertools
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import triton
from flashinfer.norm import fused_add_rmsnorm, rmsnorm from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from torch import nn from torch import nn
from vllm import _custom_ops as vllm_ops from vllm import _custom_ops as vllm_ops
from vllm.triton_utils import triton
class HuggingFaceRMSNorm(nn.Module): class HuggingFaceRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
...@@ -114,23 +113,19 @@ def rmsnorm_vllm( ...@@ -114,23 +113,19 @@ def rmsnorm_vllm(
def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
dtype = torch.bfloat16 dtype = torch.bfloat16
x = torch.randn(batch_size, x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None residual = torch.randn_like(x) if use_residual else None
output_naive = rmsnorm_naive( output_naive = rmsnorm_naive(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
output_flashinfer = rmsnorm_flashinfer( output_flashinfer = rmsnorm_flashinfer(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
output_vllm = rmsnorm_vllm( output_vllm = rmsnorm_vllm(
x.clone(), weight, x.clone(), weight, residual.clone() if residual is not None else None
residual.clone() if residual is not None else None) )
if use_residual: if use_residual:
output_naive = output_naive[0] output_naive = output_naive[0]
...@@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): ...@@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
print(f"FlashInfer output={output_flashinfer}") print(f"FlashInfer output={output_flashinfer}")
print(f"vLLM output={output_vllm}") print(f"vLLM output={output_vllm}")
if torch.allclose(output_naive, output_flashinfer, atol=1e-2, if torch.allclose(
rtol=1e-2) and torch.allclose( output_naive, output_flashinfer, atol=1e-2, rtol=1e-2
output_naive, output_vllm, atol=1e-2, rtol=1e-2): ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2):
print("✅ All implementations match") print("✅ All implementations match")
else: else:
print("❌ Implementations differ") print("❌ Implementations differ")
...@@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): ...@@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True):
batch_size_range = [2**i for i in range(0, 7, 2)] batch_size_range = [2**i for i in range(0, 7, 2)]
seq_length_range = [2**i for i in range(6, 11, 1)] seq_length_range = [2**i for i in range(6, 11, 1)]
head_num_range = [32, 48] head_num_range = [32, 48]
configs = list( configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range))
itertools.product(head_num_range, batch_size_range, seq_length_range))
def get_benchmark(use_residual): def get_benchmark(use_residual):
@triton.testing.perf_report( @triton.testing.perf_report(
triton.testing.Benchmark( triton.testing.Benchmark(
x_names=["head_num", "batch_size", "seq_len"], x_names=["head_num", "batch_size", "seq_len"],
...@@ -167,19 +160,15 @@ def get_benchmark(use_residual): ...@@ -167,19 +160,15 @@ def get_benchmark(use_residual):
line_names=["HuggingFace", "FlashInfer", "vLLM"], line_names=["HuggingFace", "FlashInfer", "vLLM"],
styles=[("blue", "-"), ("green", "-"), ("red", "-")], styles=[("blue", "-"), ("green", "-"), ("red", "-")],
ylabel="us", ylabel="us",
plot_name= plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual",
args={}, args={},
)) )
)
def benchmark(head_num, batch_size, seq_len, provider): def benchmark(head_num, batch_size, seq_len, provider):
dtype = torch.bfloat16 dtype = torch.bfloat16
hidden_size = head_num * 128 # assuming head_dim = 128 hidden_size = head_num * 128 # assuming head_dim = 128
x = torch.randn(batch_size, x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda")
seq_len,
hidden_size,
dtype=dtype,
device="cuda")
weight = torch.ones(hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda")
residual = torch.randn_like(x) if use_residual else None residual = torch.randn_like(x) if use_residual else None
...@@ -240,9 +229,9 @@ if __name__ == "__main__": ...@@ -240,9 +229,9 @@ if __name__ == "__main__":
default=4096, default=4096,
help="Hidden size (2nd dimension) of the sequence", help="Hidden size (2nd dimension) of the sequence",
) )
parser.add_argument("--use-residual", parser.add_argument(
action="store_true", "--use-residual", action="store_true", help="Whether to use residual connection"
help="Whether to use residual connection") )
parser.add_argument( parser.add_argument(
"--save-path", "--save-path",
type=str, type=str,
...@@ -253,10 +242,12 @@ if __name__ == "__main__": ...@@ -253,10 +242,12 @@ if __name__ == "__main__":
args = parser.parse_args() args = parser.parse_args()
# Run correctness test # Run correctness test
calculate_diff(batch_size=args.batch_size, calculate_diff(
seq_len=args.seq_len, batch_size=args.batch_size,
hidden_size=args.hidden_size, seq_len=args.seq_len,
use_residual=args.use_residual) hidden_size=args.hidden_size,
use_residual=args.use_residual,
)
# Get the benchmark function with proper use_residual setting # Get the benchmark function with proper use_residual setting
benchmark = get_benchmark(args.use_residual) benchmark = get_benchmark(args.use_residual)
......
...@@ -6,8 +6,7 @@ from typing import Optional ...@@ -6,8 +6,7 @@ from typing import Optional
import nvtx import nvtx
import torch import torch
from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope
get_rope)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
...@@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora( ...@@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora(
# silulating serving 4 LoRAs # silulating serving 4 LoRAs
scaling_factors = [1, 2, 4, 8] scaling_factors = [1, 2, 4, 8]
# batched RoPE can take multiple scaling factors # batched RoPE can take multiple scaling factors
batched_rope = get_rope(head_size, rotary_dim, max_position, base, batched_rope = get_rope(
is_neox_style, { head_size,
"rope_type": "linear", rotary_dim,
"factor": tuple(scaling_factors) max_position,
}) base,
is_neox_style,
{"rope_type": "linear", "factor": tuple(scaling_factors)},
)
# non-batched RoPE takes only one scaling factor, we create multiple # non-batched RoPE takes only one scaling factor, we create multiple
# instances to simulate the same behavior # instances to simulate the same behavior
non_batched_ropes: list[RotaryEmbedding] = [] non_batched_ropes: list[RotaryEmbedding] = []
for scaling_factor in scaling_factors: for scaling_factor in scaling_factors:
non_batched_ropes.append( non_batched_ropes.append(
get_rope(head_size, rotary_dim, max_position, base, is_neox_style, get_rope(
{ head_size,
"rope_type": "linear", rotary_dim,
"factor": (scaling_factor, ) max_position,
})) base,
is_neox_style,
{"rope_type": "linear", "factor": (scaling_factor,)},
)
)
positions = torch.randint(0, max_position, (batch_size, seq_len)) positions = torch.randint(0, max_position, (batch_size, seq_len))
query = torch.randn(batch_size, query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype)
seq_len,
num_heads * head_size,
dtype=dtype)
key = torch.randn_like(query) key = torch.randn_like(query)
# create query offsets for batched RoPE, we concat multiple kv cache # create query offsets for batched RoPE, we concat multiple kv cache
# together and each query needs to find the right kv cache of its type # together and each query needs to find the right kv cache of its type
offset_map = torch.tensor( offset_map = torch.tensor(
list( list(
accumulate([0] + [ accumulate(
max_position * scaling_factor * 2 [0]
for scaling_factor in scaling_factors[:-1] + [
]))) max_position * scaling_factor * 2
query_types = torch.randint(0, for scaling_factor in scaling_factors[:-1]
len(scaling_factors), (batch_size, seq_len), ]
device=device) )
)
)
query_types = torch.randint(
0, len(scaling_factors), (batch_size, seq_len), device=device
)
# map query types to offsets # map query types to offsets
query_offsets = offset_map[query_types] query_offsets = offset_map[query_types]
# the kernel takes flattened offsets # the kernel takes flattened offsets
...@@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora( ...@@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora(
torch.cuda.synchronize() torch.cuda.synchronize()
if __name__ == '__main__': if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description="Benchmark the rotary embedding kernels.") description="Benchmark the rotary embedding kernels."
)
parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--is-neox-style", type=bool, default=True)
parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--batch-size", type=int, default=16)
parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--seq-len", type=int, default=512)
parser.add_argument("--num-heads", type=int, default=8) parser.add_argument("--num-heads", type=int, default=8)
parser.add_argument("--head-size", parser.add_argument(
type=int, "--head-size",
choices=[64, 80, 96, 112, 120, 128, 192, 256], type=int,
default=128) choices=[64, 80, 96, 112, 120, 128, 192, 256],
default=128,
)
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
parser.add_argument("--dtype", parser.add_argument(
type=str, "--dtype", type=str, choices=["bfloat16", "float"], default="float"
choices=["bfloat16", "float"], )
default="float")
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--device", parser.add_argument(
type=str, "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0"
choices=["cuda:0", "cuda:1"], )
default="cuda:0")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
...@@ -14,14 +14,16 @@ import tqdm ...@@ -14,14 +14,16 @@ import tqdm
import triton import triton
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
_w8a8_block_fp8_matmul) _w8a8_block_fp8_matmul,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
assert current_platform.is_cuda( assert current_platform.is_cuda(), (
), "Only support tune w8a8 block fp8 kernel on CUDA device." "Only support tune w8a8 block fp8 kernel on CUDA device."
)
DTYPE_MAP = { DTYPE_MAP = {
"float32": torch.float32, "float32": torch.float32,
...@@ -40,7 +42,7 @@ def w8a8_block_matmul( ...@@ -40,7 +42,7 @@ def w8a8_block_matmul(
config: dict[str, Any], config: dict[str, Any],
output_dtype: torch.dtype = torch.float16, output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor: ) -> torch.Tensor:
"""This function performs matrix multiplication with """This function performs matrix multiplication with
block-wise quantization. block-wise quantization.
It takes two input tensors `A` and `B` with scales `As` and `Bs`. It takes two input tensors `A` and `B` with scales `As` and `Bs`.
...@@ -51,7 +53,7 @@ def w8a8_block_matmul( ...@@ -51,7 +53,7 @@ def w8a8_block_matmul(
B: The input tensor, e.g., weight. B: The input tensor, e.g., weight.
As: The per-token-group quantization scale for `A`. As: The per-token-group quantization scale for `A`.
Bs: The per-block quantization scale for `B`. Bs: The per-block quantization scale for `B`.
block_size: The block size for per-block quantization. block_size: The block size for per-block quantization.
It should be 2-dim, e.g., [128, 128]. It should be 2-dim, e.g., [128, 128].
output_dytpe: The dtype of the returned tensor. output_dytpe: The dtype of the returned tensor.
...@@ -71,18 +73,18 @@ def w8a8_block_matmul( ...@@ -71,18 +73,18 @@ def w8a8_block_matmul(
assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(N, block_n) == Bs.shape[0]
assert triton.cdiv(K, block_k) == Bs.shape[1] assert triton.cdiv(K, block_k) == Bs.shape[1]
C_shape = A.shape[:-1] + (N, ) C_shape = A.shape[:-1] + (N,)
C = A.new_empty(C_shape, dtype=output_dtype) C = A.new_empty(C_shape, dtype=output_dtype)
def grid(META): def grid(META):
return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * return (
triton.cdiv(N, META["BLOCK_SIZE_N"]), ) triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
if A.dtype == torch.float8_e4m3fn: if A.dtype == torch.float8_e4m3fn:
kernel = _w8a8_block_fp8_matmul kernel = _w8a8_block_fp8_matmul
else: else:
raise RuntimeError( raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
"Currently, only support tune w8a8 block fp8 kernel.")
kernel[grid]( kernel[grid](
A, A,
...@@ -119,14 +121,16 @@ def get_configs_compute_bound(): ...@@ -119,14 +121,16 @@ def get_configs_compute_bound():
for block_n in [32, 64, 128, 256]: for block_n in [32, 64, 128, 256]:
for num_warps in [4, 8]: for num_warps in [4, 8]:
for group_size in [1, 16, 32, 64]: for group_size in [1, 16, 32, 64]:
configs.append({ configs.append(
"BLOCK_SIZE_M": block_m, {
"BLOCK_SIZE_N": block_n, "BLOCK_SIZE_M": block_m,
"BLOCK_SIZE_K": block_k, "BLOCK_SIZE_N": block_n,
"GROUP_SIZE_M": group_size, "BLOCK_SIZE_K": block_k,
"num_warps": num_warps, "GROUP_SIZE_M": group_size,
"num_stages": num_stages, "num_warps": num_warps,
}) "num_stages": num_stages,
}
)
return configs return configs
...@@ -165,15 +169,9 @@ def get_weight_shapes(tp_size): ...@@ -165,15 +169,9 @@ def get_weight_shapes(tp_size):
return weight_shapes return weight_shapes
def benchmark_config(A, def benchmark_config(
B, A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10
As, ):
Bs,
block_size,
config,
out_dtype=torch.float16,
num_iters=10):
def run(): def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
...@@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): ...@@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type):
fp8_max, fp8_min = fp8_info.max, fp8_info.min fp8_max, fp8_min = fp8_info.max, fp8_info.min
A_fp32 = ( A_fp32 = (
(torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
fp8_max) )
A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
B_fp32 = ( B_fp32 = (
(torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
fp8_max) )
B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
else: else:
raise RuntimeError( raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.")
"Currently, only support tune w8a8 block fp8 kernel.")
block_n, block_k = block_size[0], block_size[1] block_n, block_k = block_size[0], block_size[1]
n_tiles = (N + block_n - 1) // block_n n_tiles = (N + block_n - 1) // block_n
k_tiles = (K + block_k - 1) // block_k k_tiles = (K + block_k - 1) // block_k
As = torch.rand(M, k_tiles, dtype=torch.float32, As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale
device="cuda") * factor_for_scale Bs = (
Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda")
factor_for_scale) * factor_for_scale
)
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
...@@ -267,7 +265,8 @@ def save_configs( ...@@ -267,7 +265,8 @@ def save_configs(
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
json_file_name = ( json_file_name = (
f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8,"
f"block_shape=[{block_n},{block_k}].json") f"block_shape=[{block_n},{block_k}].json"
)
config_file_path = os.path.join(save_path, json_file_name) config_file_path = os.path.join(save_path, json_file_name)
print(f"Writing best config to {config_file_path}...") print(f"Writing best config to {config_file_path}...")
...@@ -295,8 +294,7 @@ def tune_on_gpu(args_dict): ...@@ -295,8 +294,7 @@ def tune_on_gpu(args_dict):
search_space = get_configs_compute_bound() search_space = get_configs_compute_bound()
search_space = [ search_space = [
config for config in search_space config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0
if block_k % config["BLOCK_SIZE_K"] == 0
] ]
start = time.time() start = time.time()
...@@ -312,15 +310,11 @@ def tune_on_gpu(args_dict): ...@@ -312,15 +310,11 @@ def tune_on_gpu(args_dict):
out_dtype, out_dtype,
search_space, search_space,
input_type, input_type,
) for batch_size in tqdm(batch_sizes, )
desc=f"GPU {gpu_id} - Batch sizes") for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes")
] ]
best_configs = { best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)}
M: config save_configs(N, K, block_n, block_k, best_configs, save_path, input_type)
for M, config in zip(batch_sizes, benchmark_results)
}
save_configs(N, K, block_n, block_k, best_configs, save_path,
input_type)
end = time.time() end = time.time()
print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds")
...@@ -376,13 +370,14 @@ def main(args): ...@@ -376,13 +370,14 @@ def main(args):
process_args = [] process_args = []
for gpu_id in range(num_gpus): for gpu_id in range(num_gpus):
process_args.append({ process_args.append(
"gpu_id": gpu_id, {
"batch_sizes": batches_per_gpu[gpu_id], "gpu_id": gpu_id,
"weight_shapes": "batch_sizes": batches_per_gpu[gpu_id],
weight_shapes, # Each GPU processes all weight shapes "weight_shapes": weight_shapes, # Each GPU processes all weight shapes
"args": args, "args": args,
}) }
)
ctx = mp.get_context("spawn") ctx = mp.get_context("spawn")
with ctx.Pool(num_gpus) as pool: with ctx.Pool(num_gpus) as pool:
...@@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1: ...@@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1:
python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8
Then copy to model_executor/layers/quantization/utils/configs Then copy to model_executor/layers/quantization/utils/configs
""", """,
formatter_class=argparse.RawTextHelpFormatter) formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument("--tp-size", "-tp", type=int, default=8) parser.add_argument("--tp-size", "-tp", type=int, default=8)
parser.add_argument("--input-type", parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8")
type=str,
choices=["fp8"],
default="fp8")
parser.add_argument( parser.add_argument(
"--out-dtype", "--out-dtype",
type=str, type=str,
......
...@@ -6,13 +6,15 @@ import time ...@@ -6,13 +6,15 @@ import time
# Import DeepGEMM functions # Import DeepGEMM functions
import deep_gemm import deep_gemm
import torch import torch
import triton
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor
# Import vLLM functions # Import vLLM functions
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8, w8a8_block_fp8_matmul) per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
from vllm.triton_utils import triton
# Copied from # Copied from
......
...@@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser ...@@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of ' description="Benchmark the latency of processing a single batch of "
'requests till completion.') "requests till completion."
parser.add_argument('filename', type=str) )
parser.add_argument("filename", type=str)
args = parser.parse_args() args = parser.parse_args()
with open(args.filename, 'rb') as f: with open(args.filename, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
raw_results: list[TMeasurement] = data["results"] raw_results: list[TMeasurement] = data["results"]
...@@ -38,11 +39,7 @@ if __name__ == "__main__": ...@@ -38,11 +39,7 @@ if __name__ == "__main__":
raise Exception("MKN not found") raise Exception("MKN not found")
kernel = v.task_spec.description kernel = v.task_spec.description
results[KN].append({ results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median})
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2)) rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
...@@ -50,14 +47,16 @@ if __name__ == "__main__": ...@@ -50,14 +47,16 @@ if __name__ == "__main__":
for axs_idx, (shape, data) in enumerate(results.items()): for axs_idx, (shape, data) in enumerate(results.items()):
plt.sca(axs[axs_idx]) plt.sca(axs[axs_idx])
df = pd.DataFrame(data) df = pd.DataFrame(data)
sns.lineplot(data=df, sns.lineplot(
x="batch_size", data=df,
y="median", x="batch_size",
hue="kernel", y="median",
style="kernel", hue="kernel",
markers=True, style="kernel",
dashes=False, markers=True,
palette="Dark2") dashes=False,
palette="Dark2",
)
plt.title(f"Shape: {shape}") plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)") plt.ylabel("time (median, s)")
plt.tight_layout() plt.tight_layout()
......
...@@ -23,6 +23,7 @@ class ArgPool: ...@@ -23,6 +23,7 @@ class ArgPool:
For every invocation during a benchmarking run, it will choose a For every invocation during a benchmarking run, it will choose a
different value from the list. different value from the list.
""" """
values: Iterable[Any] values: Iterable[Any]
def __getitem__(self, index): def __getitem__(self, index):
...@@ -30,9 +31,7 @@ class ArgPool: ...@@ -30,9 +31,7 @@ class ArgPool:
class Bench: class Bench:
class ArgsIterator: class ArgsIterator:
def __init__(self, args_list, kwargs_list): def __init__(self, args_list, kwargs_list):
assert len(args_list) == len(kwargs_list) assert len(args_list) == len(kwargs_list)
self.args_list = args_list self.args_list = args_list
...@@ -53,10 +52,16 @@ class Bench: ...@@ -53,10 +52,16 @@ class Bench:
def n_args(self): def n_args(self):
return self.n return self.n
def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], def __init__(
label: str, sub_label: str, description: str, fn: Callable, self,
*args, **kwargs): cuda_graph_params: Optional[CudaGraphBenchParams],
label: str,
sub_label: str,
description: str,
fn: Callable,
*args,
**kwargs,
):
self.cuda_graph_params = cuda_graph_params self.cuda_graph_params = cuda_graph_params
self.use_cuda_graph = self.cuda_graph_params is not None self.use_cuda_graph = self.cuda_graph_params is not None
self.label = label self.label = label
...@@ -67,10 +72,8 @@ class Bench: ...@@ -67,10 +72,8 @@ class Bench:
# Process args # Process args
self._args = args self._args = args
self._kwargs = kwargs self._kwargs = kwargs
self.args_list, self.kwargs_list = self.collapse_argpool( self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs)
*args, **kwargs) self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list)
self.args_iterator = self.ArgsIterator(self.args_list,
self.kwargs_list)
# Cudagraph runner # Cudagraph runner
self.g = None self.g = None
...@@ -100,16 +103,13 @@ class Bench: ...@@ -100,16 +103,13 @@ class Bench:
for i in range(argpool_size): for i in range(argpool_size):
# collapse args; Just pick the ith value # collapse args; Just pick the ith value
args_list[i] = tuple([ args_list[i] = tuple(
arg[i] if isinstance(arg, ArgPool) else arg [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]]
for arg in args_list[i] )
])
# collapse kwargs # collapse kwargs
kwargs_i = kwargs_list[i] kwargs_i = kwargs_list[i]
arg_pool_keys = [ arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)]
k for k, v in kwargs_i.items() if isinstance(v, ArgPool)
]
for k in arg_pool_keys: for k in arg_pool_keys:
# again just pick the ith value # again just pick the ith value
kwargs_i[k] = kwargs_i[k][i] kwargs_i[k] = kwargs_i[k][i]
...@@ -142,7 +142,7 @@ class Bench: ...@@ -142,7 +142,7 @@ class Bench:
def run_cudagrah(self) -> TMeasurement: def run_cudagrah(self) -> TMeasurement:
assert self.use_cuda_graph assert self.use_cuda_graph
globals = {'g': self.g} globals = {"g": self.g}
return TBenchmark.Timer( return TBenchmark.Timer(
stmt="g.replay()", stmt="g.replay()",
...@@ -162,15 +162,15 @@ class Bench: ...@@ -162,15 +162,15 @@ class Bench:
has_arg_pool = self.args_iterator.n_args > 1 has_arg_pool = self.args_iterator.n_args > 1
if has_arg_pool: if has_arg_pool:
setup = ''' setup = """
args_iterator.reset() args_iterator.reset()
args_it = args_iterator.__next__() args_it = args_iterator.__next__()
''' """
stmt = ''' stmt = """
args, kwargs = next(args_it) args, kwargs = next(args_it)
fn(*args, **kwargs) fn(*args, **kwargs)
''' """
globals = {'fn': self.fn, 'args_iterator': self.args_iterator} globals = {"fn": self.fn, "args_iterator": self.args_iterator}
else: else:
# no arg pool. Just use the args and kwargs directly # no arg pool. Just use the args and kwargs directly
self.args_iterator.reset() self.args_iterator.reset()
...@@ -178,10 +178,10 @@ class Bench: ...@@ -178,10 +178,10 @@ class Bench:
args, kwargs = next(args_it) args, kwargs = next(args_it)
setup = "" setup = ""
stmt = ''' stmt = """
fn(*args, **kwargs) fn(*args, **kwargs)
''' """
globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} globals = {"fn": self.fn, "args": args, "kwargs": kwargs}
return TBenchmark.Timer( return TBenchmark.Timer(
stmt=stmt, stmt=stmt,
......
...@@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams ...@@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
# A very long prompt, total number of tokens is about 15k. # A very long prompt, total number of tokens is about 15k.
LONG_PROMPT = ["You are an expert in large language models, aren't you?" LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000
] * 1000 LONG_PROMPT = " ".join(LONG_PROMPT)
LONG_PROMPT = ' '.join(LONG_PROMPT)
def main(args): def main(args):
...@@ -30,32 +29,35 @@ def main(args): ...@@ -30,32 +29,35 @@ def main(args):
print("------start generating------") print("------start generating------")
for i in range(3): for i in range(3):
profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', profiler.runctx(
globals(), locals()) "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals()
)
# analyze the runtime of hashing function # analyze the runtime of hashing function
stats = pstats.Stats(profiler) stats = pstats.Stats(profiler)
stats.sort_stats('cumulative') stats.sort_stats("cumulative")
total_time = 0 total_time = 0
total_calls = 0 total_calls = 0
for func in stats.stats: for func in stats.stats:
if 'hash_of_block' in func[2]: if "hash_of_block" in func[2]:
total_time = stats.stats[func][3] total_time = stats.stats[func][3]
total_calls = stats.stats[func][0] total_calls = stats.stats[func][0]
percentage = (total_time / stats.total_tt) * 100 percentage = (total_time / stats.total_tt) * 100
print(f"Hashing took {total_time:.2f} seconds," print(
f"{percentage:.2f}% of the total runtime.") f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime."
)
if __name__ == "__main__": if __name__ == "__main__":
parser = FlexibleArgumentParser( parser = FlexibleArgumentParser(
description='Benchmark the performance of hashing function in' description="Benchmark the performance of hashing function in"
'automatic prefix caching.') "automatic prefix caching."
parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') )
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k")
parser.add_argument('--output-len', type=int, default=10) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument('--enable-prefix-caching', parser.add_argument("--output-len", type=int, default=10)
action='store_true', parser.add_argument(
help='enable prefix caching') "--enable-prefix-caching", action="store_true", help="enable prefix caching"
)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
# This local pyproject file is part of the migration from yapf to ruff format.
# It uses the same core rules as the main pyproject.toml file, but with the
# following differences:
# - ruff line length is overridden to 88
# - deprecated typing ignores (UP006, UP035) have been removed
[tool.ruff]
line-length = 88
exclude = [
# External file, leaving license intact
"examples/other/fp8/quantizer/quantize.py",
"vllm/vllm_flash_attn/flash_attn_interface.pyi"
]
[tool.ruff.lint.per-file-ignores]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
[tool.ruff.lint]
select = [
# pycodestyle
"E",
# Pyflakes
"F",
# pyupgrade
"UP",
# flake8-bugbear
"B",
# flake8-simplify
"SIM",
# isort
"I",
# flake8-logging-format
"G",
]
ignore = [
# star imports
"F405", "F403",
# lambda expression assignment
"E731",
# Loop control variable not used within loop body
"B007",
# f-string format
"UP032",
# Can remove once 3.10+ is the minimum Python version
"UP007",
]
[tool.ruff.lint.isort]
known-first-party = ["vllm"]
[tool.ruff.format]
docstring-code-format = true
\ No newline at end of file
#!/bin/bash #!/bin/bash
# Define the model to use # default values
MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"}
BACKEND=${BACKEND:-"vllm"}
# Define the backend to use DATASET=${DATASET:-"xgrammar_bench"}
BACKEND=${2:-"vllm"}
# Define the dataset to use
DATASET=${3:-"xgrammar_bench"}
# Define the guided decoding backend
GUIDED_BACKEND=${4:-"xgrammar"}
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"} OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"}
PORT=${PORT:-8000}
GUIDED_RATIO=${6:-0.5} STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1}
TOTAL_SECONDS=${TOTAL_SECONDS:-90}
MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300}
TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"}
usage() {
echo "Usage: $0 [options]"
echo "Options:"
echo " --model MODEL Model to benchmark (default: $MODEL)"
echo " --backend BACKEND Backend to use (default: $BACKEND)"
echo " --dataset DATASET Dataset to use (default: $DATASET)"
echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)"
echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)"
echo " --port PORT Port to use (default: $PORT)"
echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)"
echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)"
echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)"
echo " -h, --help Show this help message and exit"
exit 0
}
# parse command line arguments
while [[ $# -gt 0 ]]; do
case $1 in
--model)
MODEL="$2"
shift 2
;;
--backend)
BACKEND="$2"
shift 2
;;
--dataset)
DATASET="$2"
shift 2
;;
--max-new-tokens)
MAX_NEW_TOKENS="$2"
shift 2
;;
--output-dir)
OUTPUT_DIR="$2"
shift 2
;;
--port)
PORT="$2"
shift 2
;;
--structured-output-ratio)
STRUCTURED_OUTPUT_RATIO="$2"
shift 2
;;
--tokenizer-mode)
TOKENIZER_MODE="$2"
shift 2
;;
--total-seconds)
TOTAL_SECONDS="$2"
shift 2
;;
-h|--help)
usage
;;
*)
echo "Unknown argument: $1\n"
usage
;;
esac
done
# Create output directory if it doesn't exist # Create output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR" mkdir -p "$OUTPUT_DIR"
# Define QPS values to test # Define QPS values to test
QPS_VALUES=(70 60 50 25 20 15 10) QPS_VALUES=(25 20 15 10 5 1)
# Common parameters # Common parameters
COMMON_PARAMS="--backend $BACKEND \ COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \ --model $MODEL \
--dataset $DATASET \ --dataset $DATASET \
--structured-output-backend $GUIDED_BACKEND \ --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
--structured-output-ratio $GUIDED_RATIO \
--save-results \ --save-results \
--result-dir $OUTPUT_DIR" --result-dir $OUTPUT_DIR \
--output-len $MAX_NEW_TOKENS \
--port $PORT \
--tokenizer-mode $TOKENIZER_MODE"
echo "Starting structured output benchmark with model: $MODEL" echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND" echo "Backend: $BACKEND"
echo "Dataset: $DATASET" echo "Dataset: $DATASET"
echo "Structured output backend: $GUIDED_BACKEND"
echo "Results will be saved to: $OUTPUT_DIR" echo "Results will be saved to: $OUTPUT_DIR"
echo "----------------------------------------" echo "----------------------------------------"
...@@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do ...@@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run # Construct filename for this run
FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
echo "Running benchmark with $NUM_PROMPTS prompts"
# Run the benchmark # Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \ --request-rate $qps \
--result-filename "$FILENAME" \ --result-filename "$FILENAME" \
--tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ --num-prompts $NUM_PROMPTS
--port ${PORT:-8000}
echo "Completed benchmark with QPS: $qps" echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------" echo "----------------------------------------"
......
...@@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ...@@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
elseif(POWER10_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.2
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
set(ONEDNN_BUILD_TESTS "OFF")
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
set(DNNL_CPU_RUNTIME "OMP")
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl) list(APPEND LIBS dnnl)
endif() endif()
...@@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) ...@@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
"csrc/cpu/quant.cpp" "csrc/cpu/quant.cpp"
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
elseif(POWER10_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif() endif()
# #
...@@ -214,4 +245,4 @@ define_gpu_extension_target( ...@@ -214,4 +245,4 @@ define_gpu_extension_target(
WITH_SOABI WITH_SOABI
) )
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
\ No newline at end of file
...@@ -229,11 +229,26 @@ macro(set_gencode_flags_for_srcs) ...@@ -229,11 +229,26 @@ macro(set_gencode_flags_for_srcs)
"${multiValueArgs}" ${ARGN} ) "${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS}) foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}") # handle +PTX suffix: generate both sm and ptx codes if requested
set_gencode_flag_for_srcs( string(FIND "${_ARCH}" "+PTX" _HAS_PTX)
SRCS ${arg_SRCS} if(NOT _HAS_PTX EQUAL -1)
ARCH "compute_${_ARCH}" string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}")
CODE "sm_${_ARCH}") string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "compute_${_STRIPPED_ARCH}")
else()
string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_STRIPPED_ARCH}"
CODE "sm_${_STRIPPED_ARCH}")
endif()
endforeach() endforeach()
if (${arg_BUILD_PTX_FOR_ARCH}) if (${arg_BUILD_PTX_FOR_ARCH})
...@@ -252,7 +267,10 @@ endmacro() ...@@ -252,7 +267,10 @@ endmacro()
# #
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the # `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes. # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
# architecture in `SRC_CUDA_ARCHS`.
# The loose intersection is defined as: # The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator. # where `<=` is the version comparison operator.
...@@ -269,44 +287,63 @@ endmacro() ...@@ -269,44 +287,63 @@ endmacro()
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
# #
# Example With PTX:
# SRC_CUDA_ARCHS="8.0+PTX"
# TGT_CUDA_ARCHS="9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0+PTX"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}")
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS})
# handle +PTX suffix: separate base arch for matching, record PTX requests
set(_PTX_ARCHS)
foreach(_arch ${_SRC_CUDA_ARCHS})
if(_arch MATCHES "\\+PTX$")
string(REPLACE "+PTX" "" _base "${_arch}")
list(APPEND _PTX_ARCHS "${_base}")
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
list(APPEND _SRC_CUDA_ARCHS "${_base}")
endif()
endforeach()
list(REMOVE_DUPLICATES _PTX_ARCHS)
list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS)
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
set(_CUDA_ARCHS) set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS) if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS_) if ("9.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
set(_CUDA_ARCHS "9.0a") set(_CUDA_ARCHS "9.0a")
endif() endif()
endif() endif()
if ("10.0a" IN_LIST SRC_CUDA_ARCHS) if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
if ("10.0" IN_LIST TGT_CUDA_ARCHS) if ("10.0" IN_LIST TGT_CUDA_ARCHS)
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
set(_CUDA_ARCHS "10.0a") set(_CUDA_ARCHS "10.0a")
endif() endif()
endif() endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
# is less or equal to ARCH (but has the same major version since SASS binary # is less or equal to ARCH (but has the same major version since SASS binary
# compatibility is only forward compatible within the same major version). # compatibility is only forward compatible within the same major version).
foreach(_ARCH ${TGT_CUDA_ARCHS_}) foreach(_ARCH ${_TGT_CUDA_ARCHS})
set(_TMP_ARCH) set(_TMP_ARCH)
# Extract the major version of the target arch # Extract the major version of the target arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}")
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS})
# Extract the major version of the source arch # Extract the major version of the source arch
string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}")
# Check major-version match AND version-less-or-equal # Check version-less-or-equal, and allow PTX arches to match across majors
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR)
set(_TMP_ARCH "${_SRC_ARCH}") set(_TMP_ARCH "${_SRC_ARCH}")
endif() endif()
else() else()
...@@ -322,6 +359,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR ...@@ -322,6 +359,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
endforeach() endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS) list(REMOVE_DUPLICATES _CUDA_ARCHS)
# reapply +PTX suffix to architectures that requested PTX
set(_FINAL_ARCHS)
foreach(_arch ${_CUDA_ARCHS})
if(_arch IN_LIST _PTX_ARCHS)
list(APPEND _FINAL_ARCHS "${_arch}+PTX")
else()
list(APPEND _FINAL_ARCHS "${_arch}")
endif()
endforeach()
set(_CUDA_ARCHS ${_FINAL_ARCHS})
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction() endfunction()
......
...@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
int64_t num_tokens = input.numel() / input.size(-1); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \ dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \ dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \ VLLM_DISPATCH_FLOATING_TYPES( \
......
...@@ -17,660 +17,660 @@ ...@@ -17,660 +17,660 @@
* limitations under the License. * limitations under the License.
*/ */
#include <torch/all.h> #include <torch/all.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h> #include <c10/cuda/CUDAGuard.h>
#include <algorithm> #include <algorithm>
#include "attention_dtypes.h" #include "attention_dtypes.h"
#include "attention_utils.cuh" #include "attention_utils.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
#include "../quantization/fp8/amd/quant_utils.cuh" #include "../quantization/fp8/amd/quant_utils.cuh"
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#else #else
#include "../quantization/fp8/nvidia/quant_utils.cuh" #include "../quantization/fp8/nvidia/quant_utils.cuh"
#endif #endif
#ifndef USE_ROCM #ifndef USE_ROCM
#define WARP_SIZE 32 #define WARP_SIZE 32
#else #else
#define WARP_SIZE warpSize #define WARP_SIZE warpSize
#endif #endif
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
namespace vllm { namespace vllm {
// Utility function for attention softmax. // Utility function for attention softmax.
template <int NUM_WARPS> template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) { inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane. // Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE; int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE; int lane = threadIdx.x % WARP_SIZE;
// Compute the sum per warp. // Compute the sum per warp.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Warp leaders store the data to shared memory. // Warp leaders store the data to shared memory.
if (lane == 0) { if (lane == 0) {
red_smem[warp] = sum; red_smem[warp] = sum;
} }
// Make sure the data is in shared memory. // Make sure the data is in shared memory.
__syncthreads(); __syncthreads();
// The warps compute the final sums. // The warps compute the final sums.
if (lane < NUM_WARPS) { if (lane < NUM_WARPS) {
sum = red_smem[lane]; sum = red_smem[lane];
} }
// Parallel reduction inside the warp. // Parallel reduction inside the warp.
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += VLLM_SHFL_XOR_SYNC(sum, mask); sum += VLLM_SHFL_XOR_SYNC(sum, mask);
} }
// Broadcast to other threads. // Broadcast to other threads.
return VLLM_SHFL_SYNC(sum, 0); return VLLM_SHFL_SYNC(sum, 0);
} }
// TODO(woosuk): Merge the last two dimensions of the grid. // TODO(woosuk): Merge the last two dimensions of the grid.
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE = 0> // Zero means no partitioning. int PARTITION_SIZE = 0> // Zero means no partitioning.
__device__ void paged_attention_kernel( __device__ void paged_attention_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size] // head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
const int seq_idx = blockIdx.y; const int seq_idx = blockIdx.y;
const int partition_idx = blockIdx.z; const int partition_idx = blockIdx.z;
const int max_num_partitions = gridDim.z; const int max_num_partitions = gridDim.z;
constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0;
const int seq_len = seq_lens[seq_idx]; const int seq_len = seq_lens[seq_idx];
if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) {
// No work to do. Terminate the thread block. // No work to do. Terminate the thread block.
return; return;
} }
const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE);
const int num_blocks_per_partition = const int num_blocks_per_partition =
USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks;
// [start_block_idx, end_block_idx) is the range of blocks to process. // [start_block_idx, end_block_idx) is the range of blocks to process.
const int start_block_idx = const int start_block_idx =
USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0;
const int end_block_idx = const int end_block_idx =
MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);
const int num_blocks = end_block_idx - start_block_idx; const int num_blocks = end_block_idx - start_block_idx;
// [start_token_idx, end_token_idx) is the range of tokens to process. // [start_token_idx, end_token_idx) is the range of tokens to process.
const int start_token_idx = start_block_idx * BLOCK_SIZE; const int start_token_idx = start_block_idx * BLOCK_SIZE;
const int end_token_idx = const int end_token_idx =
MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);
const int num_tokens = end_token_idx - start_token_idx; const int num_tokens = end_token_idx - start_token_idx;
constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1);
constexpr int NUM_THREAD_GROUPS = constexpr int NUM_THREAD_GROUPS =
NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE
// divides NUM_THREADS // divides NUM_THREADS
assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); assert(NUM_THREADS % THREAD_GROUP_SIZE == 0);
constexpr int NUM_TOKENS_PER_THREAD_GROUP = constexpr int NUM_TOKENS_PER_THREAD_GROUP =
DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE);
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int thread_idx = threadIdx.x; const int thread_idx = threadIdx.x;
const int warp_idx = thread_idx / WARP_SIZE; const int warp_idx = thread_idx / WARP_SIZE;
const int lane = thread_idx % WARP_SIZE; const int lane = thread_idx % WARP_SIZE;
const int head_idx = blockIdx.x; const int head_idx = blockIdx.x;
const int num_heads = gridDim.x; const int num_heads = gridDim.x;
const int num_queries_per_kv = num_heads / num_kv_heads; const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv; const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = const float alibi_slope =
alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
// A vector type to store a part of a key or a query. // A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread // The vector size is configured in such a way that the threads in a thread
// group fetch or compute 16 bytes at a time. For example, if the size of a // group fetch or compute 16 bytes at a time. For example, if the size of a
// thread group is 4 and the data type is half, then the vector size is 16 / // thread group is 4 and the data type is half, then the vector size is 16 /
// (4 * sizeof(half)) == 2. // (4 * sizeof(half)) == 2.
constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE;
const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE;
// Load the query to registers. // Load the query to registers.
// Each thread in a thread group has a different part of the query. // Each thread in a thread group has a different part of the query.
// For example, if the the thread group size is 4, then the first thread in // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the query, and the second thread // the group has 0, 4, 8, ... th vectors of the query, and the second thread
// has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because
// q is split from a qkv tensor, it may not be contiguous. // q is split from a qkv tensor, it may not be contiguous.
const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE;
__shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD;
i += NUM_THREAD_GROUPS) { i += NUM_THREAD_GROUPS) {
const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE;
q_vecs[thread_group_offset][i] = q_vecs[thread_group_offset][i] =
*reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE); *reinterpret_cast<const Q_vec*>(q_ptr + vec_idx * VEC_SIZE);
} }
__syncthreads(); // TODO(naed90): possible speedup if this is replaced with a __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a
// memory wall right before we use q_vecs // memory wall right before we use q_vecs
// Memory planning. // Memory planning.
extern __shared__ char shared_mem[]; extern __shared__ char shared_mem[];
// NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy.
float* logits = reinterpret_cast<float*>(shared_mem); float* logits = reinterpret_cast<float*>(shared_mem);
// Workspace for reduction. // Workspace for reduction.
__shared__ float red_smem[2 * NUM_WARPS]; __shared__ float red_smem[2 * NUM_WARPS];
// x == THREAD_GROUP_SIZE * VEC_SIZE // x == THREAD_GROUP_SIZE * VEC_SIZE
// Each thread group fetches x elements from the key at a time. // Each thread group fetches x elements from the key at a time.
constexpr int x = 16 / sizeof(cache_t); constexpr int x = 16 / sizeof(cache_t);
float qk_max = -FLT_MAX; float qk_max = -FLT_MAX;
// Iterate over the key blocks. // Iterate over the key blocks.
// Each warp fetches a block of keys for each iteration. // Each warp fetches a block of keys for each iteration.
// Each thread group in a warp fetches a key from the block, and computes // Each thread group in a warp fetches a key from the block, and computes
// dot product with the query. // dot product with the query.
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
// blocksparse specific vars // blocksparse specific vars
int bs_block_offset; int bs_block_offset;
int q_bs_block_id; int q_bs_block_id;
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
// const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len,
// blocksparse_block_size); // blocksparse_block_size);
q_bs_block_id = (seq_len - 1) / blocksparse_block_size; q_bs_block_id = (seq_len - 1) / blocksparse_block_size;
if (blocksparse_head_sliding_step >= 0) if (blocksparse_head_sliding_step >= 0)
// sliding on q heads // sliding on q heads
bs_block_offset = bs_block_offset =
(tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1;
else else
// sliding on kv heads // sliding on kv heads
bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) *
(-blocksparse_head_sliding_step) + (-blocksparse_head_sliding_step) +
1; 1;
} }
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
const bool is_remote = const bool is_remote =
((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0);
const bool is_local = const bool is_local =
(k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks);
if (!is_remote && !is_local) { if (!is_remote && !is_local) {
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// NOTE(linxihui): assign very large number to skipped tokens to // NOTE(linxihui): assign very large number to skipped tokens to
// avoid contribution to the sumexp softmax normalizer. This will // avoid contribution to the sumexp softmax normalizer. This will
// not be used at computing sum(softmax*v) as the blocks will be // not be used at computing sum(softmax*v) as the blocks will be
// skipped. // skipped.
logits[token_idx - start_token_idx] = -FLT_MAX; logits[token_idx - start_token_idx] = -FLT_MAX;
} }
} }
continue; continue;
} }
} }
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
// Load a key to registers. // Load a key to registers.
// Each thread in a thread group has a different part of the key. // Each thread in a thread group has a different part of the key.
// For example, if the the thread group size is 4, then the first thread in // For example, if the thread group size is 4, then the first thread in
// the group has 0, 4, 8, ... th vectors of the key, and the second thread // the group has 0, 4, 8, ... th vectors of the key, and the second thread
// has 1, 5, 9, ... th vectors of the key, and so on. // has 1, 5, 9, ... th vectors of the key, and so on.
for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) {
const int physical_block_offset = const int physical_block_offset =
(thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
K_vec k_vecs[NUM_VECS_PER_THREAD]; K_vec k_vecs[NUM_VECS_PER_THREAD];
#pragma unroll #pragma unroll
for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
const cache_t* k_ptr = const cache_t* k_ptr =
k_cache + physical_block_number * kv_block_stride + k_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride + physical_block_offset * x; kv_head_idx * kv_head_stride + physical_block_offset * x;
const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
const int offset1 = (vec_idx * VEC_SIZE) / x; const int offset1 = (vec_idx * VEC_SIZE) / x;
const int offset2 = (vec_idx * VEC_SIZE) % x; const int offset2 = (vec_idx * VEC_SIZE) % x;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
k_vecs[j] = *reinterpret_cast<const K_vec*>( k_vecs[j] = *reinterpret_cast<const K_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
} else { } else {
// Vector conversion from Quant_vec to K_vec. // Vector conversion from Quant_vec to K_vec.
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, *k_scale); k_vec_quant, *k_scale);
} }
} }
// Compute dot product. // Compute dot product.
// This includes a reduction across the threads in the same thread group. // This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot( float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(
q_vecs[thread_group_offset], k_vecs); q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given. // Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0;
if (thread_group_offset == 0) { if (thread_group_offset == 0) {
// Store the partial reductions to shared memory. // Store the partial reductions to shared memory.
// NOTE(woosuk): It is required to zero out the masked logits. // NOTE(woosuk): It is required to zero out the masked logits.
const bool mask = token_idx >= seq_len; const bool mask = token_idx >= seq_len;
logits[token_idx - start_token_idx] = mask ? 0.f : qk; logits[token_idx - start_token_idx] = mask ? 0.f : qk;
// Update the max value. // Update the max value.
qk_max = mask ? qk_max : fmaxf(qk_max, qk); qk_max = mask ? qk_max : fmaxf(qk_max, qk);
} }
} }
} }
// Perform reduction across the threads in the same warp to get the // Perform reduction across the threads in the same warp to get the
// max qk value for each "warp" (not across the thread block yet). // max qk value for each "warp" (not across the thread block yet).
// The 0-th thread of each thread group already has its max qk value. // The 0-th thread of each thread group already has its max qk value.
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
if (lane == 0) { if (lane == 0) {
red_smem[warp_idx] = qk_max; red_smem[warp_idx] = qk_max;
} }
__syncthreads(); __syncthreads();
// TODO(woosuk): Refactor this part. // TODO(woosuk): Refactor this part.
// Get the max qk value for the sequence. // Get the max qk value for the sequence.
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll #pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
} }
// Broadcast the max qk value to all threads. // Broadcast the max qk value to all threads.
qk_max = VLLM_SHFL_SYNC(qk_max, 0); qk_max = VLLM_SHFL_SYNC(qk_max, 0);
// Get the sum of the exp values. // Get the sum of the exp values.
float exp_sum = 0.f; float exp_sum = 0.f;
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
float val = __expf(logits[i] - qk_max); float val = __expf(logits[i] - qk_max);
logits[i] = val; logits[i] = val;
exp_sum += val; exp_sum += val;
} }
exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum); exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], exp_sum);
// Compute softmax. // Compute softmax.
const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f);
for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) {
logits[i] *= inv_sum; logits[i] *= inv_sum;
} }
__syncthreads(); __syncthreads();
// If partitioning is enabled, store the max logit and exp_sum. // If partitioning is enabled, store the max logit and exp_sum.
if (USE_PARTITIONING && thread_idx == 0) { if (USE_PARTITIONING && thread_idx == 0) {
float* max_logits_ptr = max_logits + float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*max_logits_ptr = qk_max; *max_logits_ptr = qk_max;
float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions + partition_idx; head_idx * max_num_partitions + partition_idx;
*exp_sums_ptr = exp_sum; *exp_sums_ptr = exp_sum;
} }
// Each thread will fetch 16 bytes from the value cache at a time. // Each thread will fetch 16 bytes from the value cache at a time.
constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type; using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
using Float_L_vec = typename FloatVec<L_vec>::Type; using Float_L_vec = typename FloatVec<L_vec>::Type;
constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW;
constexpr int NUM_ROWS_PER_THREAD = constexpr int NUM_ROWS_PER_THREAD =
DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER);
// NOTE(woosuk): We use FP32 for the accumulator for better accuracy. // NOTE(woosuk): We use FP32 for the accumulator for better accuracy.
float accs[NUM_ROWS_PER_THREAD]; float accs[NUM_ROWS_PER_THREAD];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
accs[i] = 0.f; accs[i] = 0.f;
} }
scalar_t zero_value; scalar_t zero_value;
zero(zero_value); zero(zero_value);
for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx;
block_idx += NUM_WARPS) { block_idx += NUM_WARPS) {
// NOTE(woosuk): The block number is stored in int32. However, we cast it to // NOTE(woosuk): The block number is stored in int32. However, we cast it to
// int64 because int32 can lead to overflow when this variable is multiplied // int64 because int32 can lead to overflow when this variable is multiplied
// by large numbers (e.g., kv_block_stride). // by large numbers (e.g., kv_block_stride).
// For blocksparse attention: skip computation on blocks that are not // For blocksparse attention: skip computation on blocks that are not
// attended // attended
if constexpr (IS_BLOCK_SPARSE) { if constexpr (IS_BLOCK_SPARSE) {
int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size;
if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) &&
!((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) {
continue; continue;
} }
} }
const int64_t physical_block_number = const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]); static_cast<int64_t>(block_table[block_idx]);
const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE;
const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset;
L_vec logits_vec; L_vec logits_vec;
from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx -
start_token_idx)); start_token_idx));
const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride +
kv_head_idx * kv_head_stride; kv_head_idx * kv_head_stride;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE) { if (row_idx < HEAD_SIZE) {
const int offset = row_idx * BLOCK_SIZE + physical_block_offset; const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
V_vec v_vec; V_vec v_vec;
if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) {
v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
} else { } else {
V_quant_vec v_quant_vec = V_quant_vec v_quant_vec =
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset); *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
// Vector conversion from V_quant_vec to V_vec. // Vector conversion from V_quant_vec to V_vec.
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec, v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
*v_scale); *v_scale);
} }
if (block_idx == num_seq_blocks - 1) { if (block_idx == num_seq_blocks - 1) {
// NOTE(woosuk): When v_vec contains the tokens that are out of the // NOTE(woosuk): When v_vec contains the tokens that are out of the
// context, we should explicitly zero out the values since they may // context, we should explicitly zero out the values since they may
// contain NaNs. See // contain NaNs. See
// https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472
scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec); scalar_t* v_vec_ptr = reinterpret_cast<scalar_t*>(&v_vec);
#pragma unroll #pragma unroll
for (int j = 0; j < V_VEC_SIZE; j++) { for (int j = 0; j < V_VEC_SIZE; j++) {
v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value;
} }
} }
accs[i] += dot(logits_vec, v_vec); accs[i] += dot(logits_vec, v_vec);
} }
} }
} }
// Perform reduction within each warp. // Perform reduction within each warp.
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
float acc = accs[i]; float acc = accs[i];
#pragma unroll #pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += VLLM_SHFL_XOR_SYNC(acc, mask); acc += VLLM_SHFL_XOR_SYNC(acc, mask);
} }
accs[i] = acc; accs[i] = acc;
} }
// NOTE(woosuk): A barrier is required because the shared memory space for // NOTE(woosuk): A barrier is required because the shared memory space for
// logits is reused for the output. // logits is reused for the output.
__syncthreads(); __syncthreads();
// Perform reduction across warps. // Perform reduction across warps.
float* out_smem = reinterpret_cast<float*>(shared_mem); float* out_smem = reinterpret_cast<float*>(shared_mem);
#pragma unroll #pragma unroll
for (int i = NUM_WARPS; i > 1; i /= 2) { for (int i = NUM_WARPS; i > 1; i /= 2) {
int mid = i / 2; int mid = i / 2;
// Upper warps write to shared memory. // Upper warps write to shared memory.
if (warp_idx >= mid && warp_idx < i) { if (warp_idx >= mid && warp_idx < i) {
float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
dst[row_idx] = accs[i]; dst[row_idx] = accs[i];
} }
} }
} }
__syncthreads(); __syncthreads();
// Lower warps update the output. // Lower warps update the output.
if (warp_idx < mid) { if (warp_idx < mid) {
const float* src = &out_smem[warp_idx * HEAD_SIZE]; const float* src = &out_smem[warp_idx * HEAD_SIZE];
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
accs[i] += src[row_idx]; accs[i] += src[row_idx];
} }
} }
} }
__syncthreads(); __syncthreads();
} }
// Write the final output. // Write the final output.
if (warp_idx == 0) { if (warp_idx == 0) {
scalar_t* out_ptr = scalar_t* out_ptr =
out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE;
#pragma unroll #pragma unroll
for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) {
from_float(*(out_ptr + row_idx), accs[i]); from_float(*(out_ptr + row_idx), accs[i]);
} }
} }
} }
} }
// Grid: (num_heads, num_seqs, 1). // Grid: (num_heads, num_seqs, 1).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE> bool IS_BLOCK_SPARSE>
__global__ void paged_attention_v1_kernel( __global__ void paged_attention_v1_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x] // head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size] // head_size, block_size]
const int num_kv_heads, // [num_heads] const int num_kv_heads, // [num_heads]
const float scale, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs] const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank, const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE>( KV_DTYPE, IS_BLOCK_SPARSE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
v_cache, num_kv_heads, scale, block_tables, seq_lens, v_cache, num_kv_heads, scale, block_tables, seq_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
blocksparse_vert_stride, blocksparse_block_size, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); blocksparse_head_sliding_step);
} }
// Grid: (num_heads, num_seqs, max_num_partitions). // Grid: (num_heads, num_seqs, max_num_partitions).
template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE, template <typename scalar_t, typename cache_t, int HEAD_SIZE, int BLOCK_SIZE,
int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE, int NUM_THREADS, vllm::Fp8KVCacheDataType KV_DTYPE,
bool IS_BLOCK_SPARSE, bool IS_BLOCK_SPARSE,
int PARTITION_SIZE> int PARTITION_SIZE>
__global__ void paged_attention_v2_kernel( __global__ void paged_attention_v2_kernel(
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads, float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, // [num_heads]
const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ seq_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float* k_scale, const float* v_scale, const int tp_rank,
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step);
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__ void paged_attention_v2_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // max_num_partitions]
scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size] // max_num_partitions, head_size]
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const int* __restrict__ seq_lens, // [num_seqs]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, const int max_num_partitions) {
// head_size/x, block_size, x] const int num_heads = gridDim.x;
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, const int head_idx = blockIdx.x;
// head_size, block_size] const int seq_idx = blockIdx.y;
const int num_kv_heads, // [num_heads] const int seq_len = seq_lens[seq_idx];
const float scale, const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE);
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] if (num_partitions == 1) {
const int* __restrict__ seq_lens, // [num_seqs] // No need to reduce. Only copy tmp_out to out.
const int max_num_blocks_per_seq, scalar_t* out_ptr =
const float* __restrict__ alibi_slopes, // [num_heads] out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
const int q_stride, const int kv_block_stride, const int kv_head_stride, const scalar_t* tmp_out_ptr =
const float* k_scale, const float* v_scale, const int tp_rank, tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
const int blocksparse_local_blocks, const int blocksparse_vert_stride, head_idx * max_num_partitions * HEAD_SIZE;
const int blocksparse_block_size, const int blocksparse_head_sliding_step) { for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, out_ptr[i] = tmp_out_ptr[i];
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>( }
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, // Terminate the thread block.
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, return;
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, }
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
blocksparse_head_sliding_step); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
} const int warp_idx = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS, // Size: 2 * num_partitions.
int PARTITION_SIZE> extern __shared__ char shared_mem[];
__global__ void paged_attention_v2_reduce_kernel( // Workspace for reduction.
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] __shared__ float red_smem[2 * NUM_WARPS];
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions] // Load max logits to shared memory.
const float* __restrict__ max_logits, // [num_seqs, num_heads, float* shared_max_logits = reinterpret_cast<float*>(shared_mem);
// max_num_partitions] const float* max_logits_ptr = max_logits +
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, seq_idx * num_heads * max_num_partitions +
// max_num_partitions, head_size] head_idx * max_num_partitions;
const int* __restrict__ seq_lens, // [num_seqs] float max_logit = -FLT_MAX;
const int max_num_partitions) { for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
const int num_heads = gridDim.x; const float l = max_logits_ptr[i];
const int head_idx = blockIdx.x; shared_max_logits[i] = l;
const int seq_idx = blockIdx.y; max_logit = fmaxf(max_logit, l);
const int seq_len = seq_lens[seq_idx]; }
const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); __syncthreads();
if (num_partitions == 1) {
// No need to reduce. Only copy tmp_out to out. // Get the global max logit.
scalar_t* out_ptr = // Reduce within the warp.
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll
const scalar_t* tmp_out_ptr = for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
head_idx * max_num_partitions * HEAD_SIZE; }
for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { if (lane == 0) {
out_ptr[i] = tmp_out_ptr[i]; red_smem[warp_idx] = max_logit;
} }
// Terminate the thread block. __syncthreads();
return; // Reduce across warps.
} max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
const int warp_idx = threadIdx.x / WARP_SIZE; max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
const int lane = threadIdx.x % WARP_SIZE; }
// Broadcast the max value to all threads.
// Size: 2 * num_partitions. max_logit = VLLM_SHFL_SYNC(max_logit, 0);
extern __shared__ char shared_mem[];
// Workspace for reduction. // Load rescaled exp sums to shared memory.
__shared__ float red_smem[2 * NUM_WARPS]; float* shared_exp_sums =
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
// Load max logits to shared memory. const float* exp_sums_ptr = exp_sums +
float* shared_max_logits = reinterpret_cast<float*>(shared_mem); seq_idx * num_heads * max_num_partitions +
const float* max_logits_ptr = max_logits + head_idx * max_num_partitions;
seq_idx * num_heads * max_num_partitions + float global_exp_sum = 0.0f;
head_idx * max_num_partitions; for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float max_logit = -FLT_MAX; float l = shared_max_logits[i];
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
const float l = max_logits_ptr[i]; global_exp_sum += rescaled_exp_sum;
shared_max_logits[i] = l; shared_exp_sums[i] = rescaled_exp_sum;
max_logit = fmaxf(max_logit, l); }
} __syncthreads();
__syncthreads(); global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Get the global max logit.
// Reduce within the warp. // Aggregate tmp_out to out.
#pragma unroll const scalar_t* tmp_out_ptr =
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); head_idx * max_num_partitions * HEAD_SIZE;
} scalar_t* out_ptr =
if (lane == 0) { out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
red_smem[warp_idx] = max_logit; #pragma unroll
} for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
__syncthreads(); float acc = 0.0f;
// Reduce across warps. for (int j = 0; j < num_partitions; ++j) {
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
#pragma unroll inv_global_exp_sum;
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { }
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); from_float(out_ptr[i], acc);
} }
// Broadcast the max value to all threads. }
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
} // namespace vllm
// Load rescaled exp sums to shared memory.
float* shared_exp_sums = #undef WARP_SIZE
reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions); #undef MAX
const float* exp_sums_ptr = exp_sums + #undef MIN
seq_idx * num_heads * max_num_partitions + #undef DIVIDE_ROUND_UP
head_idx * max_num_partitions; \ No newline at end of file
float global_exp_sum = 0.0f;
for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) {
float l = shared_max_logits[i];
float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit);
global_exp_sum += rescaled_exp_sum;
shared_exp_sums[i] = rescaled_exp_sum;
}
__syncthreads();
global_exp_sum = block_sum<NUM_WARPS>(&red_smem[NUM_WARPS], global_exp_sum);
const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f);
// Aggregate tmp_out to out.
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
#pragma unroll
for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) {
float acc = 0.0f;
for (int j = 0; j < num_partitions; ++j) {
acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] *
inv_global_exp_sum;
}
from_float(out_ptr[i], acc);
}
}
} // namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
#include <assert.h>
#include <cuda.h>
#include <torch/all.h>
__device__ int64_t save_blocks(int* block_offset, int64_t range_start,
int64_t range_end, int64_t block_size,
int64_t input_block_count, int64_t kv_seqlen) {
if (range_start >= kv_seqlen) {
return input_block_count;
}
if (range_end > kv_seqlen) {
range_end = kv_seqlen;
}
int64_t current_block_count = input_block_count;
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[current_block_count++] = idx;
}
return current_block_count;
}
__global__ void convert_vertical_slash_indexes_kernel(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count,
block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M,
BLOCK_SIZE_N, NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* This function builds the index of each row of blocks from vertical indices
* and slash indices. The vertical indices are treated as points, while the
* slash indices are converted as ranges. The output consists of the merged
* ranges and separate column indices, where the ranges are represented by
* block indices.
*
* The implementation is referenced from the original MInference repo:
* https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu.
*/
void convert_vertical_slash_indexes(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(), block_offset.data_ptr<int>(),
column_count.data_ptr<int>(), column_index.data_ptr<int>(), batch_size,
num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash,
causal);
}
__global__ void convert_vertical_slash_indexes_kernel_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
const int* per_head_vertical_topkv, const int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N,
int64_t NNZ_V, int64_t NNZ_S,
bool causal // True for intra, False for succ
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;
int64_t q_seqlen = q_seqlens[batch_idx];
int64_t kv_seqlen = kv_seqlens[batch_idx];
int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x;
int64_t start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= q_seqlen) {
return;
}
int64_t end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
// MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S
// above is buffer size, use to compute offset)
NNZ_S = per_head_slash_topkv[head_idx];
NNZ_V = per_head_vertical_topkv[head_idx];
bool has_slash = true;
int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0;
int64_t s = 0, v = 0;
int64_t v_idx = vertical_indexes[v++];
int64_t s_idx = slash_indexes[s++];
if (causal) {
while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false;
s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M);
} else {
while (s_idx >= end_m + kv_seqlen && s < NNZ_S) {
s_idx = slash_indexes[s++];
}
if (s_idx > end_m + kv_seqlen) has_slash = false;
s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M);
}
int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
if (!has_slash) {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
}
bool slash_finished = false;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
if (causal)
v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen);
else
v_idx = end_m + BLOCK_SIZE_N + kv_seqlen;
}
} else {
if ((s < NNZ_S && causal) ||
(s < NNZ_S && !causal && slash_indexes[s] >= start_m)) {
if (causal)
s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++],
BLOCK_SIZE_M);
else
s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
if (v == NNZ_V || (v_idx > range_start && causal)) {
// add the last vertical if no more slash
if (v == NNZ_V && !causal && v_idx < kv_seqlen) {
column_index[tmp_col_cnt++] = v_idx;
}
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
break;
} else {
if (causal) {
range_start = (kv_seqlen - q_seqlen) + end_m;
range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N;
} else {
// if slash_finished but there are vertical left, save current
// blocks
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = kv_seqlen;
range_end = kv_seqlen + BLOCK_SIZE_N;
}
slash_finished = true;
}
}
if (!slash_finished) {
if (s_idx > range_end + BLOCK_SIZE_M) {
tmp_blk_cnt = save_blocks(block_offset, range_start, range_end,
BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}
}
block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}
void convert_vertical_slash_indexes_64x64_mergehead(
const int* q_seqlens, // [BATCH, ]
const int* kv_seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* per_head_vertical_topkv, int* per_head_slash_topkv,
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M,
int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) {
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
convert_vertical_slash_indexes_kernel_mergehead<<<dimGrid, dimBlock>>>(
q_seqlens, kv_seqlens, vertical_indexes, slash_indexes,
per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset,
column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N,
NNZ_V, NNZ_S, causal);
}
/**
* Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490.
*
* Like the above convert_vertical_slash_indexes, but with
* pre-computed vertical and slash counts.
*/
void convert_vertical_slash_indexes_mergehead(
torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S]
torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS]
torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V]
torch::Tensor q_seqlens, // [BATCH, ]
torch::Tensor kv_seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
torch::Tensor vertical_indices_count, // [N_HEADS, ]
torch::Tensor slash_indices_count, // [N_HEADS, ]
int64_t context_size, int64_t block_size_M, int64_t block_size_N,
bool causal) {
cudaSetDevice(q_seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
convert_vertical_slash_indexes_64x64_mergehead(
q_seqlens.data_ptr<int>(), kv_seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(), slash_indexes.data_ptr<int>(),
vertical_indices_count.data_ptr<int>(),
slash_indices_count.data_ptr<int>(), block_count.data_ptr<int>(),
block_offset.data_ptr<int>(), column_count.data_ptr<int>(),
column_index.data_ptr<int>(), batch_size, num_heads, num_rows,
block_size_M, block_size_N, nnz_vertical, nnz_slash, causal);
}
...@@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { ...@@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num; if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
} }
template <typename A, typename B>
static inline constexpr auto div_ceil(A a, B b) {
return (a + b - 1) / b;
}
// Round a down to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_previous_multiple_of(T a, T b) {
return a % b == 0 ? a : (a / b) * b;
}
// Round a up to the next multiple of b. The caller is responsible for making
// sure that b is non-zero
template <typename T>
inline constexpr T round_to_next_multiple_of(T a, T b) {
return a % b == 0 ? a : ((a / b) + 1) * b;
}
...@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); ...@@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8);
static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8 = ScalarType::uint(8);
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
static inline constexpr auto kFE2M1f =
ScalarType::float_(2, 1, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE3M2f = static inline constexpr auto kFE3M2f =
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
static inline constexpr auto kFE4M3fn = static inline constexpr auto kFE4M3fn =
...@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; ...@@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8;
static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8 = kU8;
static inline constexpr auto kUint8b128 = kU8B128; static inline constexpr auto kUint8b128 = kU8B128;
static inline constexpr auto kFloat4_e2m1f = kFE2M1f;
static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
static inline constexpr auto kFloat8_e5m2 = kFE5M2; static inline constexpr auto kFloat8_e5m2 = kFE5M2;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include <altivec.h> #include <altivec.h>
#include <cmath> #include <cmath>
#include <algorithm>
#include <torch/all.h> #include <torch/all.h>
namespace vec_op { namespace vec_op {
...@@ -62,6 +63,10 @@ typedef struct f32x4x4_t { ...@@ -62,6 +63,10 @@ typedef struct f32x4x4_t {
__vector float val[4]; __vector float val[4];
} f32x4x4_t; } f32x4x4_t;
typedef struct i32x4x4_t {
__vector int32_t val[4];
} i32x4x4_t;
struct FP32Vec8; struct FP32Vec8;
struct FP32Vec16; struct FP32Vec16;
...@@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> { ...@@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[0], 0, (signed short*)ptr);
vec_xst(reg.val[1], 16, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr);
} }
void save(void* ptr, const int elem_num) const {
const int clamped_elem = std::max(0, std::min(elem_num, 16));
// Calculate elements to store in each 128-bit part (8 elements each)
const int elements_val0 = std::min(clamped_elem, 8);
const int elements_val1 = std::max(clamped_elem - 8, 0);
// Convert elements to bytes (2 bytes per element)
const size_t bytes_val0 = elements_val0 * sizeof(signed short);
const size_t bytes_val1 = elements_val1 * sizeof(signed short);
signed short* dest = static_cast<signed short*>(ptr);
// Store the first part using vec_xst_len
if (bytes_val0 > 0) {
vec_xst_len(reg.val[0], dest, bytes_val0);
}
// Store the second part if needed
if (bytes_val1 > 0) {
vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1);
}
}
}; };
const static __vector signed short zero = vec_splats((signed short)0); const static __vector signed short zero = vec_splats((signed short)0);
...@@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> { ...@@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
} }
}; };
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
i32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
i32x4x4_t reg;
explicit INT32Vec16(const void* data_ptr) {
reg.val[0] = vec_xl(0, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[1] =
vec_xl(16, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[2] =
vec_xl(32, reinterpret_cast<const __vector int32_t*>(data_ptr));
reg.val[3] =
vec_xl(48, reinterpret_cast<const __vector int32_t*>(data_ptr));
}
void save(int32_t* ptr) const {
vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr));
vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr));
}
void save(int32_t* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(int32_t));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(int32_t));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(int32_t));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(int32_t));
vec_xst_len(reg.val[0], reinterpret_cast<int32_t*>(ptr), bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> { struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
union AliasReg { union AliasReg {
...@@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vec_ctf(v.reg.val[0], 0);
reg.val[1] = vec_ctf(v.reg.val[1], 0);
reg.val[2] = vec_ctf(v.reg.val[2], 0);
reg.val[3] = vec_ctf(v.reg.val[3], 0);
}
FP32Vec16 operator*(const FP32Vec16& b) const { FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]), vec_mul(reg.val[1], b.reg.val[1]),
...@@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_div(reg.val[3], b.reg.val[3])})); vec_div(reg.val[3], b.reg.val[3])}));
} }
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(f32x4x4_t(
{vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])),
vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])),
vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])),
vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))}));
}
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]),
vec_max(reg.val[1], b.reg.val[1]),
vec_max(reg.val[2], b.reg.val[2]),
vec_max(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 max(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
// Create a vector of element indices for each chunk
__vector unsigned int indices = {0, 1, 2, 3};
__vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
// Compute masks for each chunk
__vector unsigned int chunk_offset0 = {0, 0, 0,
0}; // Chunk 0: Elements 0-3
__vector unsigned int chunk_offset1 = {4, 4, 4,
4}; // Chunk 1: Elements 4-7
__vector unsigned int chunk_offset2 = {8, 8, 8,
8}; // Chunk 2: Elements 8-11
__vector unsigned int chunk_offset3 = {12, 12, 12,
12}; // Chunk 3: Elements 12-15
// Compute masks for each chunk
__vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
__vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
__vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
__vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
// Apply masks to compute the result for each chunk
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_max(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_max(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_max(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_max(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]),
vec_min(reg.val[1], b.reg.val[1]),
vec_min(reg.val[2], b.reg.val[2]),
vec_min(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 min(const FP32Vec16& b, int elem_num) const {
FP32Vec16 result;
vector unsigned int indices = {0, 1, 2, 3};
vector unsigned int elem_num_vec =
vec_splats(static_cast<unsigned int>(elem_num));
vector unsigned int chunk_offset0 = {0, 0, 0, 0};
vector unsigned int chunk_offset1 = {4, 4, 4, 4};
vector unsigned int chunk_offset2 = {8, 8, 8, 8};
vector unsigned int chunk_offset3 = {12, 12, 12, 12};
vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec);
vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec);
vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec);
vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec);
result.reg.val[0] = vec_sel(this->reg.val[0],
vec_min(this->reg.val[0], b.reg.val[0]), mask0);
result.reg.val[1] = vec_sel(this->reg.val[1],
vec_min(this->reg.val[1], b.reg.val[1]), mask1);
result.reg.val[2] = vec_sel(this->reg.val[2],
vec_min(this->reg.val[2], b.reg.val[2]), mask2);
result.reg.val[3] = vec_sel(this->reg.val[3],
vec_min(this->reg.val[3], b.reg.val[3]), mask3);
return FP32Vec16(result.reg);
}
FP32Vec16 abs() const {
return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]),
vec_abs(reg.val[2]), vec_abs(reg.val[3])}));
}
float reduce_max() {
__vector float max01 = vec_max(reg.val[0], reg.val[1]);
__vector float max23 = vec_max(reg.val[2], reg.val[3]);
__vector float max_all = vec_max(max01, max23);
__vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8));
temp = vec_max(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_min() {
__vector float min01 = vec_min(reg.val[0], reg.val[1]);
__vector float min23 = vec_min(reg.val[2], reg.val[3]);
__vector float min_all = vec_min(min01, min23);
__vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8));
temp = vec_min(temp, vec_sld(temp, temp, 4));
return vec_extract(temp, 0);
}
float reduce_sum() const { float reduce_sum() const {
AliasReg ar; AliasReg ar;
ar.reg = reg; ar.reg = reg;
...@@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr); vec_xst(reg.val[3], 48, ptr);
} }
void save(float* ptr, const int elem_num) const {
const int elements_in_chunk1 =
(elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0;
const int elements_in_chunk2 =
(elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0;
const int elements_in_chunk3 =
(elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0;
const int elements_in_chunk4 =
(elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0;
const size_t bytes_chunk1 =
static_cast<size_t>(elements_in_chunk1 * sizeof(float));
const size_t bytes_chunk2 =
static_cast<size_t>(elements_in_chunk2 * sizeof(float));
const size_t bytes_chunk3 =
static_cast<size_t>(elements_in_chunk3 * sizeof(float));
const size_t bytes_chunk4 =
static_cast<size_t>(elements_in_chunk4 * sizeof(float));
vec_xst_len(reg.val[0], ptr, bytes_chunk1);
vec_xst_len(reg.val[1],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 16),
bytes_chunk2);
vec_xst_len(reg.val[2],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 32),
bytes_chunk3);
vec_xst_len(reg.val[3],
reinterpret_cast<float*>(reinterpret_cast<char*>(ptr) + 48),
bytes_chunk4);
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16
union AliasReg {
__vector signed char reg;
int8_t values[VEC_NUM_ELEM];
};
__vector signed char reg;
explicit INT8Vec16(const FP32Vec16& vec) {
__vector signed int ret[4];
ret[0] = vec_cts(vec.reg.val[0], 0);
ret[1] = vec_cts(vec.reg.val[1], 0);
ret[2] = vec_cts(vec.reg.val[2], 0);
ret[3] = vec_cts(vec.reg.val[3], 0);
__vector signed short packed1 = vec_packs(ret[0], ret[1]);
__vector signed short packed2 = vec_packs(ret[2], ret[3]);
reg = vec_packs(packed1, packed2);
}
void save(void* ptr) const {
*reinterpret_cast<__vector signed char*>(ptr) = reg;
}
void save(signed char* ptr, const int elem_num) {
vec_xst_len(reg, ptr, static_cast<size_t>(elem_num));
}
}; };
template <typename T> template <typename T>
......
...@@ -9,7 +9,8 @@ void rotary_embedding_impl( ...@@ -9,7 +9,8 @@ void rotary_embedding_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads, /// head_size] or [num_tokens, num_heads,
/// head_size] /// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
...@@ -85,10 +86,13 @@ void rotary_embedding_impl( ...@@ -85,10 +86,13 @@ void rotary_embedding_impl(
compute_loop(token_head, cache_ptr, query); compute_loop(token_head, cache_ptr, query);
} }
for (int i = 0; i < num_kv_heads; ++i) { if (key != nullptr) {
const int head_idx = i; for (int i = 0; i < num_kv_heads; ++i) {
const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int head_idx = i;
compute_loop(token_head, cache_ptr, key); const int64_t token_head =
token_idx * key_stride + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
} }
} }
} }
...@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl( ...@@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl(
scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads,
/// head_size] or [num_tokens, num_heads, /// head_size] or [num_tokens, num_heads,
/// head_size] /// head_size]
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, scalar_t* __restrict__ key, // nullptr (optional) or
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads, // head_size] or [num_tokens, num_kv_heads,
// head_size] // head_size]
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim //
...@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl( ...@@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl(
} }
} }
if (key == nullptr) {
return;
}
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int i = 0; i < num_kv_heads; ++i) { for (int i = 0; i < num_kv_heads; ++i) {
...@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl( ...@@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl(
}; // namespace }; // namespace
void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
torch::Tensor& key, int64_t head_size, std::optional<torch::Tensor> key, int64_t head_size,
torch::Tensor& cos_sin_cache, bool is_neox) { torch::Tensor& cos_sin_cache, bool is_neox) {
int num_tokens = positions.numel(); int num_tokens = positions.numel();
int rot_dim = cos_sin_cache.size(1); int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size; int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size; int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads;
int64_t key_stride = key.stride(-2); int64_t key_stride = key.has_value() ? key->stride(-2) : 0;
int64_t query_stride = query.stride(-2); int64_t query_stride = query.stride(-2);
VLLM_DISPATCH_FLOATING_TYPES( VLLM_DISPATCH_FLOATING_TYPES(
...@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, ...@@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query,
if (is_neox) { if (is_neox) {
rotary_embedding_impl( rotary_embedding_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size, num_tokens); key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} else { } else {
rotary_embedding_gptj_impl( rotary_embedding_gptj_impl(
positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(), positions.data_ptr<int64_t>(), query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(), cos_sin_cache.data_ptr<scalar_t>(), key.has_value() ? key->data_ptr<scalar_t>() : nullptr,
rot_dim, query_stride, key_stride, num_heads, num_kv_heads, cos_sin_cache.data_ptr<scalar_t>(), rot_dim, query_stride,
head_size, num_tokens); key_stride, num_heads, num_kv_heads, head_size, num_tokens);
} }
CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) CPU_KERNEL_GUARD_OUT(rotary_embedding_impl)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment