Commit 711aa9d5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.0' into v0.10.0-dev

parents 751c492c 6d8d0a24
......@@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
input_low = int(real_input_len * (1 - range_ratio))
input_high = int(real_input_len * (1 + range_ratio))
output_low = int(output_len * (1 - range_ratio))
# Ensure the lower bound for output length is at least 1 to prevent
# sampling 0 tokens, which can cause request failures.
output_low = max(output_low, 1)
output_high = int(output_len * (1 + range_ratio))
# Add logging for debugging
......@@ -701,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self,
dataset_path: str,
dataset_split: str,
no_stream: bool = False,
dataset_subset: Optional[str] = None,
**kwargs,
) -> None:
......@@ -708,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self.dataset_split = dataset_split
self.dataset_subset = dataset_subset
self.load_stream = not no_stream
self.load_data()
def load_data(self) -> None:
......@@ -716,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset):
self.dataset_path,
name=self.dataset_subset,
split=self.dataset_split,
streaming=True,
streaming=self.load_stream,
)
self.data = self.data.shuffle(seed=self.random_seed)
......
......@@ -30,7 +30,7 @@ import os
import random
import time
import warnings
from collections.abc import AsyncGenerator, Iterable
from collections.abc import Iterable
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Literal, Optional
......@@ -73,6 +73,7 @@ from benchmark_dataset import (
VisionArenaDataset,
)
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
from vllm.benchmarks.serve import get_request
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
......@@ -107,101 +108,6 @@ class BenchmarkMetrics:
percentiles_e2el_ms: list[tuple[float, float]]
def _get_current_request_rate(
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
ramp_up_start_rps: Optional[int],
ramp_up_end_rps: Optional[int],
request_index: int,
total_requests: int,
request_rate: float,
) -> float:
if (
ramp_up_strategy
and ramp_up_start_rps is not None
and ramp_up_end_rps is not None
):
progress = request_index / max(total_requests - 1, 1)
if ramp_up_strategy == "linear":
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
return ramp_up_start_rps + increase
elif ramp_up_strategy == "exponential":
ratio = ramp_up_end_rps / ramp_up_start_rps
return ramp_up_start_rps * (ratio**progress)
else:
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
return request_rate
async def get_request(
input_requests: list[SampleRequest],
request_rate: float,
burstiness: float = 1.0,
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
ramp_up_start_rps: Optional[int] = None,
ramp_up_end_rps: Optional[int] = None,
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
"""
Asynchronously generates requests at a specified rate
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
Args:
input_requests:
A list of input requests, each represented as a SampleRequest.
request_rate:
The rate at which requests are generated (requests/s).
burstiness (optional):
The burstiness factor of the request generation.
Only takes effect when request_rate is not inf.
Default value is 1, which follows a Poisson process.
Otherwise, the request intervals follow a gamma distribution.
A lower burstiness value (0 < burstiness < 1) results
in more bursty requests, while a higher burstiness value
(burstiness > 1) results in a more uniform arrival of requests.
ramp_up_strategy (optional):
The ramp-up strategy. Can be "linear" or "exponential".
If None, uses constant request rate (specified by request_rate).
ramp_up_start_rps (optional):
The starting request rate for ramp-up.
ramp_up_end_rps (optional):
The ending request rate for ramp-up.
"""
assert burstiness > 0, (
f"A positive burstiness factor is expected, but given {burstiness}."
)
# Convert to list to get length for ramp-up calculations
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
input_requests = list(input_requests)
total_requests = len(input_requests)
request_index = 0
for request in input_requests:
current_request_rate = _get_current_request_rate(
ramp_up_strategy,
ramp_up_start_rps,
ramp_up_end_rps,
request_index,
total_requests,
request_rate,
)
yield request, current_request_rate
request_index += 1
if current_request_rate == float("inf"):
# If the request rate is infinity, then we don't need to wait.
continue
theta = 1.0 / (current_request_rate * burstiness)
# Sample the request interval from the gamma distribution.
# If burstiness is 1, it follows exponential distribution.
interval = np.random.gamma(shape=burstiness, scale=theta)
# The next request will be sent after the interval.
await asyncio.sleep(interval)
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
......@@ -825,6 +731,7 @@ def main(args: argparse.Namespace):
dataset_subset=args.hf_subset,
dataset_split=args.hf_split,
random_seed=args.seed,
no_stream=args.no_stream,
).sample(
num_requests=args.num_prompts,
tokenizer=tokenizer,
......@@ -1033,6 +940,11 @@ def create_argument_parser():
help="Path to the sharegpt/sonnet dataset. "
"Or the huggingface dataset ID if using HF dataset.",
)
parser.add_argument(
"--no-stream",
action="store_true",
help="Do not load the dataset in streaming mode.",
)
parser.add_argument(
"--max-concurrency",
type=int,
......
......@@ -410,6 +410,7 @@ def get_requests(args, tokenizer):
elif args.dataset_name == "burstgpt":
dataset_cls = BurstGPTDataset
elif args.dataset_name == "hf":
common_kwargs["no_stream"] = args.no_stream
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
dataset_cls = VisionArenaDataset
common_kwargs["dataset_subset"] = None
......@@ -666,6 +667,11 @@ def create_argument_parser():
help="Name of the dataset to benchmark on.",
default="sharegpt",
)
parser.add_argument(
"--no-stream",
action="store_true",
help="Do not load the dataset in streaming mode.",
)
parser.add_argument(
"--dataset",
type=str,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools
import torch
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.triton_utils import triton
if not current_platform.has_device_capability(100):
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
PROVIDER_CFGS = {
"torch-bf16": dict(enabled=True),
"nvfp4": dict(no_a_quant=False, enabled=True),
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
}
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
# Compute global scale for weight
b_amax = torch.abs(b).max().to(torch.float32)
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
return b_fp4, scale_b_fp4, b_global_scale
def build_nvfp4_runner(cfg, a, b, dtype, device):
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
# Compute global scale for activation
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
a_amax = torch.abs(a).max().to(torch.float32)
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
# Alpha for the GEMM operation
alpha = 1.0 / (a_global_scale * b_global_scale)
if cfg["no_a_quant"]:
# Pre-quantize activation
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
def run():
return ops.cutlass_scaled_fp4_mm(
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
)
return run
# Quantize activation on-the-fly
def run():
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
return ops.cutlass_scaled_fp4_mm(
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
)
return run
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size"],
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
x_log=False,
line_arg="provider",
line_vals=_enabled,
line_names=_enabled,
ylabel="TFLOP/s (larger is better)",
plot_name="BF16 vs NVFP4 GEMMs",
args={},
)
)
def benchmark(batch_size, provider, N, K):
M = batch_size
device = "cuda"
dtype = torch.bfloat16
a = torch.randn((M, K), device=device, dtype=dtype)
b = torch.randn((N, K), device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-bf16":
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
)
else:
cfg = PROVIDER_CFGS[provider]
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
lambda: run_quant(), quantiles=quantiles
)
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
def prepare_shapes(args):
out = []
for model, tp_size in itertools.product(args.models, args.tp_sizes):
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
KN[tp_dim] //= tp_size
KN.append(model)
out.append(KN)
return out
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
nargs="+",
type=str,
default=["meta-llama/Llama-3.1-8B-Instruct"],
choices=list(WEIGHT_SHAPES.keys()),
)
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
args = parser.parse_args()
for K, N, model in prepare_shapes(args):
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
benchmark.run(
print_data=True,
show_plots=True,
save_path=f"bench_nvfp4_res_n{N}_k{K}",
N=N,
K=K,
)
print("Benchmark finished!")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from typing import Callable
import torch
from vllm import _custom_ops as ops
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.triton_utils import triton
# TODO(luka): use standalone_compile utility
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
def inner(*args):
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
return fn(*args)
return inner
torch._dynamo.config.recompile_limit = 8888
compilation_config = CompilationConfig(custom_ops=["none"])
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
torch_per_token_quant_fp8 = torch.compile(
QuantFP8(False, GroupShape.PER_TOKEN),
fullgraph=True,
dynamic=False, # recompile for different shapes
)
# First dim is explicitly dynamic to simulate vLLM usage
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
def cuda_per_token_quant_fp8(
input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return ops.scaled_fp8_quant(input)
def calculate_diff(batch_size: int, seq_len: int):
"""Calculate difference between Triton and CUDA implementations."""
device = torch.device("cuda")
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
torch_out, torch_scale = torch_per_token_quant_fp8(x)
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
if torch.allclose(
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
print("✅ All implementations match")
else:
print("❌ Implementations differ")
batch_size_range = [1, 16, 32, 64, 128]
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
configs = list(itertools.product(batch_size_range, seq_len_range))
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["batch_size", "seq_len"],
x_vals=configs,
line_arg="provider",
line_vals=["torch", "cuda"],
line_names=["Torch", "CUDA"],
styles=[("blue", "-"), ("green", "-")],
ylabel="us",
plot_name="per-token-dynamic-quant-fp8-performance",
args={},
)
)
def benchmark_quantization(batch_size, seq_len, provider):
dtype = torch.float16
device = torch.device("cuda")
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch":
fn = lambda: torch_per_token_quant_fp8(x.clone())
elif provider == "cuda":
fn = lambda: cuda_per_token_quant_fp8(x.clone())
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
if __name__ == "__main__":
calculate_diff(batch_size=4, seq_len=4096)
benchmark_quantization.run(print_data=True)
......@@ -7,19 +7,19 @@ import time
from contextlib import nullcontext
from datetime import datetime
from itertools import product
from typing import Any, TypedDict
from typing import Any, TypedDict, Optional
import ray
import torch
from ray.experimental.tqdm_ray import tqdm
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
# 移除全局的 current_platform 导入,改为在需要时局部导入
# FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict):
......@@ -47,8 +47,11 @@ def benchmark_config(
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False
) -> float:
from vllm.platforms import current_platform
device = torch.cuda.current_device()
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device)
if use_int8_w8a16:
if not nn_moe:
w1 = torch.randint(
......@@ -60,6 +63,7 @@ def benchmark_config(
hidden_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
......@@ -70,6 +74,7 @@ def benchmark_config(
shard_intermediate_size // 2,
),
dtype=torch.int8,
device=device,
)
else:
w1 = torch.randint(
......@@ -81,6 +86,7 @@ def benchmark_config(
shard_intermediate_size,
),
dtype=torch.int8,
device=device,
)
w2 = torch.randint(
-127,
......@@ -91,23 +97,24 @@ def benchmark_config(
hidden_size,
),
dtype=torch.int8,
device=device,
)
else:
if not nn_moe:
w1 = torch.randn(
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype
num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype
num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype, device=device
)
else:
w1 = torch.randn(
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype
num_experts, hidden_size, shard_intermediate_size, dtype=init_dtype, device=device
)
w2 = torch.randn(
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype
num_experts, shard_intermediate_size // 2, hidden_size, dtype=init_dtype, device=device
)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32)
gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32, device=device)
w1_scale = None
w2_scale = None
......@@ -115,9 +122,12 @@ def benchmark_config(
a2_scale = None
if use_int8_w8a16:
w1_scale = torch.randn(
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32, device=device
)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32, device=device)
if use_deep_gemm:
# we use the default block shape for deepgemm
block_quant_shape = [128, 128]
if use_fp8_w8a8:
if block_quant_shape:
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
......@@ -130,24 +140,26 @@ def benchmark_config(
k_tiles_w1 = (K + block_k - 1) // block_k
k_tiles_w2 = (N + block_k - 1) // block_k
w1_scale = (
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32)
torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32, device=device)
* factor_for_scale
)
w2_scale = (
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32)
torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32, device=device)
* factor_for_scale
)
else:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
w1_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
w2_scale = torch.randn(num_experts, dtype=torch.float32, device=device)
a1_scale = torch.randn(1, dtype=torch.float32)
a2_scale = torch.randn(1, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32, device=device)
a2_scale = torch.randn(1, dtype=torch.float32, device=device)
# 获取 FP8_DTYPE
FP8_DTYPE = current_platform.fp8_dtype()
w1 = w1.to(FP8_DTYPE)
w2 = w2.to(FP8_DTYPE)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32, device=device)
def prepare(i: int):
input_gating.copy_(gating_output[i])
......@@ -266,6 +278,9 @@ def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
def get_configs_compute_bound(use_fp16, block_quant_shape, nn_moe: Optional[bool] = False) -> list[dict[str, int]]:
configs: list[BenchmarkConfig] = []
# 局部导入 current_platform
from vllm.platforms import current_platform
if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
......@@ -426,12 +441,18 @@ def merge_unique_dicts(list1, list2):
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int, device_id: int) -> None:
torch.set_default_device("cuda:"+ str(device_id))
from vllm.platforms import current_platform
import os
if current_platform.is_rocm():
# In ROCm environment with Ray, let Ray handle device assignment
# Don't manually set default device as it may conflict with Ray's device mapping
pass
else:
torch.set_default_device("cuda:"+ str(device_id))
current_platform.seed_everything(seed)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
# Store the logical device ID for Ray
self.device_id = device_id
def benchmark(
......@@ -448,7 +469,13 @@ class BenchmarkWorker:
use_deep_gemm: bool = False,
nn_moe: Optional[bool] = False,
) -> tuple[dict[str, int], float]:
# 局部导入 current_platform
from vllm.platforms import current_platform
current_platform.seed_everything(self.seed)
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_moe_configs, get_default_config
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
......@@ -502,6 +529,9 @@ class BenchmarkWorker:
use_deep_gemm: bool,
nn_moe: Optional[bool] = False,
) -> dict[str, int]:
from vllm.platforms import current_platform
import os
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
......@@ -515,10 +545,16 @@ class BenchmarkWorker:
topk,
)
# In ROCm environments with Ray, device context is already handled by Ray
# Using torch.cuda.device() may cause device ordinal conflicts
need_device_guard = False
if current_platform.is_rocm():
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
if visible_device != f"{self.device_id}":
# For ROCm with Ray, skip additional device context management
need_device_guard = False
else:
# For other platforms, use device guard if needed
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if visible_devices is not None and len(visible_devices.split(',')) > 1:
need_device_guard = True
with torch.cuda.device(self.device_id) if need_device_guard else nullcontext():
......@@ -587,6 +623,10 @@ def save_configs(
block_quant_shape: list[int],
use_nn_moe: Optional[bool] = False,
) -> None:
from vllm.model_executor.layers.fused_moe.fused_moe import (
get_config_dtype_str, get_config_file_name
)
dtype_str = get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
)
......@@ -611,6 +651,13 @@ def get_weight_block_size_safety(config, default_value=None):
def main(args: argparse.Namespace):
import os
import logging
from vllm.platforms import current_platform
logger = logging.getLogger(__name__)
print(args)
tp_size = args.tp_size
......@@ -628,7 +675,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", "Glm4MoeForCausalLM"):
elif config.architectures[0] in (
"DeepseekV3ForCausalLM",
"DeepseekV2ForCausalLM",
"Glm4MoeForCausalLM",
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
......@@ -638,6 +689,11 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // tp_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Support for llama4
config = config.get_text_config()
......
......@@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
expert_ids_triton = torch.zeros(
expert_ids_triton = torch.empty(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
sorted_ids_vllm.fill_(topk_ids.numel())
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
expert_ids_vllm = torch.empty_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
# 2. run implementations
......@@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import csv
import os
import random
from datetime import datetime
import flashinfer
import torch
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
# KV Cache Layout for TRT-LLM
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
def to_float8(x, dtype=torch.float8_e4m3fn):
finfo = torch.finfo(dtype)
min_val, max_val = x.aminmax()
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax * 0.1
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
return x_scl_sat.to(dtype), scale.float().reciprocal()
@torch.no_grad()
def benchmark_decode(
num_seqs,
max_seq_len,
page_size=16,
dtype=torch.bfloat16,
kv_layout="HND",
num_kv_heads=8,
kv_cache_dtype="auto",
head_dim=128,
warmup=10,
trials=20,
):
torch.set_default_device("cuda")
device = "cuda"
torch.manual_seed(0)
# Currently only HEAD_GRP_SIZE == 8 is supported
HEAD_GRP_SIZE = 8
MAX_SEQ_LEN = max_seq_len
# large number to reduce kv_cache reuse
NUM_BLOCKS = int(256000 / page_size)
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
# For decode, batch_size is num_decode_token
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
sm_scale = float(1.0 / (head_dim**0.5))
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
max_kv_len = max(kv_lens)
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
block_tables = torch.randint(
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
)
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
k_scale = v_scale = 1.0
if kv_cache_dtype.startswith("fp8"):
kv_cache, _ = to_float8(kv_cache)
# Benchmark TRT decode
def trt_decode():
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
q,
kv_cache,
workspace_buffer,
num_qo_heads,
num_kv_heads,
sm_scale,
block_tables,
kv_lens_tensor,
page_size,
max_kv_len,
kv_cache_dtype,
k_scale,
v_scale,
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
times = []
for i in range(warmup):
fn()
for i in range(trials):
start.record()
fn()
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
# TRT Decode
trt_mean, trt_std = time_fn(trt_decode)
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + page_size - 1) // page_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % page_size
if kv_last_page_len == 0:
kv_last_page_len = page_size
kv_last_page_lens.append(kv_last_page_len)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
workspace_buffer,
kv_layout,
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
)
wrapper.plan(
kv_indptr,
kv_indices,
kv_last_page_lens,
num_qo_heads,
num_kv_heads,
head_dim,
page_size,
"NONE",
q_data_type=dtype,
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
)
def baseline_decode():
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
baseline_mean, baseline_std = time_fn(baseline_decode)
# Calculate percentage speedup (positive means TRT is faster)
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
print(
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
)
# Return results for CSV writing
return {
"num_seqs": num_seqs,
"trt_mean": trt_mean,
"trt_std": trt_std.item(),
"baseline_mean": baseline_mean,
"baseline_std": baseline_std.item(),
"speedup_percent": speedup_percent,
"q_dtype": str(dtype),
"kv_cache_dtype": kv_cache_dtype,
"page_size": page_size,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"max_seq_len": max_seq_len,
}
def write_results_to_csv(results, filename=None):
"""Write benchmark results to CSV file."""
if filename is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
fieldnames = [
"num_seqs",
"trt_mean",
"trt_std",
"baseline_mean",
"baseline_std",
"speedup_percent",
"q_dtype",
"kv_cache_dtype",
"page_size",
"num_kv_heads",
"head_dim",
"max_seq_len",
]
file_exists = os.path.exists(filename)
with open(filename, "a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
if not file_exists:
writer.writeheader()
for result in results:
writer.writerow(result)
print(f"Results written to {filename}")
if __name__ == "__main__":
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
all_results = []
print("Running benchmark for kv_cache_dtype: bfloat16")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
)
all_results.append(result)
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
print(
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
)
for max_seq_len in max_seq_lens:
for bs in num_seqs:
result = benchmark_decode(
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
)
all_results.append(result)
# Write all results to CSV
write_results_to_csv(all_results)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import time
from typing import Optional
from tabulate import tabulate
from vllm.utils import FlexibleArgumentParser
from vllm.v1.core.block_pool import BlockPool
class Metric:
def __init__(self) -> None:
self.cnt: int = 0
self.sum_v: int = 0
self.max_v: Optional[int] = None
def update(self, v: int) -> None:
self.cnt += 1
self.sum_v += v
if self.max_v is None:
self.max_v = v
else:
self.max_v = max(self.max_v, v)
def avg_v(self) -> float:
return self.sum_v * 1.0 / self.cnt
def main(args):
rows = []
for allocate_block in args.allocate_blocks:
# Enforce a GC collect ahead to minimize the impact among runs
gc.collect()
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
get_blocks_metric: Metric = Metric()
free_blocks_metric: Metric = Metric()
for _ in range(args.num_iteration):
t1 = time.monotonic_ns()
blocks = block_pool.get_new_blocks(allocate_block)
t2 = time.monotonic_ns()
block_pool.free_blocks(blocks)
t3 = time.monotonic_ns()
get_blocks_metric.update(t2 - t1)
free_blocks_metric.update(t3 - t2)
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
rows.append(
[
get_blocks_metric.cnt,
args.num_gpu_blocks,
allocate_block,
get_blocks_metric.avg_v() / 1000000,
get_blocks_metric.max_v / 1000000.0,
free_blocks_metric.avg_v() / 1000000,
free_blocks_metric.max_v / 1000000.0,
]
)
else:
print(
"No valid metrics found."
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
)
print(
tabulate(
rows,
headers=[
"Iterations",
"Total\nBlocks",
"Allocated\nBlocks",
"Get Blocks\nAvg (ms)",
"Get Blocks\nMax (ms)",
"Free Blocks\nAvg (ms)",
"Free Blocks\nMax (ms)",
],
tablefmt="grid",
floatfmt=".6f",
)
)
def invoke_main() -> None:
parser = FlexibleArgumentParser(
description="Benchmark the performance of BlockPool for KV Cache."
)
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
parser.add_argument(
"--num-iteration",
type=int,
default=1000,
help="Number of iterations to run to stablize final data readings",
)
parser.add_argument(
"--allocate-blocks",
type=int,
nargs="*",
default=[10, 50, 100, 500, 1000],
help="Number of blocks to allocate",
)
args = parser.parse_args()
main(args)
if __name__ == "__main__":
invoke_main() # pragma: no cover
......@@ -165,17 +165,32 @@ else()
endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
#
if (AVX512_FOUND AND NOT AVX512_DISABLED)
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
# Flag to enable ACL kernels for AARCH64 platforms
if ( VLLM_BUILD_ACL STREQUAL "ON")
set(USE_ACL ON)
else()
set(USE_ACL OFF)
endif()
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.7.1
GIT_TAG v3.8.1
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
if(USE_ACL)
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
if(NOT ARM_COMPUTE_LIBRARY)
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
endif()
set(ONEDNN_AARCH64_USE_ACL "ON")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
endif()
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
......@@ -264,6 +279,11 @@ elseif(POWER10_FOUND)
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif()
if (ASIMD_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/quant.cpp"
${VLLM_EXT_SRC})
endif()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
......
......@@ -24,6 +24,7 @@
#include "attention_dtypes.h"
#include "attention_utils.cuh"
#include "cuda_compat.h"
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
......@@ -35,12 +36,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
#include "../quantization/int8_kvcache/quant_utils.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
......@@ -684,7 +679,6 @@ __global__ void paged_attention_v2_reduce_kernel(
} // namespace vllm
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
/*!
\file
\brief An universal device layer for cutlass 3.x-style kernels.
*/
// clang-format off
#pragma once
// common
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#if !defined(__CUDACC_RTC__)
#include "cutlass/cluster_launch.hpp"
#include "cutlass/trace.h"
#endif // !defined(__CUDACC_RTC__)
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
#include "../kernel/sm100_fmha_mla_reduction.hpp"
////////////////////////////////////////////////////////////////////////////////
namespace cutlass::fmha::device {
using namespace cute;
using namespace cutlass::fmha::kernel;
////////////////////////////////////////////////////////////////////////////////
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
////////////////////////////////////////////////////////////////////////////////
template<
class Kernel_
>
class MLA {
public:
using Kernel = Kernel_;
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
typename Kernel::ElementOut,
typename Kernel::ElementAcc,
typename Kernel::ElementAcc,
Kernel::TileShapeH::value,
Kernel::TileShapeL::value,
256 /*Max split*/
>;
/// Argument structure: User API
using KernelArguments = typename Kernel::Arguments;
using ReductionArguments = typename ReductionKernel::Arguments;
using Arguments = KernelArguments;
/// Argument structure: Kernel API
using KernelParams = typename Kernel::Params;
using ReductionParams = typename ReductionKernel::Params;
struct Params {
KernelParams fmha_params;
ReductionParams reduction_params;
};
private:
/// Kernel API parameters object
Params params_;
bool is_initialized(bool set = false) {
static bool initialized = false;
if (set) initialized = true;
return initialized;
}
static ReductionArguments to_reduction_args(Arguments const& args) {
auto [H, K, D, B] = args.problem_shape;
return ReductionArguments{
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
args.ptr_split_kv, Kernel::TileShapeS::value
};
}
public:
/// Access the Params structure
Params const& params() const {
return params_;
}
static void set_split_kv (KernelArguments& args) {
// printf("set_split_kv start");
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int sm_count = args.hw_info.sm_count;
// printf(" sm_count = %d\n", sm_count);
int max_splits = ceil_div(K, 128);
max_splits = min(16, max_splits);
// printf(" max_splits = %d\n", max_splits);
int sms_per_batch = max(1, sm_count / B);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
// printf(" args.split_kv = %d\n", args.split_kv);
}
/// Determines whether the GEMM can execute the given problem.
static Status
can_implement(Arguments const& args) {
if (! Kernel::can_implement(args)) {
return Status::kInvalid;
}
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
return Status::kInvalid;
}
return Status::kSuccess;
}
/// Gets the workspace size
static size_t
get_workspace_size(Arguments const& args) {
size_t workspace_bytes = 0;
workspace_bytes += Kernel::get_workspace_size(args);
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
return workspace_bytes;
}
/// Computes the maximum number of active blocks per multiprocessor
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
int max_active_blocks = -1;
int smem_size = Kernel::SharedStorageSize;
// first, account for dynamic smem capacity if needed
cudaError_t result;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaFuncSetAttribute() returned error: "
<< cudaGetErrorString(result));
return -1;
}
}
// query occupancy after setting smem size
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks,
device_kernel<Kernel>,
Kernel::MaxThreadsPerBlock,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
<< cudaGetErrorString(result));
return -1;
}
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
return max_active_blocks;
}
/// Initializes GEMM state from arguments.
Status
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
// Initialize the workspace
Status status = Kernel::initialize_workspace(args, workspace, stream);
if (status != Status::kSuccess) {
return status;
}
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
if (status != Status::kSuccess) {
return status;
}
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {kernel_params, reduction_params};
if (is_initialized()) return Status::kSuccess;
// account for dynamic smem capacity if needed
// no dynamic smem is needed for reduction kernel
int smem_size = Kernel::SharedStorageSize;
if (smem_size >= (48 << 10)) {
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
cudaError_t result = cudaFuncSetAttribute(
device_kernel<Kernel>,
cudaFuncAttributeMaxDynamicSharedMemorySize,
smem_size);
if (cudaSuccess != result) {
result = cudaGetLastError(); // to clear the error bit
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
return Status::kErrorInternal;
}
}
is_initialized(true);
return Status::kSuccess;
}
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
Status
update(Arguments const& args, void* workspace = nullptr) {
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
size_t workspace_bytes = get_workspace_size(args);
if (workspace_bytes > 0 && nullptr == workspace) {
return Status::kErrorWorkspaceNull;
}
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
ReductionArguments reduction_args = to_reduction_args(args);
if (reduction_args.split_kv > 1) {
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
}
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
// Initialize the Params structure
params_ = Params {fmha_params, reduction_params};
return Status::kSuccess;
}
/// Primary run() entry point API that is static allowing users to create and manage their own params.
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
static Status
run(Params& params, cudaStream_t stream = nullptr) {
CUTLASS_TRACE_HOST("MLA::run()");
dim3 const block = Kernel::get_block_shape();
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
// configure smem size and carveout
int smem_size = Kernel::SharedStorageSize;
Status launch_result;
// Use extended launch API only for mainloops that use it
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
cute::size<1>(typename Kernel::ClusterShape{}),
cute::size<2>(typename Kernel::ClusterShape{}));
void const* kernel = (void const*) device_kernel<Kernel>;
void* kernel_params[] = {&params.fmha_params};
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
}
else {
launch_result = Status::kSuccess;
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
}
cudaError_t result = cudaGetLastError();
if (cudaSuccess != result or Status::kSuccess != launch_result) {
//return Status::kSuccess;
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
if (params.reduction_params.split_kv > 1) {
// launch reduction kernel
dim3 const block = ReductionKernel::get_block_shape();
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
cudaError_t result = cudaGetLastError();
if (cudaSuccess == result) {
return Status::kSuccess;
}
else {
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
return Status::kErrorInternal;
}
}
else {
return Status::kSuccess;
}
}
//
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
//
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
Status status = initialize(args, workspace, stream);
if (Status::kSuccess == status) {
status = run(params_, stream);
}
return status;
}
/// Launches the kernel after first constructing Params internal state from supplied arguments.
Status
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
return run(args, workspace, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
run(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
Status
operator()(cudaStream_t stream = nullptr) {
return run(params_, stream);
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::device
////////////////////////////////////////////////////////////////////////////////
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/arch/arch.h"
#include "cute/tensor.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class ElementOut,
class ElementAcc,
class ElementScale,
size_t kNumHeads,
size_t kHeadDimLatent,
int kMaxSplits
>
struct Sm100FmhaMlaReductionKernel {
static const int SharedStorageSize = 0;
static const int MaxThreadsPerBlock = 128;
static const int MinBlocksPerMultiprocessor = 1;
using ArchTag = cutlass::arch::Sm100;
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
struct Arguments {
ElementAcc* ptr_oaccum = nullptr;
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_lseaccum = nullptr;
ElementAcc* ptr_lse = nullptr;
ElementScale scale = 1.f;
int num_batches = 0;
int split_kv = -1;
int dim_k = -1;
int* ptr_seq = nullptr;
int* ptr_split_kv = nullptr;
int tile_shape_s = 128;
};
using Params = Arguments;
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
args.ptr_split_kv, args.tile_shape_s};
}
static size_t get_workspace_size(Arguments const& /*args*/) {
return 0;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return dim3(kNumHeads, 1, params.num_batches);
}
static dim3 get_block_shape() {
return dim3(MaxThreadsPerBlock, 1, 1);
}
static bool can_implement(Arguments const& args) {
if (args.num_batches <= 0) return false;
if (args.split_kv <= 0) return false;
return true;
}
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
if (params.split_kv <= 1) return;
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
__shared__ ElementAcc sLseScale[kMaxSplits];
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
Shape<_1>{}, Stride<_1>{});
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
ElementAcc local_lse[kNLsePerThread];
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
}
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
lse_max = max(lse_max, local_lse[i]);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
}
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
ElementAcc sum_lse = 0;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
}
CUTLASS_PRAGMA_UNROLL
for (int offset = 16; offset >= 1; offset /= 2) {
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
}
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
gLSE(0) = global_lse;
}
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + threadIdx.x;
if (split < local_split_kv) {
sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
}
__syncthreads();
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
ElementAcc local_val[Elements] = {0};
for (int split = 0; split < local_split_kv; ++split) {
ElementAcc lse_scale = sLseScale[split];
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
}
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
}
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
CUTLASS_PRAGMA_UNROLL
for(int i = 0; i < Elements; ++i) {
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
}
}
};
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/arch/simd_sm100.hpp"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/memory_sm80.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "gather_tensor.hpp" // from examples/common
#include "common/pow_2.hpp"
namespace cutlass::fmha::kernel {
using namespace cute;
template<
class TileShape,
class Element_,
class ElementAcc_,
class ElementOut_,
class ElementLSE_,
class TileScheduler,
#ifdef CPASYNC
bool kIsCpAsync = true
#else
bool kIsCpAsync = false
#endif
>
struct Sm100FmhaMlaKernelTmaWarpspecialized {
using Element = Element_;
using ElementAcc = ElementAcc_;
using ElementOut = ElementOut_;
using ElementLSE = ElementLSE_;
// only 2Sm mode is supported
static const bool kIs2Sm = true;
static const int MaxThreadsPerBlock = 256;
static const int MinBlocksPerMultiprocessor = 1;
static const int TotalSNum = 2;
static const int TotalPNum = 2;
using ArchTag = cutlass::arch::Sm100;
using ClusterShape = cute::conditional_t<kIs2Sm, Shape<_2, _1, _1>, Shape<_1, _1, _1>>;
using TileShapeH = tuple_element_t<0, TileShape>;
using TileShapeS = tuple_element_t<1, TileShape>;
using TileShapeD = tuple_element_t<2, TileShape>;
using TileShapeL = tuple_element_t<0, TileShapeD>;
using TileShapeR = tuple_element_t<1, TileShapeD>;
static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim");
using ProblemShape = Shape<TileShapeH, int, TileShapeD, int>;
using TensorStride = Stride<int64_t, _1, int64_t>;
using TmemAllocator = cute::conditional_t<kIs2Sm, cute::TMEM::Allocator2Sm, cute::TMEM::Allocator1Sm>;
static_assert(TileShapeH{} == 128);
static const int kWarpsInN = kIs2Sm ? 2 : 1;
static const int kNumComputeWarps = 4;
static const int kNumLoadWarps = kIsCpAsync ? 2 : 1;
enum class WarpRole {
kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0
};
static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull;
static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) {
return static_cast<WarpRole>((kWarpAssignment >> (4 * warp_idx)) & 0xF);
}
static const int Alignment = 128 / sizeof_bits_v<Element>;
static const int AlignmentOut = 128 / sizeof_bits_v<ElementOut>;
using TileShapeQK = Shape<TileShapeH, TileShapeS, decltype(TileShapeR{} / _1{})>;
static const int StagesQK = 24 / sizeof(Element); // free parameter
static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value;
static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value;
static const int IterationsQK = IterationsQKLatent + IterationsQKRope;
using Schedule = cute::conditional_t<kIs2Sm, cutlass::gemm::KernelTmaWarpSpecialized2SmSm100, cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>;
using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStride, Alignment,
Element, TensorStride, Alignment,
ElementAcc,
TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<StagesQK>,
Schedule>::CollectiveOp;
using TiledMmaQK = typename CollectiveMmaQK::TiledMma;
using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK;
// chosen for unified smem staging between K and V
using TileShapePV = Shape<TileShapeH, _256, _32>;
using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{}));
static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes
static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value;
static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value;
using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder<
cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp,
Element, TensorStride, Alignment,
Element, TransposeTensorStride, Alignment,
ElementAcc,
TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<StagesPV>,
Schedule>::CollectiveOp;
using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK;
static_assert(std::is_same_v<TransposeTensorStride, typename CollectiveMmaPV::StrideB>);
using TiledMmaPV = typename CollectiveMmaPV::TiledMma;
using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK;
static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match");
static const int StagesPageTable = kIsCpAsync ? StagesPV : 1;
// pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd
// use expect_tx for Q load
using PipelineLoadQK = cute::conditional_t<kIsCpAsync, PipelineUmmaConsumerAsync<StagesQK, AtomThrShapeMNK>, PipelineTmaUmmaAsync<StagesQK, ClusterShape, AtomThrShapeMNK>>;
using PipelineLoadPV = PipelineLoadQK;
// pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages
using PipelineS = PipelineUmmaAsync<TotalSNum, AtomThrShapeMNK>;
// pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages
using PipelineP = PipelineUmmaConsumerAsync<TotalPNum, AtomThrShapeMNK>;
// pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage
using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>;
using PipelinePT = PipelineAsync<StagesPageTable>;
struct PipelineStorage {
alignas(16) typename PipelineLoadQK::SharedStorage load_qk;
alignas(16) typename PipelineS::SharedStorage mma_s;
alignas(16) typename PipelineP::SharedStorage p_mma;
alignas(16) typename PipelineO::SharedStorage mma_o;
alignas(16) typename PipelinePT::SharedStorage load_page_table;
};
template<class Layout, class Stages = _1>
static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) {
return composition(layout, make_tuple(_, _, _, make_layout(stages)));
}
using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int<IterationsQK>{}));
using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB;
using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB;
using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int<IterationsPV_K>{}, _2{})));
static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v<Element>);
static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v<Element>);
// pre-condition for overlapped smem staging
static_assert(kBytesLoadKC == kBytesLoadVC);
static_assert(StagesQK == StagesPV);
static const int kTransactionsBytesLoadQK = kBytesLoadKC;
static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ;
static const int kTransactionsBytesLoadPV = kBytesLoadVC;
static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier;
// This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent
// tile scheduler for FP8 MLA.
static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier;
//
static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier;
enum class TmemAllocation : uint32_t {
kSizeS = TileShapeS::value / kWarpsInN,
// Overall
kSizeO = TileShapeL::value / kWarpsInN,
// Between accumulators we loop over
kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN,
kNumS = TotalSNum,
kNumP = TotalPNum,
kNumO = 1,
kS0 = 0,
kS1 = kS0 + kSizeS,
kO0 = kS1 + kSizeS,
kTotal = kO0 + kSizeO
};
static_assert(static_cast<int>(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem");
struct TensorStorage {
// to communicate max and row_sum
cute::array<ElementAcc, kNumComputeWarps * cutlass::NumThreadsPerWarp> smem_exchange;
cute::array<int, StagesPageTable * TileShapeS::value> smem_page_table;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutQ>> smem_q;
union {
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutKC>> smem_kc;
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutVC>> smem_vc;
};
alignas(2048) cute::array<Element, cute::cosize_v<SmemLayoutP>> smem_p;
};
struct SharedStorage {
PipelineStorage pipelines;
TensorStorage tensors;
uint32_t tmem_base_ptr;
};
static const int SharedStorageSize = sizeof(SharedStorage);
static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem");
struct MainloopArguments {
ElementAcc softmax_scale;
// all tensors strides are (num_heads or seqlen, head_dim, batch)
// head_dim stride is always 1
Element* ptr_q_latent;
TensorStride stride_q_latent;
Element* ptr_q_rope;
TensorStride stride_q_rope;
Element* ptr_c_latent;
TensorStride stride_c_latent;
Element* ptr_k_rope;
TensorStride stride_k_rope;
// for paged attention, we interpret what was previously [batch, seqlen]
// as [page_count, page_size], and index according to page_table
int* ptr_seq = nullptr;
int* ptr_page_table = nullptr;
// page table is [batch, seqlen or similar]
Stride<_1, int> stride_page_table = {};
int page_count = 0;
int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS
};
struct EpilogueArguments {
ElementOut* ptr_o = nullptr;
TensorStride stride_o;
ElementLSE* ptr_lse = nullptr;
Stride<_1, int> stride_lse;
ElementAcc output_scale = 1.0f;
};
struct Arguments {
// (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count)
// for paged attention, seqlen is max seqlen
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueArguments epilogue;
KernelHardwareInfo hw_info;
int split_kv = -1;
int* ptr_split_kv = nullptr;
};
using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A;
using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A;
using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B;
using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B;
using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B;
struct MainloopParams {
TmaLoadQLatent tma_load_q_latent;
TmaLoadQRope tma_load_q_rope;
TmaLoadCLatent tma_load_c_latent;
TmaLoadKRope tma_load_k_rope;
TmaLoadCLatentTranspose tma_load_c_latent_transpose;
};
struct EpilogueParams {
ElementOut* ptr_o = nullptr;
ElementAcc* ptr_o_acc = nullptr;
TensorStride stride_o;
TensorStride stride_o_acc;
ElementLSE* ptr_lse = nullptr;
ElementLSE* ptr_lse_acc = nullptr;
Stride<_1, int> stride_lse;
Stride<_1, int> stride_lse_acc;
ElementAcc output_scale = 1.0f;
};
struct Params {
ProblemShape problem_shape;
MainloopArguments mainloop;
EpilogueParams epilogue;
MainloopParams mainloop_params;
typename TileScheduler::Params tile_scheduler;
int split_kv = -1;
int* ptr_split_kv = nullptr;
};
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
//workspace = nullptr; // let's get an error if one of these needs workspace
auto [H, K, D, B] = args.problem_shape;
auto [L, R] = D;
int paged_B = B;
int paged_K = K;
if (args.mainloop.ptr_page_table != nullptr) {
paged_B = args.mainloop.page_count;
paged_K = args.mainloop.page_size;
}
auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments(
make_shape(H, K, L, B),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent,
args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent,
}, nullptr);
auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments(
make_shape(H, paged_K, L, paged_B),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent,
args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent,
}, nullptr);
auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments(
make_shape(H, K, R, B),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope,
args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope,
}, nullptr);
auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments(
make_shape(H, paged_K, R, paged_B),
typename CollectiveMmaQK::Arguments {
args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope,
args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope,
}, nullptr);
auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent);
auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments(
make_shape(H, L, paged_K, paged_B),
typename CollectiveMmaPV::Arguments {
args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used
args.mainloop.ptr_c_latent, stride_c_latent_transpose,
}, nullptr);
MainloopParams mainloop_params {
params_qk_latent.tma_load_a,
params_qk_rope.tma_load_a,
params_qk_latent_paged.tma_load_b,
params_qk_rope_paged.tma_load_b,
params_pv_latent.tma_load_b
};
EpilogueParams epilogue_params;
epilogue_params.ptr_o = args.epilogue.ptr_o;
epilogue_params.stride_o = args.epilogue.stride_o;
epilogue_params.ptr_lse = args.epilogue.ptr_lse;
epilogue_params.stride_lse = args.epilogue.stride_lse;
epilogue_params.output_scale = args.epilogue.output_scale;
if (args.split_kv > 1) {
ElementAcc* ptr_o_acc = reinterpret_cast<ElementAcc*>(workspace);
ElementLSE* ptr_lse_acc = reinterpret_cast<ElementLSE*>(ptr_o_acc + H * L * args.split_kv * B);
epilogue_params.ptr_o_acc = ptr_o_acc;
epilogue_params.ptr_lse_acc = ptr_lse_acc;
epilogue_params.stride_o_acc = make_tuple(static_cast<int64_t>(0 + L) * args.split_kv, _1{}, static_cast<int64_t>(0 + H * L) * args.split_kv);
epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv);
}
return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params,
TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), args.split_kv, args.ptr_split_kv};
}
static size_t get_workspace_size(Arguments const& args) {
ProblemShape problem_shape = args.problem_shape;
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
auto split_kv = args.split_kv;
return (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B;
}
static Status initialize_workspace(
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
return Status::kSuccess;
}
static dim3 get_grid_shape(Params const& params) {
return TileScheduler::get_grid_shape(params.tile_scheduler);
}
static dim3 get_block_shape() {
dim3 block(MaxThreadsPerBlock, 1, 1);
return block;
}
static bool can_implement(Arguments const& args) {
if (kIsCpAsync) {
if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) {
return false;
}
if (args.mainloop.page_size > TileShapeS{}) {
return false;
}
}
else {
if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) {
return false;
}
}
if (get<0>(args.problem_shape) != 128) {
return false;
}
if (get<1>(args.problem_shape) <= 0) {
return false;
}
if (args.split_kv <= 0) {
return false;
}
return true;
}
CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) {
TileScheduler tile_scheduler(params.tile_scheduler);
int warp_idx = cutlass::canonical_warp_idx_sync();
auto role = warp_idx_to_role(warp_idx);
uint32_t lane_predicate = cute::elect_one_sync();
uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster();
int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{});
bool is_mma_leader_cta = cta_coord_v == 0;
if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) {
prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor());
prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor());
}
SharedStorage& shared_storage = *reinterpret_cast<SharedStorage*>(smem_raw);
typename PipelineLoadQK::Params pipeline_load_qk_params;
if (role == WarpRole::kLoad) {
pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer;
}
if (role == WarpRole::kMma) {
pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer;
}
if constexpr (kIsCpAsync) {
// we can make our life easier by unconditionally loading blocks
// since we know it'll always be legal
pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{});
}
else {
pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta;
pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK;
}
pipeline_load_qk_params.initializing_warp = 0;
PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineS::Params pipeline_mma_s_params;
if (role == WarpRole::kMma) {
pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer;
}
if (role == WarpRole::kCompute) {
pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer;
}
pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{});
pipeline_mma_s_params.initializing_warp = 1;
PipelineS pipeline_mma_s(
shared_storage.pipelines.mma_s,
pipeline_mma_s_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineP::Params pipeline_p_mma_params;
if (role == WarpRole::kMma) {
pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer;
}
if (role == WarpRole::kCompute) {
pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer;
}
pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{});
pipeline_p_mma_params.consumer_arv_count = 1;
pipeline_p_mma_params.initializing_warp = 2;
PipelineP pipeline_p_mma(
shared_storage.pipelines.p_mma,
pipeline_p_mma_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelineO::Params pipeline_mma_o_params;
if (role == WarpRole::kMma) {
pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer;
}
if (role == WarpRole::kCompute) {
pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer;
}
pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{});
pipeline_mma_o_params.initializing_warp = 3;
PipelineO pipeline_mma_o(
shared_storage.pipelines.mma_o,
pipeline_mma_o_params,
ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{});
typename PipelinePT::Params pipeline_pt_params;
if (role == WarpRole::kLoad) {
pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer;
}
if (role == WarpRole::kLoadPageTable) {
pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer;
}
pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp;
pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp;
pipeline_pt_params.initializing_warp = 4;
PipelinePT pipeline_page_table(
shared_storage.pipelines.load_page_table,
pipeline_pt_params);
TmemAllocator tmem_allocator;
pipeline_init_arrive_relaxed(size(ClusterShape{}));
pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm?
pipeline_mma_s.init_masks(ClusterShape{});
pipeline_p_mma.init_masks(ClusterShape{});
pipeline_mma_o.init_masks(ClusterShape{});
typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state;
typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state<PipelineLoadQK>();
typename PipelineS::PipelineState pipeline_mma_s_consumer_state;
typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state<PipelineS>();
typename PipelineP::PipelineState pipeline_p_mma_consumer_state;
typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state<PipelineP>();
typename PipelineO::PipelineState pipeline_mma_o_consumer_state;
typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state<PipelineO>();
typename PipelinePT::PipelineState pipeline_pt_consumer_state;
typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state<PipelinePT>();
pipeline_init_wait(size(ClusterShape{}));
if (role == WarpRole::kLoadPageTable) {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
load_page_table(
blk_coord,
problem_shape,
params.mainloop,
shared_storage.tensors,
pipeline_page_table, pipeline_pt_producer_state,
local_split_kv
);
}
}
else if (role == WarpRole::kLoad) {
if constexpr (kIsCpAsync) {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
load_cpasync(
blk_coord,
problem_shape,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv,
/* must be shared pipe */
pipeline_page_table, pipeline_pt_consumer_state
);
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
}
}
else {
if (params.mainloop.ptr_page_table != nullptr) {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
load_tma</* paged= */ true>(
blk_coord,
problem_shape,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv
);
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
}
}
else {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
load_tma<false>(
blk_coord,
problem_shape,
params.mainloop,
params.mainloop_params,
shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_producer_state,
pipeline_load_qk, pipeline_load_qk_producer_state,
local_split_kv
);
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait();
}
}
}
}
else if (role == WarpRole::kMma) {
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr);
__syncwarp();
if (is_mma_leader_cta) {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto local_split_kv = params.split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
mma(blk_coord,
problem_shape,
shared_storage.tensors,
pipeline_load_qk, pipeline_load_qk_consumer_state,
pipeline_load_qk, pipeline_load_qk_consumer_state,
pipeline_mma_s, pipeline_mma_s_producer_state,
pipeline_p_mma, pipeline_p_mma_consumer_state,
pipeline_mma_o, pipeline_mma_o_producer_state,
local_split_kv
);
}
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait();
//uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
//tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
else if (role == WarpRole::kCompute) {
CUTLASS_PRAGMA_NO_UNROLL
for (; tile_scheduler.is_valid(); ++tile_scheduler) {
auto blk_coord = tile_scheduler.get_block_coord();
auto problem_shape = params.problem_shape;
auto split_kv = params.split_kv;
auto local_split_kv = split_kv;
if (params.mainloop.ptr_seq != nullptr) {
get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)];
if (params.ptr_split_kv != nullptr) {
local_split_kv = params.ptr_split_kv[get<2>(blk_coord)];
}
}
if (local_split_kv <= get<3>(blk_coord))
continue;
compute(
blk_coord,
problem_shape,
params.mainloop, // for softmax_scale
params.epilogue,
shared_storage.tensors, // for smem_comm
pipeline_mma_s, pipeline_mma_s_consumer_state,
pipeline_p_mma, pipeline_p_mma_producer_state,
pipeline_mma_o, pipeline_mma_o_consumer_state,
local_split_kv
);
}
//cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive();
}
cute::cluster_sync();
cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive();
if (role == WarpRole::kMma) {
uint32_t free_stage_ptr = shared_storage.tmem_base_ptr;
tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
}
template<class BlkCoord>
CUTLASS_DEVICE void load_page_table(
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
PipelinePT& pipeline_page_table,
typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) {
auto [H, K, D, B] = problem_shape;
int batch_coord = get<2>(blk_coord);
auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table),
make_shape(mainloop_args.page_count, B),
mainloop_args.stride_page_table);
auto mPT = mPT_l(_, batch_coord);
int k_tile_total = ceil_div(K, TileShapeS{});
int k_tile_per_cta = ceil_div(k_tile_total, split_kv);
int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit
int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index);
if (k_tile_count == 0) {
return;
}
auto page_size = Pow2{mainloop_args.page_size};
auto pages_per_tile = Pow2{TileShapeS{} / page_size};
int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp;
#if 1
for (; k_tile_count > 0; ++k_index, --k_tile_count) {
pipeline_page_table.producer_acquire(pipeline_pt_producer_state);
// assume a single warp
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) {
int idx = i + thread_idx;
bool guard = idx < pages_per_tile;
int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx;
int pt_idx = pages_per_tile * k_index + idx;
cutlass::arch::cp_async_zfill<sizeof(int), cutlass::arch::CacheOperation::Always>(
&shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard
);
}
pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive);
++pipeline_pt_producer_state;
}
#endif
}
struct Gather {
int& page_table_stage;
Pow2 pages_per_tile;
const int * __restrict__ smem_page_table;
CUTLASS_DEVICE int operator()(int idx) const {
return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile];
}
CUTLASS_DEVICE friend void print(Gather const&) {
printf("<gather>");
}
};
template<class BlkCoord>
CUTLASS_DEVICE void load_cpasync(
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadQK& pipeline_load,
typename PipelineLoadQK::PipelineState& pipeline_load_producer_state,
int const& split_kv,
PipelinePT& pipeline_page_table,
typename PipelinePT::PipelineState& pipeline_pt_consumer_state) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
using X = Underscore;
int k_tile_total = ceil_div(K, TileShapeS{});
int k_tile_per_cta = ceil_div(k_tile_total, split_kv);
int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit
int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index);
if (k_tile_count == 0) {
return;
}
// partition all tensors
auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent);
auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope);
int paged_B = mainloop_args.page_count;
auto paged_K = Pow2{mainloop_args.page_size};
auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table);
int batch_coord = get<2>(blk_coord);
auto mPT = mPT_l(_, batch_coord);
auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{}));
ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{}));
auto tSgQL = cta_mma_qk.partition_A(gQL);
auto tSgQR = cta_mma_qk.partition_A(gQR);
Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{});
Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{});
auto make_copy_for = [](auto sT) {
auto rT_a = sT.layout()(_, _, _, _0{});
auto rT = make_ordered_layout(shape(rT_a), stride(rT_a));
auto threads = Int<kNumLoadWarps * cutlass::NumThreadsPerWarp>{};
auto values = Int<sizeof(uint128_t) / sizeof(Element)>{};
return make_cotiled_copy(
Copy_Atom<SM80_CP_ASYNC_CACHEALWAYS<uint128_t>, Element>{},
make_ordered_layout(
make_shape(threads, values),
make_stride(_1{}, _0{})),
rT);
};
// like cute::copy, but makes sure we do all page table lookups first
auto copy_split = [](auto atom, auto src, auto dst) {
auto src_v = group_modes<1, rank_v<decltype(src)>>(src);
auto dst_v = group_modes<1, rank_v<decltype(dst)>>(dst);
auto src_v_ptrs = make_tensor<Element*>(size<1>(src_v));
for (int i = 0; i < size<1>(src_v); i++) {
src_v_ptrs(i) = &src_v(_0{}, i);
}
for (int i = 0; i < size<1>(src_v); i++) {
auto src_v_i = make_tensor(
make_gmem_ptr(src_v_ptrs(i)),
make_shape(shape<0>(src_v)),
make_stride(make_stride(_1{}, _0{}))
);
atom.call(src_v_i, dst_v(_, i));
}
};
auto tiled_copy_q = make_copy_for(sQ);
auto tiled_copy_kc = make_copy_for(sKC);
auto tiled_copy_vc = make_copy_for(sVC);
auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp));
auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp));
auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp));
auto tQsQ = thr_copy_q.partition_D(sQ);
auto tQgQL = thr_copy_q.partition_S(tSgQL);
auto tQgQR = thr_copy_q.partition_S(tSgQR);
auto tKCsKC = thr_copy_kc.partition_D(sKC);
auto tVCsVC = thr_copy_vc.partition_D(sVC);
auto pipeline_pt_release_state = pipeline_pt_consumer_state;
int page_table_stage = -1;
Pow2 pages_per_tile{TileShapeS{} / paged_K};
const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin();
Gather gather{page_table_stage, pages_per_tile, smem_page_table};
auto mCL = make_tensor(
make_gmem_ptr(mainloop_args.ptr_c_latent),
ComposedLayout{
make_layout(
make_shape(make_shape(paged_K, paged_B), _1{}),
make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))),
make_coord(_0{}, _0{}),
make_identity_layout(make_shape(paged_K * paged_B, D_latent))});
auto mKR = make_tensor(
make_gmem_ptr(mainloop_args.ptr_k_rope),
ComposedLayout{
make_layout(
make_shape(make_shape(paged_K, paged_B), _1{}),
make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))),
make_coord(_0{}, _0{}),
make_identity_layout(make_shape(paged_K * paged_B, D_latent))});
auto mCLT = make_tensor(
make_gmem_ptr(mainloop_args.ptr_c_latent),
ComposedLayout{
make_layout(
make_shape(_1{}, make_shape(paged_K, paged_B)),
make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))),
make_coord(_0{}, _0{}),
make_identity_layout(make_shape(D_latent, paged_K * paged_B))});
auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto tSgCL = cta_mma_qk.partition_B(gCL);
auto tSgKR = cta_mma_qk.partition_B(gKR);
auto tOgCLT = cta_mma_pv.partition_B(gCLT);
auto tKCgCL = thr_copy_kc.partition_S(tSgCL);
auto tKCgKR = thr_copy_kc.partition_S(tSgKR);
auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
auto& pipeline_acquire_state = pipeline_load_producer_state;
auto pipeline_commit_state = pipeline_acquire_state;
int pipeline_offset = 0;
for (int i = 0; i < StagesPV; i++) {
cutlass::arch::cp_async_fence();
}
auto load_stage = [&](auto fn) {
pipeline_load.producer_acquire(pipeline_acquire_state);
fn(pipeline_acquire_state.index());
cutlass::arch::cp_async_fence();
++pipeline_acquire_state;
++pipeline_offset;
if (pipeline_offset == StagesPV - 1) {
cutlass::arch::cp_async_wait<StagesPV - 1>();
pipeline_load.producer_commit(pipeline_commit_state);
++pipeline_commit_state;
--pipeline_offset;
}
};
pipeline_page_table.consumer_wait(pipeline_pt_consumer_state);
page_table_stage = pipeline_pt_consumer_state.index();
++pipeline_pt_consumer_state;
// each Q/K tile consists of rope and latent
for (int i = 0; i < IterationsQKLatent; i++) {
load_stage([&](int index) {
cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i));
copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index));
});
}
for (int i = 0; i < IterationsQKRope; i++) {
load_stage([&](int index) {
cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i));
copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index));
});
}
k_index += 1;
k_tile_count -= 1;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
pipeline_page_table.consumer_wait(pipeline_pt_consumer_state);
page_table_stage = pipeline_pt_consumer_state.index();
++pipeline_pt_consumer_state;
for (int i = 0; i < IterationsQKLatent; i++) {
load_stage([&](int index) {
copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index));
});
}
for (int i = 0; i < IterationsQKRope; i++) {
load_stage([&](int index) {
copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index));
});
}
page_table_stage = pipeline_pt_release_state.index();
for (int i = 0; i < IterationsPV_K; i++) {
for (int j = 0; j < IterationsPV_N; j++) {
load_stage([&](int index) {
copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index));
});
}
}
pipeline_page_table.consumer_release(pipeline_pt_release_state);
++pipeline_pt_release_state;
k_index += 1;
k_tile_count -= 1;
}
page_table_stage = pipeline_pt_release_state.index();
for (int i = 0; i < IterationsPV_K; i++) {
for (int j = 0; j < IterationsPV_N; j++) {
load_stage([&](int index) {
copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index));
});
}
}
pipeline_page_table.consumer_release(pipeline_pt_release_state);
++pipeline_pt_release_state;
while (pipeline_offset > 0) {
cutlass::arch::cp_async_fence();
cutlass::arch::cp_async_wait<StagesPV - 1>();
pipeline_load.producer_commit(pipeline_commit_state);
++pipeline_commit_state;
--pipeline_offset;
}
cutlass::arch::cp_async_wait<0>();
}
template<bool kIsPaged = false, class BlkCoord>
CUTLASS_DEVICE void load_tma(
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
MainloopParams const& mainloop_params,
TensorStorage& shared_tensors,
PipelineLoadQK& pipeline_load_qk,
typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state,
PipelineLoadPV& pipeline_load_pv,
typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state,
int const& split_kv) {
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
int k_tile_total = ceil_div(K, TileShapeS{});
int k_tile_per_cta = ceil_div(k_tile_total, split_kv);
int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit
int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index);
if (k_tile_count == 0) {
return;
}
using X = Underscore;
// partition all tensors
auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B));
auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B));
int paged_B = B;
int paged_K = K;
if constexpr (kIsPaged) {
paged_B = mainloop_args.page_count;
paged_K = mainloop_args.page_size;
}
auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table);
auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B));
auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B));
auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B));
auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{});
auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step<X, _1, _1>{});
auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step<X, _1, _1>{});
ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{}));
ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{}));
auto tSgQL = cta_mma_qk.partition_A(gQL);
auto tSgQR = cta_mma_qk.partition_A(gQR);
auto tSgCL = cta_mma_qk.partition_B(gCL);
auto tSgKR = cta_mma_qk.partition_B(gKR);
auto tOgCLT = cta_mma_pv.partition_B(gCLT);
Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{});
Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{});
auto [tQLgQL_mkl, tQsQ] = tma_partition(
mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQL));
auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition(
mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}),
group_modes<0,3>(sQ), group_modes<0,3>(tSgQR));
auto [tCLgCL_nkl, tKCsKC] = tma_partition(
mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}),
group_modes<0,3>(sKC), group_modes<0,3>(tSgCL));
auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition(
mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}),
group_modes<0,3>(sKC), group_modes<0,3>(tSgKR));
auto [tCLTgCLT_nkl, tVCsVC] = tma_partition(
mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}),
group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT));
uint16_t mcast_mask = 0;
int batch_coord = get<2>(blk_coord);
Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord);
Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord);
auto mPT = mPT_l(_, batch_coord);
Tensor tCLgCL = tCLgCL_nkl(_, _, _, _);
Tensor tKRgKR = tKRgKR_nkl(_, _, _, _);
// careful: stage and k are swapped here!
Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _);
// latent is first in memory, so let's load it first always
// startup: alternate Q and K, set tx count appropriately, for k_idx = 0
// each Q/K tile consists of rope and latent
for (int i = 0; i < IterationsQKLatent; i++) {
pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ);
pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state);
auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state);
if (cute::elect_one_sync()) {
// expect the extra bytes
// load_qk ql
cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i));
// load_qk cl
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask),
tCLgCL(_, _0{}, i, mPT(k_index)),
tKCsKC(_, pipeline_load_qk_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask),
tCLgCL(_, k_index, i, batch_coord),
tKCsKC(_, pipeline_load_qk_producer_state.index()));
}
}
++pipeline_load_qk_producer_state;
}
for (int i = 0; i < IterationsQKRope; i++) {
pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ);
pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state);
auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state);
if (cute::elect_one_sync()) {
// expect the extra bytes
// load_qk ql
cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent));
// load_qk cl
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask),
tKRgKR(_, _0{}, i, mPT(k_index)),
tKCsKC(_, pipeline_load_qk_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask),
tKRgKR(_, k_index, i, batch_coord),
tKCsKC(_, pipeline_load_qk_producer_state.index()));
}
}
++pipeline_load_qk_producer_state;
}
k_index += 1;
k_tile_count -= 1;
// assume k_tile_count >= 1
// perform K+Q load here
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
// perform K load
for (int i = 0; i < IterationsQKLatent; i++) {
pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state);
auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state);
if (cute::elect_one_sync()) {
// load_qk cl
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask),
tCLgCL(_, _0{}, i, mPT(k_index)),
tKCsKC(_, pipeline_load_qk_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask),
tCLgCL(_, k_index, i, batch_coord),
tKCsKC(_, pipeline_load_qk_producer_state.index()));
}
}
++pipeline_load_qk_producer_state;
}
for (int i = 0; i < IterationsQKRope; i++) {
pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state);
auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state);
if (cute::elect_one_sync()) {
// load_qk cl
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask),
tKRgKR(_, _0{}, i, mPT(k_index)),
tKCsKC(_, pipeline_load_qk_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask),
tKRgKR(_, k_index, i, batch_coord),
tKCsKC(_, pipeline_load_qk_producer_state.index()));
}
}
++pipeline_load_qk_producer_state;
}
// prefetch next K load to keep busy while we transpose-load from cache
const int kPrefetchDistance = 1;
for (int i = 0; i < IterationsQKLatent; i++) {
if (cute::elect_one_sync()) {
if constexpr (kIsPaged) {
if (k_tile_count > kPrefetchDistance) {
cute::prefetch(
mainloop_params.tma_load_c_latent,
tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance))
);
}
}
else {
cute::prefetch(
mainloop_params.tma_load_c_latent,
tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord)
);
}
}
}
for (int i = 0; i < IterationsQKRope; i++) {
if (cute::elect_one_sync()) {
if constexpr (kIsPaged) {
if (k_tile_count > kPrefetchDistance) {
cute::prefetch(
mainloop_params.tma_load_k_rope,
tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance))
);
}
}
else {
cute::prefetch(
mainloop_params.tma_load_k_rope,
tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord)
);
}
}
}
// perform V load (k_idx - 1)
for (int i = 0; i < IterationsPV_K; i++) {
for (int j = 0; j < IterationsPV_N; j++) {
pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state);
auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state);
if (cute::elect_one_sync()) {
// load_pv cl
// note the transpose in indices!
// note we are off-by-one on k_index
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST),
tCLTgCLT(_, j, i, mPT(k_index - 1)),
tVCsVC(_, pipeline_load_pv_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST),
tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord),
tVCsVC(_, pipeline_load_pv_producer_state.index())
);
}
}
++pipeline_load_pv_producer_state;
}
}
k_index += 1;
k_tile_count -= 1;
}
for (int i = 0; i < IterationsPV_K; i++) {
for (int j = 0; j < IterationsPV_N; j++) {
pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state);
auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state);
if (cute::elect_one_sync()) {
// load_pv cl
// note the transpose in indices
// note we are off-by-one on k_index
if constexpr (kIsPaged) {
cute::copy(
mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST),
tCLTgCLT(_, j, i, mPT(k_index - 1)),
tVCsVC(_, pipeline_load_pv_producer_state.index())
);
}
else {
cute::copy(
mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST),
tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord),
tVCsVC(_, pipeline_load_pv_producer_state.index())
);
}
}
++pipeline_load_pv_producer_state;
}
}
}
template<class BlkCoord>
CUTLASS_DEVICE void mma(
BlkCoord const& blk_coord,
ProblemShape const& problem_shape,
TensorStorage& shared_tensors,
PipelineLoadQK& pipeline_load_qk,
typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state,
PipelineLoadPV& pipeline_load_pv,
typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state,
PipelineS& pipeline_mma_s,
typename PipelineS::PipelineState& pipeline_mma_s_producer_state,
PipelineP& pipeline_p_mma,
typename PipelineP::PipelineState& pipeline_p_mma_consumer_state,
PipelineO& pipeline_mma_o,
typename PipelineO::PipelineState& pipeline_mma_o_producer_state,
int const& split_kv) {
auto [H, K, D, B] = problem_shape;
int k_tile_total = ceil_div(K, TileShapeS{});
int k_tile_per_cta = ceil_div(k_tile_total, split_kv);
int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit
int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index);
if (k_tile_count == 0) {
return;
}
// mma init
Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{});
Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{});
Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{});
Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{});
Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ);
Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC);
Tensor tOrP = TiledMmaPV::make_fragment_A(sP);
Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC);
TiledMmaQK tiled_mma_qk;
TiledMmaPV tiled_mma_pv;
Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{}));
Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{}));
tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero;
pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state);
// Mma S0 S1 O0 S2 O1 ... Sn On-1 On
// S0 ownership -- ----- -- --
// S1 ownership -- ----- ----
// O ownership -- -- ---- --
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
for (int i = 0; i < IterationsQK; i++) {
pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state);
int read_stage = pipeline_load_qk_consumer_state.index();
tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSrQ(_,_,k_block,i),
tSrKC(_,_,k_block,read_stage),
tStS);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state);
++pipeline_load_qk_consumer_state;
}
pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state);
++pipeline_mma_s_producer_state;
k_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero;
for (int i = 0; i < IterationsQK; i++) {
pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state);
int read_stage = pipeline_load_qk_consumer_state.index();
tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) {
cute::gemm(tiled_mma_qk,
tSrQ(_,_,k_block,i),
tSrKC(_,_,k_block,read_stage),
tStS);
tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state);
++pipeline_load_qk_consumer_state;
}
pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state);
++pipeline_mma_s_producer_state;
pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state);
pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state);
for (int i = 0; i < IterationsPV_K; i++) {
auto acc_flag = tiled_mma_pv.accumulate_;
for (int j = 0; j < IterationsPV_N; j++) {
pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state);
int read_stage = pipeline_load_pv_consumer_state.index();
tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO);
tiled_mma_pv.accumulate_ = acc_flag;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) {
cute::gemm(tiled_mma_pv,
tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())),
tOrVC(_,_,k_block,read_stage),
tItI);
tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state);
++pipeline_load_pv_consumer_state;
}
}
pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state);
++pipeline_p_mma_consumer_state;
pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state);
++pipeline_mma_o_producer_state;
--k_tile_count;
}
pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state);
pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state);
for (int i = 0; i < IterationsPV_K; i++) {
auto acc_flag = tiled_mma_pv.accumulate_;
for (int j = 0; j < IterationsPV_N; j++) {
pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state);
int read_stage = pipeline_load_pv_consumer_state.index();
tItI.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO);
tiled_mma_pv.accumulate_ = acc_flag;
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) {
cute::gemm(tiled_mma_pv,
tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())),
tOrVC(_,_,k_block,read_stage),
tItI);
tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One;
}
pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state);
++pipeline_load_pv_consumer_state;
}
}
pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state);
++pipeline_p_mma_consumer_state;
pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state);
++pipeline_mma_o_producer_state;
}
template<class IsLastTile>
CUTLASS_DEVICE void softmax(
IsLastTile const& is_last_tile,
ElementAcc& row_max,
ElementAcc& row_sum,
ElementAcc& correction_factor,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
TensorStorage& shared_tensors,
int k_index,
uint32_t tmem_s,
int smem_p_index) {
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
TiledMmaQK tiled_mma_qk;
Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{}));
tStS.data() = tmem_s;
CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{});
CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{});
Tensor tAcc = tStS(make_coord(_,_),_0{},_0{});
Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{}));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_cS = thread_t2r.partition_D(cS);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_cS));
Tensor tTR_rS_frag = make_tensor<Element>(shape(tTR_rAcc));
const int AlignmentS = 4;
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
Tensor tTR_rAcc_vec = recast<Array<ElementAcc, AlignmentS>>(tTR_rAcc);
Tensor tTR_rS_vec = recast<Array<Element, AlignmentS>>(tTR_rS_frag);
// load s
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
if (is_last_tile) {
for (int i = 0; i < size(tTR_rAcc); i++) {
if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) {
tTR_rAcc(i) = -std::numeric_limits<ElementAcc>::infinity();
}
}
}
// max
ElementAcc row_max_new = row_max;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i += 1) {
row_max_new = ::fmax(row_max_new, tTR_rAcc(i));
}
// for 2x2 dp, reduce here
if constexpr (kWarpsInN > 1) {
shared_tensors.smem_exchange[threadIdx.x] = row_max_new;
cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync();
// (64, 2) shape
int peer_index = (threadIdx.x + 64) % 128;
row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]);
}
#ifndef B2B
// find correction factor
ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast<ElementAcc>(M_LOG2E);
correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new));
row_max = row_max_new;
// softmax
ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2);
}
#endif
// quantize
cutlass::NumericArrayConverter<Element, ElementAcc, AlignmentS> epilogue_op;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc_vec); i++) {
tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i));
}
Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index));
Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS);
// have a mapping for each thread to coord
// find identical mapping to coords for the MMA
auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{})));
auto sP_ = as_position_independent_swizzle_tensor(sP);
copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _));
// sum
row_sum *= correction_factor;
static_assert(cute::is_same_v<ElementAcc, float>);
auto tTR_rAcc_float2 = recast<float2>(tTR_rAcc);
auto sums = make_tensor<float2>(_4{});
static_assert(size(tTR_rAcc_float2) % size(sums) == 0);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(sums); i++) {
sums(i) = tTR_rAcc_float2(i);
}
CUTLASS_PRAGMA_UNROLL
for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(sums); j++) {
cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j));
}
}
CUTLASS_PRAGMA_UNROLL
for (int i = 1; i < size(sums); i *= 2) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < size(sums); j += 2*i) {
cute::add(sums(j), sums(j), sums(j+i));
}
}
row_sum += sums(0).x + sums(0).y;
}
CUTLASS_DEVICE void rescale(
ElementAcc correction_factor,
uint32_t tmem_o) {
// for b2b gemm, do nothing
#ifndef B2B
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
auto store_op = TMEM::tmem_load_to_store(load_op);
TiledMmaPV tiled_mma_pv;
Tensor tItI = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{}));
tItI.data() = tmem_o;
CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{});
CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{});
Tensor tAcc = tItI(make_coord(_,_),_0{},_0{});
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto tiled_r2t = make_tmem_copy(store_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
auto thread_r2t = tiled_r2t.get_slice(thread_idx);
Tensor tTR_gO = thread_t2r.partition_D(gO);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
// load o
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
// multiply by correction factor
float2 correction_factor_vec = make_float2(correction_factor, correction_factor);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i += 2) {
float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1));
float2 out;
cute::mul(out, in, correction_factor_vec);
tTR_rAcc(i + 0) = out.x;
tTR_rAcc(i + 1) = out.y;
}
// store o
copy(tiled_r2t, tTR_rAcc, tTR_tAcc);
#endif
}
template<class BlkCoord>
CUTLASS_DEVICE void epilogue(
ElementAcc& row_max,
ElementAcc& row_sum,
BlkCoord const& cta_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
TensorStorage& shared_tensors,
uint32_t tmem_o,
int const& split_kv) {
auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{};
TiledMmaPV tiled_mma_pv;
Tensor tItI = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{})));
tItI.data() = tmem_o;
CUTE_STATIC_ASSERT_V(shape<1>(tItI) == _1{});
CUTE_STATIC_ASSERT_V(shape<2>(tItI) == _1{});
Tensor tAcc = tItI(make_coord(_,_),_0{},_0{});
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
if (epilogue_args.ptr_o_acc != nullptr) {
using ElementOutAcc = ElementAcc;
constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v<ElementOutAcc>;
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc);
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_gO = thread_t2r.partition_D(gO);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
Tensor tTR_rO_frag = make_tensor<ElementOutAcc>(shape(tTR_rAcc));
Tensor tTR_rO_src = recast<Array<ElementOutAcc, AlignmentOutAcc>>(coalesce(tTR_rO_frag));
Tensor tR2G_rO_dst = recast<Array<ElementOutAcc, AlignmentOutAcc>>(coalesce(tTR_gO));
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
cutlass::epilogue::thread::LinearCombination<ElementOutAcc, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> epilogue_op({epilogue_args.output_scale / row_sum});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i));
}
copy(tTR_rO_src, tR2G_rO_dst);
#ifndef B2B
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
#endif
}
else {
Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o);
auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{});
Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord));
auto tiled_t2r = make_tmem_copy(load_op, tAcc);
auto thread_idx = threadIdx.x % size(tiled_t2r);
auto thread_t2r = tiled_t2r.get_slice(thread_idx);
Tensor tTR_gO = thread_t2r.partition_D(gO);
Tensor tTR_rAcc = make_tensor<ElementAcc>(shape(tTR_gO));
Tensor tTR_rO_frag = make_tensor<ElementOut>(shape(tTR_rAcc));
Tensor tTR_rO_src = recast<Array<ElementOut, AlignmentOut>>(coalesce(tTR_rO_frag));
Tensor tR2G_rO_dst = recast<Array<ElementOut, AlignmentOut>>(coalesce(tTR_gO));
Tensor tTR_tAcc = thread_t2r.partition_S(tAcc);
copy(tiled_t2r, tTR_tAcc, tTR_rAcc);
cutlass::epilogue::thread::LinearCombination<ElementOut, 1, ElementAcc, ElementAcc, cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling> epilogue_op({epilogue_args.output_scale / row_sum});
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(tTR_rAcc); i++) {
tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i));
}
copy(tTR_rO_src, tR2G_rO_dst);
#ifndef B2B
if (epilogue_args.ptr_lse != nullptr) {
// compute LSE
ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max;
// store LSE
Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse);
Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{});
// for 2x2 dp, this must be conditional and the index is wrong
if (! kIs2Sm || (threadIdx.x < 64))
{
gLSE(threadIdx.x) = lse;
}
}
#endif
}
}
template<class CtaCoord>
CUTLASS_DEVICE void compute(
CtaCoord const& cta_coord,
ProblemShape const& problem_shape,
MainloopArguments const& mainloop_args,
EpilogueParams const& epilogue_args,
TensorStorage& shared_tensors,
PipelineS& pipeline_mma_s,
typename PipelineS::PipelineState& pipeline_mma_s_consumer_state,
PipelineP& pipeline_p_mma,
typename PipelineP::PipelineState& pipeline_p_mma_producer_state,
PipelineO& pipeline_mma_o,
typename PipelineO::PipelineState& pipeline_mma_o_consumer_state,
int const& split_kv) {
auto [H, K, D, B] = problem_shape;
int k_tile_total = ceil_div(K, TileShapeS{});
int k_tile_per_cta = ceil_div(k_tile_total, split_kv);
int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit
int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index);
if (k_tile_count == 0) {
// if we return early, we have to make sure we release the load warp
cutlass::arch::NamedBarrier(
(kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp,
kNamedBarrierEpilogue
).arrive();
return;
}
int k_index_final = k_tile_total - 1;
ElementAcc row_max = -std::numeric_limits<ElementAcc>::infinity();
ElementAcc row_sum = 0;
ElementAcc correction_factor = 1;
pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state);
pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state);
auto dispatch_bool = [](bool b, auto fn) {
if (b) {
fn(cute::true_type{});
}
else {
fn(cute::false_type{});
}
};
// softmax s0 -> p0
dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) {
softmax(
is_last_tile,
row_max, row_sum, correction_factor,
problem_shape, mainloop_args, shared_tensors, k_index,
uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1),
pipeline_p_mma_producer_state.index()
);
});
k_index += 1;
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::fence_view_async_shared();
pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state);
++pipeline_mma_s_consumer_state;
pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state);
++pipeline_p_mma_producer_state;
k_tile_count -= 1;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile_count > 0) {
pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state);
pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state);
// softmax s1 -> p1
dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) {
softmax(
is_last_tile,
row_max, row_sum, correction_factor,
problem_shape, mainloop_args, shared_tensors, k_index,
uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1),
pipeline_p_mma_producer_state.index()
);
});
cutlass::arch::fence_view_async_tmem_load();
cutlass::arch::fence_view_async_shared();
pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state);
++pipeline_mma_s_consumer_state;
pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state);
++pipeline_p_mma_producer_state;
pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state);
// rescale
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO));
}
cutlass::arch::fence_view_async_tmem_store();
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
++pipeline_mma_o_consumer_state;
--k_tile_count;
k_index += 1;
}
pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state);
#ifdef B2B
row_sum = 1;
#else
if constexpr (kWarpsInN > 1) {
// reduce row_sum if needed (for 2x2 dp)
shared_tensors.smem_exchange[threadIdx.x] = row_sum;
cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync();
// (64, 2) shape
int peer_index = (threadIdx.x + 64) % 128;
row_sum += shared_tensors.smem_exchange[peer_index];
}
#endif
cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive();
// epilogue
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < IterationsPV_N; j++) {
epilogue(
row_max, row_sum,
replace<1>(cta_coord, j), problem_shape,
mainloop_args, epilogue_args, shared_tensors,
uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), split_kv
);
}
cutlass::arch::fence_view_async_tmem_load();
pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state);
++pipeline_mma_o_consumer_state;
}
};
///////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/***************************************************************************************************
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
// clang-format off
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/kernel_hardware_info.h"
namespace cutlass::fmha::kernel {
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaIndividualTileScheduler {
struct Params {
dim3 grid;
};
bool valid_ = true;
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler(Params const&) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
return Params{ grid };
}
static dim3 get_grid_shape(Params const& params) {
return params.grid;
}
CUTLASS_DEVICE
bool is_valid() {
return valid_;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
}
CUTLASS_DEVICE
Sm100MlaIndividualTileScheduler& operator++() {
valid_ = false;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
struct Sm100MlaPersistentTileScheduler {
struct Params {
int num_blocks;
FastDivmod divmod_m_block;
FastDivmod divmod_b;
FastDivmod divmod_split_kv;
KernelHardwareInfo hw_info;
};
int block_idx = 0;
Params params;
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
template<class ProblemShape, class ClusterShape>
static Params to_underlying_arguments(
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
ClusterShape const& cluster_shape, int const& split_kv) {
using namespace cute;
// Get SM count if needed, otherwise use user supplied SM count
int sm_count = hw_info.sm_count;
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
}
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;
int num_m_blocks = size<0>(cluster_shape);
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
num_blocks *= split_kv; /* Maximum Split KV*/
return Params {
num_blocks,
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
hw_info
};
}
static dim3 get_grid_shape(Params const& params) {
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
return grid;
}
CUTLASS_DEVICE
bool is_valid() {
return block_idx < params.num_blocks;
}
CUTLASS_DEVICE
auto get_block_coord() {
using namespace cute;
int block_decode = block_idx;
int m_block, bidb, n_split_kv;
params.divmod_m_block(block_decode, m_block, block_decode);
params.divmod_b(block_decode, bidb, block_decode);
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
return make_coord(m_block, _0{}, bidb, n_split_kv);
}
CUTLASS_DEVICE
Sm100MlaPersistentTileScheduler& operator++() {
block_idx += gridDim.x;
return *this;
}
};
////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::fmha::kernel
/*
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/*
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
* by Alcanderian JieXin Liang
*/
#include "core/registration.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cutlass/cutlass.h>
#include <cutlass/kernel_hardware_info.h>
#include <torch/all.h>
#include <cute/tensor.hpp>
#include <iostream>
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
// clang-format off
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
}
#else
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
}
using namespace cute;
using namespace cutlass::fmha::kernel;
template <bool v>
struct IsPersistent {
static const bool value = v;
};
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
struct MlaSm100 {
using Element = T;
using ElementAcc = float;
using ElementOut = T;
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
using TileShapeH = cute::tuple_element_t<0, TileShape>;
using TileShapeD = cute::tuple_element_t<2, TileShape>;
// H K (D_latent D_rope) B
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
using StrideO = StrideK; // H D B
using StrideLSE = cute::tuple<_1, int>; // H B
using TileScheduler =
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
TileShape,
Element,
ElementAcc,
ElementOut,
ElementAcc,
TileScheduler,
/*kIsCpAsync=*/!IsPaged128>;
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
};
template <typename T>
typename T::Fmha::Arguments args_from_options(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
double sm_scale,
int64_t num_kv_splits) {
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = q_nope.device().index();
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
int batches = q_nope.sizes()[0];
int page_count_per_seq = page_table.sizes()[1];
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
int page_size = kv_c_and_k_pe_cache.sizes()[1];
int max_seq_len = page_size * page_count_per_seq;
using TileShapeH = typename T::TileShapeH;
using TileShapeD = typename T::TileShapeD;
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
auto [H, K, D, B] = problem_shape;
auto [D_latent, D_rope] = D;
float scale = float(sm_scale);
using StrideQ = typename T::StrideQ;
using StrideK = typename T::StrideK;
using StrideO = typename T::StrideO;
using StrideLSE = typename T::StrideLSE;
StrideQ stride_Q_nope = cute::make_tuple(
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
StrideQ stride_Q_pe = cute::make_tuple(
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
StrideK stride_C = cute::make_tuple(
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
using Element = typename T::Element;
using ElementOut = typename T::ElementOut;
using ElementAcc = typename T::ElementAcc;
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
typename T::Fmha::Arguments arguments{
problem_shape,
{scale,
Q_nope_ptr,
stride_Q_nope,
Q_pe_ptr,
stride_Q_pe,
C_ptr,
stride_C,
C_ptr + D_latent,
stride_C,
static_cast<int*>(seq_lens.data_ptr()),
static_cast<int*>(page_table.data_ptr()),
stride_PT,
page_count_total,
page_size},
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info,
// TODO(trevor-m): Change split_kv back to -1 when
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
// perform worse with larger context length and smaller batch sizes.
num_kv_splits, // split_kv
nullptr, // is_var_split_kv
};
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
// split_kv automatically based on batch size and sequence length to balance
// workload across available SMs. Consider using var_split_kv for manual
// control if needed.
T::Fmha::set_split_kv(arguments);
return arguments;
}
template <typename Element, bool IsPaged128, typename PersistenceOption>
void runMla(
at::Tensor const& out,
at::Tensor const& q_nope,
at::Tensor const& q_pe,
at::Tensor const& kv_c_and_k_pe_cache,
at::Tensor const& seq_lens,
at::Tensor const& page_table,
at::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits,
cudaStream_t stream) {
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
typename MlaSm100Type::Fmha fmha;
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
CUTLASS_CHECK(fmha.can_implement(arguments));
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
}
#define DISPATCH_BOOL(expr, const_expr, ...) \
[&]() -> bool { \
if (expr) { \
constexpr bool const_expr = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
return __VA_ARGS__(); \
} \
}()
void sm100_cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope,
torch::Tensor const& q_pe,
torch::Tensor const& kv_c_and_k_pe_cache,
torch::Tensor const& seq_lens,
torch::Tensor const& page_table,
torch::Tensor const& workspace,
double sm_scale,
int64_t num_kv_splits) {
auto in_dtype = q_nope.dtype();
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
// Maybe per batch split kv will fix this.
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
if (in_dtype == at::ScalarType::Half) {
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::BFloat16) {
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
} else {
TORCH_CHECK(false, "Unsupported input data type of MLA");
}
return true;
});
return true;
});
}
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
// which are float, so Element type here doesn't matter.
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
// Get split kv. Requires problem shape and sm_count only.
typename MlaSm100Type::Fmha::Arguments arguments;
using TileShapeH = typename MlaSm100Type::TileShapeH;
using TileShapeD = typename MlaSm100Type::TileShapeD;
arguments.problem_shape =
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
// Assumes device 0 when getting sm_count.
arguments.hw_info.sm_count =
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
arguments.split_kv = num_kv_splits;
MlaSm100Type::Fmha::set_split_kv(arguments);
return MlaSm100Type::Fmha::get_workspace_size(arguments);
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
}
// clang-format on
......@@ -18,12 +18,7 @@
*/
#include "attention_kernels.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
......@@ -187,7 +182,6 @@ void paged_attention_v1(
CALL_V1_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
......@@ -18,12 +18,7 @@
*/
#include "attention_kernels.cuh"
#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE warpSize
#endif
#include "cuda_compat.h"
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
......@@ -197,7 +192,6 @@ void paged_attention_v2(
CALL_V2_LAUNCHER_BLOCK_SIZE)
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
......@@ -33,6 +33,8 @@ namespace vec_op {
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
// Number of elements in single ASIMD vector of given Datatype
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
namespace {
template <typename T, T... indexes, typename F>
......@@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
}
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / 8;
int remainder = elem_num % 8;
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
if (full_blocks > 0) {
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
......@@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
};
};
struct BF16Vec32 : public Vec<BF16Vec32> {
......@@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
void save(void* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_bf16(
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
bfloat16x8_t temp = reg.val[full_blocks];
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
base[0] = vgetq_lane_bf16(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
}
};
};
#endif
......@@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
}
};
struct INT32Vec16 : public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int32x4x4_t reg;
int32_t values[VEC_ELEM_NUM];
};
int32x4x4_t reg;
explicit INT32Vec16(const void* ptr) {
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
}
void save(int32_t* ptr) const {
vst1q_s32(ptr, reg.val[0]);
vst1q_s32(ptr + 4, reg.val[1]);
vst1q_s32(ptr + 8, reg.val[2]);
vst1q_s32(ptr + 12, reg.val[3]);
};
void save(int32_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_s32(
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
int32x4_t temp = reg.val[full_blocks];
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
}
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
......@@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
};
explicit FP32Vec16(const INT32Vec16& v) {
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
};
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
vaddq_f32(reg.val[1], b.reg.val[1]),
......@@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vdivq_f32(reg.val[3], b.reg.val[3])}));
};
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
return FP32Vec16(float32x4x4_t(
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
};
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
vmaxq_f32(b.reg.val[1], reg.val[1]),
vmaxq_f32(b.reg.val[2], reg.val[2]),
vmaxq_f32(b.reg.val[3], reg.val[3])}));
};
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
}
if (remainder > 1) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
}
if (remainder > 2) {
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(float32x4x4_t({
vminq_f32(b.reg.val[0], reg.val[0]),
vminq_f32(b.reg.val[1], reg.val[1]),
vminq_f32(b.reg.val[2], reg.val[2]),
vminq_f32(b.reg.val[3], reg.val[3]),
}));
};
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
float32x4x4_t temp;
for (int i = 0; i < full_blocks; i++)
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
if (remainder > 0) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
vgetq_lane_f32(b.reg.val[full_blocks], 0));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
}
if (remainder > 1) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
vgetq_lane_f32(b.reg.val[full_blocks], 1));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
}
if (remainder > 2) {
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
vgetq_lane_f32(b.reg.val[full_blocks], 2));
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
}
return FP32Vec16(temp);
};
FP32Vec16 abs() const {
return FP32Vec16(
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
......@@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return answer;
};
float reduce_max() const {
AliasReg ar;
ar.reg = reg;
float max_v = std::numeric_limits<float>::lowest();
unroll_loop<int, VEC_ELEM_NUM>(
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
return max_v;
}
float reduce_min() const {
AliasReg ar;
ar.reg = reg;
float min_v = std::numeric_limits<float>::max();
unroll_loop<int, VEC_ELEM_NUM>(
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
return min_v;
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
......@@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
vst1q_f32(ptr + 8, reg.val[2]);
vst1q_f32(ptr + 12, reg.val[3]);
};
void save(float* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
for (int i = 0; i < full_blocks; i++)
vst1q_f32(
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
reg.val[i]);
if (remainder > 0) {
float32x4_t temp = reg.val[full_blocks];
float* base = reinterpret_cast<float32_t*>(ptr) +
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
}
}
};
struct INT8Vec16 : public Vec<INT8Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
int8x16_t reg;
int8_t values[VEC_ELEM_NUM];
};
int8x16_t reg;
explicit INT8Vec16(const FP32Vec16& vec) {
// Convert each 128-bit float32 vector to int32
int32x4_t part0 =
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
int32x4_t part1 =
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
int32x4_t part2 =
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
int32x4_t part3 =
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
// Narrow each 32-bit vector to 8 bits and combine
int8x8_t lower =
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
int8x8_t upper =
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
}
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
void save(int8_t* ptr, const int elem_num) const {
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
for (int i = 0; i < full_blocks; i++)
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
if (remainder > 0) {
int8x16_t temp = reg;
int8_t* base =
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
}
};
};
template <typename T>
......
......@@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
// Note: Due to the limitation of oneDNN
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
// not supported.
template <typename OutputT, typename BiasT>
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
......@@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
}
dnnl::matmul::primitive_desc matmul_pd;
// Create memory descriptors with format_tag::any for the primitive. This
// enables the matmul primitive to choose memory layouts for an
// optimized primitive implementation, and these layouts may differ from the
// ones provided by the user.
#ifdef __aarch64__
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
dnnl::memory::format_tag::any);
auto mat_weights_md = dnnl::memory::desc(
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
auto mat_dst_md =
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
if (bias) {
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
mat_weights_md, bias_md,
mat_dst_md, attr);
} else {
matmul_pd = dnnl::matmul::primitive_desc(
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
}
#else
if (bias) {
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
......@@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
c_md, attr);
}
#endif
dnnl::matmul matmul(matmul_pd);
auto& engine = default_engine();
......@@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
(void*)b_scales);
auto& stream = default_stream();
auto mat_src_mem = a_m;
auto mat_weights_mem = b_m;
auto mat_dst_mem = c_m;
#ifdef __aarch64__
if (matmul_pd.weights_desc() != b_m.get_desc()) {
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
}
#endif
if constexpr (InputNoScale) {
if (bias) {
dnnl::memory::desc bias_md({N}, BiasType, {1});
dnnl::memory bias_m(bias_md, engine, (void*)bias);
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_SRC, mat_src_mem},
{DNNL_ARG_WEIGHTS, mat_weights_mem},
{DNNL_ARG_BIAS, bias_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_DST, mat_dst_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
} else {
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_SRC, mat_src_mem},
{DNNL_ARG_WEIGHTS, mat_weights_mem},
{DNNL_ARG_DST, mat_dst_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
}
......@@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
dnnl::memory bias_m(bias_md, engine, (void*)bias);
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_SRC, mat_src_mem},
{DNNL_ARG_WEIGHTS, mat_weights_mem},
{DNNL_ARG_BIAS, bias_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_DST, mat_dst_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
} else {
matmul.execute(
stream, {
{DNNL_ARG_SRC, a_m},
{DNNL_ARG_WEIGHTS, b_m},
{DNNL_ARG_DST, c_m},
{DNNL_ARG_SRC, mat_src_mem},
{DNNL_ARG_WEIGHTS, mat_weights_mem},
{DNNL_ARG_DST, mat_dst_mem},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
});
......@@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
return stream;
}
};
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment