Commit af7f4372 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.5.5' into v0.5.5-dtk24.04.1

parents 5e19cdef 09c77926
......@@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor:
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device='cuda') * 5
b = torch.randn((n, k), device='cuda').t() * 5
......@@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
raise ValueError("unsupported dtype")
# impl
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch.mm(a, b)
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype)
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return torch._scaled_mm(a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=out_dtype,
use_fast_accum=True)
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype) -> torch.Tensor:
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
# bench
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
sub_label: str, fn: Callable, description: str) -> TMeasurement:
def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args,
**kwargs) -> TMeasurement:
min_run_time = 1
globals = {
"a": a,
"b": b,
"scale_a": scale_a,
"scale_b": scale_b,
"out_dtype": out_dtype,
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(a, b, scale_a, scale_b, out_dtype)",
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
......@@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
azp = torch.zeros((m, ), device="cuda", dtype=torch.int32)
azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32)
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16)))
# pytorch impl - float16
timers.append(
bench_fn(a.to(dtype=torch.float16, device="cuda"),
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
torch.float16, label, sub_label, pytorch_mm_impl,
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
bench_fn(label, sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm,
a.to(dtype=torch.float16), b.to(dtype=torch.float16)))
# cutlass impl
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass with bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass with azp per-tensor
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj))
# cutlass with azp per-tensor + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, None, bias))
# cutlass with azp per-token
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp))
# cutlass with azp per-token + bias
timers.append(
bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias",
ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b,
torch.bfloat16, azp_adj, azp, bias))
return timers
......@@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16)
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
torch.bfloat16, label, sub_label, pytorch_mm_impl,
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm, a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda")))
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16))
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True))
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16))
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
pytorch_fp8_impl_fast_accum,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum"))
bench_fn(label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True))
# cutlass impl: bf16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b,
torch.bfloat16))
# cutlass impl: fp16 output
timers.append(
bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label,
cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm"))
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16))
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16,
bias))
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias",
ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16,
bias.to(dtype=torch.float16)))
return timers
......@@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm",
......@@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
print_timers(data)
......@@ -251,7 +281,6 @@ def run_range_bench(args):
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
......
import random
import time
import torch
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
add_residual: bool,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")
layer = RMSNorm(hidden_size).to(dtype=dtype)
layer.weight.data.normal_(mean=1.0, std=0.1)
scale = 1 / (2 * hidden_size)
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
x *= scale
residual = torch.randn_like(x) * scale if add_residual else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()
for _ in range(num_iters):
layer(x, residual)
torch.cuda.synchronize()
end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)
# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__':
parser = FlexibleArgumentParser(
description="Benchmark the layernorm kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--add-residual", action="store_true")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")
args = parser.parse_args()
print(args)
main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
add_residual=args.add_residual,
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
import argparse
import copy
import itertools
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # make col major
return ops.machete_prepack_B(w_q, wtype)
def make_bench_tensors(
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
k: int
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
torch.tensor]]]:
assert wtype.is_integer(), "TODO: support floating point weights"
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
weights = [
torch.randn((k, n), device="cuda", dtype=atype)
for _ in range(num_weights)
]
quanitized_weights = [
quantize_weights(w, wtype, group_size) for w in weights
]
return a, quanitized_weights
# impl
# bench
def bench_fn(label: str, sub_label: str, description: str,
fn: Callable) -> TMeasurement:
min_run_time = 1
return TBenchmark.Timer(
stmt="fn()",
globals={
"fn": fn
},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def loop_over_weights(
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
torch.tensor, torch.tensor]],
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
None]):
for w_ref, w_q, w_s, _ in weights:
fn(a, w_ref, w_q, w_s)
def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
for w_ref, w_q, w_s, w_zp in weights]
timers = []
# pytorch impl
timers.append(
bench_fn(
label, sub_label, "torch.matmul", lambda: loop_over_weights(
a,
weights,
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
)))
if benchmark_marlinv1:
w_ref = weights[0][0]
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
wtype.size_bits)
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
marlinv1_permute_scales(w_s), w_zp)
for w_ref, w_q, w_s, w_zp in weights]
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
# marlinv1
timers.append(
bench_fn(
label, sub_label, "marlin_orig", lambda: loop_over_weights(
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
gptq_marlin_gemm(a,
w_q,
w_s,
w_zp_empty,
g_idx,
sort_indices,
workspace.scratch,
wtype,
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
# machete
timers.append(
bench_fn(
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
if sweep_schedules:
print("Finding best schedule for machete")
best = None
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
w_q,
wtype,
w_s,
b_group_size=group_size,
schedule=schedule)
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
best_schedule = schedule
print("Best schedule:", best_schedule)
timers.append(best)
return timers
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype, sweep_schedules: bool,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
scalar_types.uint4b8,
128,
m,
k,
n,
f"{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
sweep_schedules=sweep_schedules)
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, args.sweep_schedules, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "bfloat16":
return torch.bfloat16
if dt == "float16":
return torch.float16
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['bfloat16', 'float16']",
)
parser.add_argument(
"--sweep-schedules",
action="store_true",
help="Run a sweep over all supported schedules",
)
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
......@@ -30,11 +30,28 @@ def benchmark_config(
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
) -> float:
init_dtype = torch.float16 if use_fp8 else dtype
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
......@@ -52,7 +69,11 @@ def benchmark_config(
w2_scale = None
a1_scale = None
a2_scale = None
if use_fp8:
if use_int8_w8a16:
w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size),
dtype=torch.float32)
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
if use_fp8_w8a8:
w1_scale = torch.randn(num_experts, dtype=torch.float32)
w2_scale = torch.randn(num_experts, dtype=torch.float32)
a1_scale = torch.randn(1, dtype=torch.float32)
......@@ -76,7 +97,8 @@ def benchmark_config(
renormalize=True,
inplace=True,
override_config=config,
use_fp8=use_fp8,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
......@@ -155,11 +177,13 @@ class BenchmarkWorker:
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
) -> Tuple[Dict[str, int], float]:
torch.cuda.manual_seed_all(self.seed)
dtype_str = "float8" if use_fp8 else None
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
op_config = get_moe_configs(num_experts, shard_intermediate_size // 2,
......@@ -173,7 +197,8 @@ class BenchmarkWorker:
key=lambda x: abs(x - num_tokens))]
kernel_time = benchmark_config(config, num_tokens, num_experts,
shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8,
use_int8_w8a16)
return config, kernel_time
def tune(
......@@ -184,9 +209,10 @@ class BenchmarkWorker:
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
search_space: List[BenchmarkConfig],
) -> BenchmarkConfig:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
for config in tqdm(search_space):
......@@ -198,7 +224,8 @@ class BenchmarkWorker:
hidden_size,
topk,
dtype,
use_fp8,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=10)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
......@@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
}
def save_configs(
configs: Dict[int, BenchmarkConfig],
num_experts: int,
shard_intermediate_size: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8: bool,
) -> None:
dtype_str = "float8" if use_fp8 else None
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
json.dump(configs, f, indent=4)
......@@ -253,6 +279,11 @@ def main(args: argparse.Namespace):
topk = config.ffn_config.moe_top_k
intermediate_size = config.ffn_config.ffn_hidden_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
# Default: Mixtral.
E = config.num_local_experts
......@@ -262,7 +293,8 @@ def main(args: argparse.Namespace):
hidden_size = config.hidden_size
dtype = config.torch_dtype
use_fp8 = args.dtype == "fp8"
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
if args.batch_size is None:
batch_sizes = [
......@@ -294,20 +326,20 @@ def main(args: argparse.Namespace):
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8, search_space)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8)
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
outputs = _distribute("benchmark",
[(batch_size, E, shard_intermediate_size,
hidden_size, topk, dtype, use_fp8)
outputs = _distribute(
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
for batch_size in batch_sizes])
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
......@@ -323,7 +355,7 @@ if __name__ == "__main__":
parser.add_argument("--tp-size", "-tp", type=int, default=2)
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
......
import random
import time
import torch
from vllm import _custom_ops as ops
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser
@torch.inference_mode()
def main(num_tokens: int,
hidden_size: int,
static_scale: bool,
quant_dtype: torch.dtype,
dtype: torch.dtype,
seed: int = 0,
do_profile: bool = False,
num_warmup_iters: int = 5,
num_iters: int = 100) -> None:
random.seed(seed)
torch.random.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.set_default_device("cuda")
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None
def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
torch.cuda.synchronize()
if profile:
torch.cuda.cudart().cudaProfilerStart()
start_time = time.perf_counter()
for _ in range(num_iters):
if quant_dtype == torch.int8:
ops.scaled_int8_quant(x, scale)
else:
ops.scaled_fp8_quant(x, scale)
torch.cuda.synchronize()
end_time = time.perf_counter()
if profile:
torch.cuda.cudart().cudaProfilerStart()
return (end_time - start_time) / num_iters
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=num_warmup_iters, profile=False)
# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=num_iters, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")
if __name__ == '__main__':
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError(f"Unsupported dtype: {dt}")
parser = FlexibleArgumentParser(
description="Benchmark the quantization (fp8 or int8) kernel.")
parser.add_argument("--num-tokens", type=int, default=4096)
parser.add_argument("--hidden-size", type=int, default=8192)
parser.add_argument("--static-scale", action="store_true")
parser.add_argument("--quant-dtype",
type=str,
choices=["fp8", "int8"],
default="int8")
parser.add_argument("--dtype",
type=str,
choices=["half", "bfloat16", "float"],
default="half")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--profile", action="store_true")
parser.add_argument("--num-warmup-iters", type=int, default=5)
parser.add_argument("--num-iters",
type=int,
default=100,
help="Number of benchmark iterations. "
"If --profile is set, this number is ignored")
args = parser.parse_args()
print(args)
main(num_tokens=args.num_tokens,
hidden_size=args.hidden_size,
static_scale=args.static_scale,
quant_dtype=to_torch_dtype(args.quant_dtype),
dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
seed=args.seed,
do_profile=args.profile,
num_warmup_iters=args.num_warmup_iters,
num_iters=args.num_iters)
import math
import pickle
import re
from collections import defaultdict
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('filename', type=str)
args = parser.parse_args()
with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
results = defaultdict(lambda: list())
for v in data:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
else:
raise Exception("MKN not found")
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
if result is not None:
M = result.group(1)
else:
raise Exception("MKN not found")
kernel = v.task_spec.description
results[KN].append({
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten()
axs_idx = 0
for shape, data in results.items():
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
x="batch_size",
y="median",
hue="kernel",
style="kernel",
markers=True,
dashes=False,
palette="Dark2")
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout()
plt.savefig("graph_machete_bench.pdf")
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
......@@ -66,6 +66,8 @@ DEFAULT_CONDA_PATTERNS = {
"nccl",
"transformers",
"zmq",
"nvidia",
"pynvml",
}
DEFAULT_PIP_PATTERNS = {
......@@ -79,6 +81,8 @@ DEFAULT_PIP_PATTERNS = {
"nccl",
"transformers",
"zmq",
"nvidia",
"pynvml",
}
......@@ -265,8 +269,9 @@ def get_neuron_sdk_version(run_lambda):
def get_vllm_version():
try:
import vllm
return vllm.__version__
except ImportError:
return vllm.__version__ + "@" + vllm.__commit__
except Exception:
# old version of vllm does not have __commit__
return 'N/A'
......
......@@ -122,7 +122,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
A_vec qk_vec = mul<A_vec, Vec, Vec>(q[0], k[0]);
#pragma unroll
for (int ii = 1; ii < N; ++ii) {
qk_vec = fma(q[ii], k[ii], qk_vec);
qk_vec = vllm::fma(q[ii], k[ii], qk_vec);
}
// Finalize the reduction across lanes.
float qk = sum(qk_vec);
......
......@@ -21,7 +21,7 @@ namespace vllm {
//
class ScalarType {
public:
enum NanRepr : int64_t {
enum NanRepr : uint8_t {
NAN_NONE = 0, // nans are not supported
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
......@@ -29,33 +29,33 @@ class ScalarType {
NAN_REPR_ID_MAX
};
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
int64_t bias, bool finite_values_only = false,
constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_,
int32_t bias, bool finite_values_only = false,
NanRepr nan_repr = NAN_IEEE_754)
: exponent(exponent),
mantissa(mantissa),
bias(bias),
signed_(signed_),
bias(bias),
finite_values_only(finite_values_only),
nan_repr(nan_repr){};
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
return ScalarType(true, 0, size_bits - 1, bias);
static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits - 1, true, bias);
}
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
return ScalarType(false, 0, size_bits, bias);
static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) {
return ScalarType(0, size_bits, false, bias);
}
// IEEE 754 compliant floating point type
static constexpr ScalarType float_IEEE754(int64_t exponent,
int64_t mantissa) {
static constexpr ScalarType float_IEEE754(uint8_t exponent,
uint8_t mantissa) {
TORCH_CHECK(mantissa > 0 && exponent > 0);
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754);
}
// IEEE 754 non-compliant floating point type
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa,
bool finite_values_only,
NanRepr nan_repr) {
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
......@@ -63,36 +63,121 @@ class ScalarType {
TORCH_CHECK(nan_repr != NAN_IEEE_754,
"use `float_IEEE754` constructor for floating point types that "
"follow IEEE 754 conventions");
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
return ScalarType(exponent, mantissa, true, 0, finite_values_only,
nan_repr);
}
int64_t const exponent; // size of the exponent field (0 for integer types)
int64_t const mantissa; // size of the mantissa field (size of the integer
uint8_t const exponent; // size of the exponent field (0 for integer types)
uint8_t const mantissa; // size of the mantissa field (size of the integer
// excluding the sign bit for integer types)
int64_t const bias; // stored values equal value + bias,
// used for quantized type
bool const signed_; // flag if the type supports negative numbers (i.e. has a
// sign bit)
int32_t const bias; // stored values equal value + bias,
// used for quantized type
// Extra Floating point info
bool const finite_values_only; // i.e. no +/-inf if true
NanRepr const nan_repr; // how NaNs are represented
// (not applicable for integer types)
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
bool is_signed() const { return signed_; }
bool is_integer() const { return exponent == 0; }
bool is_floating_point() const { return exponent > 0; }
bool is_ieee_754() const {
using Id = int64_t;
private:
// Field size in id
template <typename T_>
static constexpr size_t member_id_field_width() {
using T = std::decay_t<T_>;
return std::is_same_v<T, bool> ? 1 : sizeof(T) * 8;
}
template <typename Fn, typename Init, typename Member, typename... Rest>
static constexpr auto reduce_members_helper(Fn f, Init val, Member member,
Rest... rest) {
auto new_val = f(val, member);
if constexpr (sizeof...(rest) > 0) {
return reduce_members_helper(f, new_val, rest...);
} else {
return new_val;
};
}
template <typename Fn, typename Init>
constexpr auto reduce_members(Fn f, Init init) const {
// Should be in constructor order for `from_id`
return reduce_members_helper(f, init, exponent, mantissa, signed_, bias,
finite_values_only, nan_repr);
};
template <typename Fn, typename Init>
static constexpr auto reduce_member_types(Fn f, Init init) {
constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE);
return dummy_type.reduce_members(f, init);
};
static constexpr auto id_size_bits() {
return reduce_member_types(
[](int acc, auto member) -> int {
return acc + member_id_field_width<decltype(member)>();
},
0);
}
public:
// unique id for this scalar type that can be computed at compile time for
// c++17 template specialization this is not needed once we migrate to
// c++20 and can pass literal classes as template parameters
constexpr Id id() const {
static_assert(id_size_bits() <= sizeof(Id) * 8,
"ScalarType id is too large to be stored");
auto or_and_advance = [](std::pair<Id, uint32_t> result,
auto member) -> std::pair<Id, uint32_t> {
auto [id, bit_offset] = result;
auto constexpr bits = member_id_field_width<decltype(member)>();
return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1))
<< bit_offset,
bit_offset + bits};
};
return reduce_members(or_and_advance, std::pair<Id, uint32_t>{}).first;
}
// create a ScalarType from an id, for c++17 template specialization,
// this is not needed once we migrate to c++20 and can pass literal
// classes as template parameters
static constexpr ScalarType from_id(Id id) {
auto extract_and_advance = [id](auto result, auto member) {
using T = decltype(member);
auto [tuple, bit_offset] = result;
auto constexpr bits = member_id_field_width<T>();
auto extracted_val = static_cast<T>((int64_t(id) >> bit_offset) &
((uint64_t(1) << bits) - 1));
auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val));
return std::pair<decltype(new_tuple), int>{new_tuple, bit_offset + bits};
};
auto [tuple_args, _] = reduce_member_types(extract_and_advance,
std::pair<std::tuple<>, int>{});
return std::apply([](auto... args) { return ScalarType(args...); },
tuple_args);
}
constexpr int64_t size_bits() const {
return mantissa + exponent + is_signed();
}
constexpr bool is_signed() const { return signed_; }
constexpr bool is_integer() const { return exponent == 0; }
constexpr bool is_floating_point() const { return exponent > 0; }
constexpr bool is_ieee_754() const {
return is_floating_point() && finite_values_only == false &&
nan_repr == NAN_IEEE_754;
}
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
bool has_infs() const {
constexpr bool has_nans() const {
return is_floating_point() && nan_repr != NAN_NONE;
}
constexpr bool has_infs() const {
return is_floating_point() && finite_values_only == false;
}
bool has_bias() const { return bias != 0; }
constexpr bool has_bias() const { return bias != 0; }
private:
double _floating_point_max() const {
......@@ -132,7 +217,7 @@ class ScalarType {
return *reinterpret_cast<double*>(&double_raw);
}
std::variant<int64_t, double> _raw_max() const {
constexpr std::variant<int64_t, double> _raw_max() const {
if (is_floating_point()) {
return {_floating_point_max()};
} else {
......@@ -142,7 +227,7 @@ class ScalarType {
}
}
std::variant<int64_t, double> _raw_min() const {
constexpr std::variant<int64_t, double> _raw_min() const {
if (is_floating_point()) {
TORCH_CHECK(is_signed(),
"We currently assume all floating point types are signed");
......@@ -169,7 +254,7 @@ class ScalarType {
public:
// Max representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> max() const {
constexpr std::variant<int64_t, double> max() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_max());
......@@ -177,7 +262,7 @@ class ScalarType {
// Min representable value for this scalar type.
// (accounting for bias if there is one)
std::variant<int64_t, double> min() const {
constexpr std::variant<int64_t, double> min() const {
return std::visit(
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
_raw_min());
......@@ -216,7 +301,7 @@ class ScalarType {
}
}
bool operator==(ScalarType const& other) const {
constexpr bool operator==(ScalarType const& other) const {
return mantissa == other.mantissa && exponent == other.exponent &&
bias == other.bias && signed_ == other.signed_ &&
finite_values_only == other.finite_values_only &&
......@@ -229,6 +314,8 @@ class ScalarType {
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
// constructor at the same time (torch::CustomClassHolder does not have a
// constexpr destructor)
// See also:
// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
public:
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
......@@ -241,31 +328,90 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
using Self = ScalarTypeTorch;
using SelfPtr = c10::intrusive_ptr<Self>;
static void check_size_bits(int64_t size_bits, bool signed_) {
TORCH_CHECK(
size_bits <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"size_bits bit width is too large to be represented");
}
static void check_bias(int64_t bias) {
using Bias = decltype(std::declval<Self>().bias);
TORCH_CHECK(bias <= std::numeric_limits<Bias>::max() &&
bias >= std::numeric_limits<Bias>::min(),
"bias too large or small to be represented");
}
static void check_exponent(int64_t exponent) {
TORCH_CHECK(
exponent <=
std::numeric_limits<decltype(std::declval<Self>().exponent)>::max(),
"exponent bit width is too large to be represented");
}
static void check_mantissa(int64_t mantissa) {
TORCH_CHECK(
mantissa <=
std::numeric_limits<decltype(std::declval<Self>().mantissa)>::max(),
"mantissa bit width is too large to be represented");
}
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>(
ScalarType::int_(size_bits, bias.value_or(0)));
}
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
check_size_bits(size_bits, true);
check_bias(bias.value_or(0));
return c10::make_intrusive<Self>(
ScalarType::uint(size_bits, bias.value_or(0)));
}
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>(
ScalarType::float_IEEE754(exponent, mantissa));
}
static SelfPtr float_(int64_t exponent, int64_t mantissa,
bool finite_values_only, int64_t nan_repr) {
check_mantissa(mantissa);
check_exponent(exponent);
return c10::make_intrusive<Self>(ScalarType::float_(
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
}
// This needs to be implemented and throw a TypeError in order for
// PyTorch's opcheck to work on ops that use ScalarTypes.
int64_t len() const {
throw c10::TypeError("__len__ not implemented");
return 0;
}
// Serialize a ScalarType into a tuple of pairs. Where each pair
// is a (fieldname, value).
// For simplicity, we are just going to convert to a ScalarTypeId.
std::tuple<std::tuple<std::string, int64_t>> obj_flatten() const {
return {{"ScalarType", id()}};
}
// Deserialize a scalar type that has been serialized by obj_flatten,
// ostensibly from a tuple of (member name, value) pairs, but in reality
// just a ScalarTypeId.
static SelfPtr obj_unflatten(
std::tuple<std::tuple<std::string, int64_t>> const& flat_type) {
return c10::make_intrusive<Self>(
from_id(std::get<1>(std::get<0>(flat_type))));
}
template <typename T>
static void bind_readonly_property(torch::class_<Self>& cls,
std::string const& name, T Base::*field) {
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) {
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
return (self.get()->*field)();
} else {
......@@ -273,6 +419,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
auto getter_func = [field = std::move(field),
getter_func_helper = std::move(getter_func_helper)](
SelfPtr const& self) {
auto val = getter_func_helper(self);
// upconvert uint8_t, int32_t etc. to int64_t for python
if constexpr (std::is_integral_v<T>) {
return static_cast<int64_t>(val);
} else {
return val;
}
};
cls.def_property(name, getter_func);
}
......@@ -325,6 +483,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
self.get()->min());
});
bind_function(cls, "__len__", &ScalarTypeTorch::len);
bind_function(cls, "__str__", &Base::str);
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
return *self == *other;
......@@ -333,6 +492,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
return "ScalarType." + self.get()->str();
});
bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten);
bind_static_function(cls, "__obj_unflatten__",
&ScalarTypeTorch::obj_unflatten);
// Bind static functions (convenience constructors)
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
......@@ -341,6 +504,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
}
};
using ScalarTypeId = int64_t;
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
// "rust style" names generally following:
......@@ -380,4 +544,5 @@ static inline constexpr auto kHalf = kFE5M10;
static inline constexpr auto kFloat16 = kHalf;
static inline constexpr auto kBFloat16 = kFE8M7;
static inline constexpr auto kFloat16Id = kFloat16.id();
}; // namespace vllm
#pragma once
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define HOST_INLINE inline
#endif
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>)
return true;
else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template <class PointerType>
static constexpr auto get_logical_ptr(PointerType* ptr) {
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
return cute::subbyte_iterator<PointerType>(ptr);
} else {
return ptr;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}
}; // namespace cute
#pragma once
#include <torch/all.h>
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;
namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}
}; // namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}
}; // namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
return tensor.size(idx);
else
return int64_t(1);
});
return make_layout(shape, stride);
}
template <typename Stride>
static inline auto maybe_make_cute_layout(
c10::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}
//
// Torch Type to Cutlass Type (equivalent_cutlass_type)
//
template <typename T>
struct equivalent_cutlass_type {
using type = T;
};
template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
using type = cutlass::bfloat16_t;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
};
template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
\ No newline at end of file
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, class Enable = void>
struct VLLMCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType> {
using CollectiveOp = typename CollectiveBuilder<
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective
\ No newline at end of file
#pragma once
#include "cutlass/integer_subbyte.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed = false>
struct vllm_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
using Base = integer_subbyte<Bits, Signed>;
using Storage = typename Base::Storage;
using xint_t = typename Base::xint_t;
using Base::bits_mask_;
using Base::sign_mask_;
using Base::storage;
//
// Methods
//
/// No operation
vllm_biased_integer_subbyte() = default;
/// Conversion from integer type
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value)
: Base(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8
using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed>
struct sizeof_bits<vllm_biased_integer_subbyte<Bits, Bias, Signed>> {
static constexpr int value = Bits;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
import enum
from typing import Dict, Union
from cutlass_library import *
#
# Extend cutlass library with custom types, and missing values
#
class VLLMDataType(enum.Enum):
u4b8 = enum_auto()
u8b128 = enum_auto()
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedMixedInput = enum_auto()
TmaWarpSpecializedPingpongMixedInput = enum_auto()
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeNames, # type: ignore
**{
VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128",
}
}
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeTag, # type: ignore
**{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
}
}
VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
}
}
#pragma once
#include "cutlass/numeric_conversion.h"
#include "cutlass_extensions/vllm_custom_types.cuh"
#include "cutlass_extensions/cute_utils.cuh"
// this file extends:
// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h
// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t
// as well as adds interleaved numeric array converters for specific types.
// (interleaved numeric array converters can be more efficient for subbyte
// types)
namespace cutlass {
// InterleavedNumericArrayConverter is like NumericArrayConverter but also
// deinterleaves converted elements based on IlvBlkLayout, interleaving can
// make subbyte converts more efficient by allowing for efficient extraction
// of subbyte elements from a 32bit register.
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
class Enable = void>
struct InterleavedNumericArrayConverter {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
CUTE_INVALID_CONTROL_PATH(
"InterleavedNumericArrayConverter not implemented\n");
return {};
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
template <typename IlvBlkLayout, typename T, typename S, int N,
FloatRoundStyle Round>
struct InterleavedNumericArrayConverter<
IlvBlkLayout, T, S, N, Round,
std::enable_if_t<is_identity_layout<IlvBlkLayout>()>> {
using Converter = NumericArrayConverter<T, S, N, Round>;
using result_type = typename Converter::result_type;
using source_type = typename Converter::source_type;
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return Converter::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// TODO (LucasWilkinson): Implement
// for Array<cutlass::float8_e4m3fn, N> <= Array<vllm_uint4b8_t, N>
// ....
template <typename RegConvert32bit, typename T, typename S, int N>
struct ArrayConverterPacked32Bit {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
using result_packed_8_t = Array<T, 8>;
using result_packed_4_t = Array<T, 4>;
using result_packed_2_t = Array<T, 2>;
using src_packed_8_t = Array<S, 8>;
using src_packed_4_t = Array<S, 4>;
using src_packed_2_t = Array<S, 2>;
static_assert(N % 2 == 0, "N must be a multiple of 2");
static_assert(cutlass::sizeof_bits_v<S> >= 4); // TODO: add 16 packed sources
static_assert(32 % cutlass::sizeof_bits_v<S> == 0);
static constexpr auto src_elems_per_32bit_reg =
32 / cutlass::sizeof_bits_v<S>;
// Maybe not Valid. ScalarConverter will not actually work unless
// NumericConverter<T, S, Round> is implemented. However it won't be used
// anyways since we assert N % 2 == 0, just here for compliance with
// VectorizedConverter.
using ScalarConverter = NumericConverter<T, S>;
template <typename PackedSrc>
CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) {
if constexpr (sizeof(PackedSrc) == 1) {
return static_cast<uint32_t>(reinterpret_cast<const uint8_t&>(source));
} else if constexpr (sizeof(PackedSrc) == 2) {
return static_cast<uint32_t>(reinterpret_cast<const uint16_t&>(source));
} else {
static_assert(sizeof(PackedSrc) == 4);
return reinterpret_cast<const uint32_t&>(source);
}
}
// The core converter uses bit tricks to construct a known FP16 number, then
// does a subtraction in FP16 for the final result.
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(PackedSrcType::kElements == PackedResultType::kElements);
static_assert(PackedResultType::kElements == 2 ||
PackedResultType::kElements == 4 ||
PackedResultType::kElements == 8,
"Invalid PackedResultType must be 2, 4 or 8.");
static_assert(std::is_same_v<typename PackedSrcType::Element, S>);
static_assert(std::is_same_v<typename PackedResultType::Element, T>);
return RegConvert32bit::template convert<PackedResultType>(to_reg(source));
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
ArrayConverterPacked32Bit<RegConvert32bit,
typename result_type::Element,
typename source_type::Element, N>;
if constexpr (src_elems_per_32bit_reg >= 8) {
detail::VectorizedConverter::convert<
ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t,
src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source);
} else if constexpr (src_elems_per_32bit_reg >= 4) {
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
} else {
detail::VectorizedConverter::convert<ConverterType, result_packed_2_t,
src_packed_2_t>(result, source);
}
return result;
}
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
// Below constructs the following temporary:
// fp16s_01 = {0x00, i4_01, 0x00, i4_01}
// fp16s_23 = {0x00, i4_23, 0x00, i4_23}
// fp16s_45 = {0x00, i4_45, 0x00, i4_45}
// fp16s_67 = {0x00, i4_67, 0x00, i4_67}
// We use inline asm instead of __byte_perm intrinsic since we don't want
// the documented (& 0x7) on the index. NVCC might be able to optimize it
// out since the index is a constexpr, but we choose to be safe about it
// here.
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
static_assert(RegArray::kElements <= 4,
"Too many inputs for F16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a fp16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the FP16 to the correct value for the
// FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)},
// where x1 in the high nibble and x0 is the low nibble then using hfma
// to subtract 1032 from that
// The AND does the following:
// 1) Clear the set bits for the int4 we will ignore.
// We use lop3 so that we can use 1 instruction for AND and XOR.
static constexpr uint32_t xor_mask = 0x64006400;
static constexpr uint32_t and_mask = 0xFFF0FF0F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 hfmas that do the following:
// {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032}
// = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032}
static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032}
static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1}
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, vllm_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032}
// For high nibble:
// {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16}
// - {72, 72}
static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::half_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<uint4_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t xor_mask = 0x64006400;
for (int ii = 0; ii < RegArray::kElements; ii += 2) {
auto src_ = src >> (4 * (ii));
r[ii + 0] = src_;
r[ii + 1] = src_;
static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
static constexpr uint32_t high_nib_mask = 0x00F000F0;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 1])
: "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut));
// For low nibble:
// {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024}
// For high nibble:
// {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64}
static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024}
static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16}
static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]);
fp16x2_val =
__hsub2(fp16x2_val, reinterpret_cast<const half2&>(low_nib_bias));
}
{
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]);
fp16x2_val = __hfma2(fp16x2_val,
reinterpret_cast<const half2&>(high_nib_scale),
reinterpret_cast<const half2&>(high_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::half_t, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::half_t, vllm_uint8b128_t, N, Round> {
using result_type = Array<cutlass::half_t, N>;
using source_type = Array<vllm_uint8b128_t, N>;
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
// Hold output FP16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t const prmt_indices[2] = {0x5150, 0x5352};
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile("prmt.b32 %0,%1,%2,%3;\n"
: "=r"(r[ii])
: "r"(src), "n"(start_byte_for_fp16),
"r"(prmt_indices[ii]));
}
// -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes
static constexpr uint32_t bias_rep = 0x64806480;
const half2& bias = reinterpret_cast<const half2&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hsub2(fp16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::float, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<float, vllm_uint8b128_t, N, Round> {
using result_type = Array<float, N>;
using source_type = Array<vllm_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
PackedResultType r;
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of
// u8x4 source and stores the result in r (without introducing extra
// cvt.u32.u8 instruction)
uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653};
uint32_t* result_as_int = reinterpret_cast<uint32_t*>(&r);
for (int ii = 0; ii < PackedResultType::kElements; ++ii) {
result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]);
// Subtract the magic number 0x4B000000 from tmp in floating-point
// arithmetic to obtain final result
r[ii] -= (8388608.f + 128.f); // fold in -128 bias
}
return r;
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
static FloatRoundStyle const round_style = Round;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) {
// Hold output BF16s in reg. We need 1 reg for every 2 elements
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
uint32_t src_reg_shifted = src_reg >> 4;
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for uint4b8_t -> BF16 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii]));
}
// Since the stored 4bit values are biased by 8 we get stored_val = (x+8)
// we are trying to construct x and a BF16 value
// The below XOR does the following:
// 1) Sets the exponent bits of the BF16 to the correct value for the
// BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)}
// and subtracting 136 to get {x1, x0}
static constexpr uint32_t xor_mask = 0x43004300;
static constexpr uint32_t and_mask = 0x000F000F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
// For each operand, computes:
// r[i] = (r[i] & and_mask) ^ xor_mask
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
// We will issue 2 bfmas that do the following:
// high BF16:
// hi_bf16 - 136, lo_bf16 - 136
// This is the BF16 {136, 136} represented as an integer.
static constexpr uint32_t bias_rep = 0x43084308;
const __nv_bfloat162& bias =
reinterpret_cast<const __nv_bfloat162&>(bias_rep);
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
__nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
bf16x2_val = __hsub2(bf16x2_val, bias);
}
return reinterpret_cast<PackedResultType&>(r);
}
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint4b8_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, vllm_uint4b8_t, N,
Round, void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint4b8_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii + 0])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136}
static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<uint4_t, N>
// for IlvdLayout: (2, 4):(4, 1)
template <FloatRoundStyle Round, int N>
struct InterleavedNumericArrayConverter<Layout<Shape<_2, _4>, Stride<_4, _1>>,
cutlass::bfloat16_t, uint4_t, N, Round,
void> {
using IlvdLayout = Layout<Shape<_2, _4>, Stride<_4, _1>>;
static_assert(N % size(IlvdLayout{}) == 0);
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<uint4_t, N>;
private:
struct RegConvert {
template <typename PackedResultType>
CUTLASS_DEVICE static PackedResultType convert(uint32_t src) {
using RegArray =
cutlass::AlignedArray<uint32_t, PackedResultType::kElements / 2,
sizeof(PackedResultType)>;
RegArray r;
static_assert(PackedResultType::kElements <= size(IlvdLayout{}));
static constexpr uint32_t or_mask = 0x43004300;
// Unlike float16 where the mantissa is large enough to contain 2
// nibbles, bfloat16 can only fit one, so we can only convert one
// nibble at a time
for (int ii = 0; ii < RegArray::kElements; ++ii) {
r[ii] = src >> (4 * ii);
static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa;
static constexpr uint32_t low_nib_mask = 0x000F000F;
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut));
// For low nibble:
// {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128}
static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128}
{
__nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]);
fp16x2_val =
__hsub2(fp16x2_val,
reinterpret_cast<const __nv_bfloat162&>(low_nib_bias));
}
}
return reinterpret_cast<PackedResultType&>(r);
};
};
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
return ArrayConverterPacked32Bit<RegConvert, typename result_type::Element,
typename source_type::Element,
N>::convert(source);
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
// for Array<cutlass::bfloat16_t, N> <= Array<vllm_uint8b128_t, N>
template <FloatRoundStyle Round, int N>
struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint8b128_t, N, Round> {
using result_type = Array<cutlass::bfloat16_t, N>;
using source_type = Array<vllm_uint8b128_t, N>;
static FloatRoundStyle const round_style = Round;
private:
using result_packed_4_t = Array<cutlass::bfloat16_t, 4>;
using result_packed_2_t = Array<cutlass::bfloat16_t, 2>;
using src_packed_4_t = Array<vllm_uint8b128_t, 4>;
using src_packed_2_t = Array<vllm_uint8b128_t, 2>;
// Not Valid, not supported, only here to satisfy the interface and to avoid
// a compile error. ScalarConverter will not actually work until
// NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round> is
// implemented
using ScalarConverter =
NumericConverter<cutlass::bfloat16_t, vllm_uint8b128_t, Round>;
template <typename PackedResultType, typename PackedSrcType>
CUTLASS_DEVICE static PackedResultType packed_convert(
PackedSrcType const& source) {
static_assert(
(platform::is_same<PackedSrcType, src_packed_2_t>::value &&
platform::is_same<PackedResultType, result_packed_2_t>::value) ||
(platform::is_same<PackedSrcType, src_packed_4_t>::value &&
platform::is_same<PackedResultType, result_packed_4_t>::value),
"Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private "
"convert dispatch.");
NumericArrayConverter<float, vllm_uint8b128_t, PackedResultType::kElements,
Round>
convert_uint8_to_f32;
Array<float, PackedResultType::kElements> tmp =
convert_uint8_to_f32(source);
NumericArrayConverter<cutlass::bfloat16_t, float,
PackedResultType::kElements, Round>
convert_f32_to_bf16_;
return convert_f32_to_bf16_(tmp);
}
friend class detail::VectorizedConverter;
public:
CUTLASS_DEVICE
static result_type convert(source_type const& source) {
result_type result;
using ConverterType =
NumericArrayConverter<typename result_type::Element,
typename source_type::Element, N, Round>;
detail::VectorizedConverter::convert<ConverterType, result_packed_4_t,
src_packed_4_t, result_packed_2_t,
src_packed_2_t>(result, source);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const& s) const { return convert(s); }
};
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
......@@ -3,13 +3,16 @@
#include <c10/cuda/CUDAGuard.h>
#include "dispatch_utils.h"
#include "reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
......@@ -31,7 +34,11 @@ __global__ void rms_norm_kernel(
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......@@ -228,12 +235,11 @@ fused_add_rms_norm_kernel(
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......@@ -268,12 +274,11 @@ fused_add_rms_norm_kernel(
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......
......@@ -105,12 +105,12 @@ void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const torch::Tensor& codebook_partition_sizes,
const std::vector<int64_t>& codebook_partition_sizes,
const std::optional<torch::Tensor>& bias);
torch::Tensor aqlm_dequant(const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& codebook_partition_sizes);
torch::Tensor aqlm_dequant(
const torch::Tensor& codes, const torch::Tensor& codebooks,
const std::vector<int64_t>& codebook_partition_sizes);
torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel,
torch::Tensor _scaling_factors, torch::Tensor _zeros,
......@@ -125,6 +125,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
......@@ -149,6 +168,15 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t type, int64_t row);
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
......@@ -161,6 +189,14 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
......
......@@ -6,13 +6,17 @@
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#include "../dispatch_utils.h"
#include "../reduction_utils.cuh"
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
using __nv_bfloat16 = __hip_bfloat16;
using __nv_bfloat162 = __hip_bfloat162;
......@@ -34,7 +38,11 @@ __global__ void rms_norm_kernel(
const float x = (float)input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......@@ -231,12 +239,11 @@ fused_add_rms_norm_kernel(
variance += temp.sum_squares();
residual_v[id] = temp;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......@@ -271,12 +278,11 @@ fused_add_rms_norm_kernel(
variance += x * x;
residual[blockIdx.x * hidden_size + idx] = z;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
if (num_tokens < 256) {
variance = blockReduceSum<float, 1024>(variance);
} else
variance = blockReduceSum<float, 256>(variance);
using BlockReduce = cub::BlockReduce<float, 1024>;
__shared__ typename BlockReduce::TempStorage reduceStore;
variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment