Unverified Commit 38364a7e authored by Kyle Sayers's avatar Kyle Sayers Committed by GitHub
Browse files

[Sparse24] [Deprecation] Remove Sparse24 CT integration and kernels (#36799)


Signed-off-by: default avatarKyle Sayers <kylesayrs@gmail.com>
parent fafe76b4
# For vllm script, with -t option (tensor parallel size).
# bash ./run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM -b "auto" -t 2
model_name: "nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM"
tasks:
- name: "gsm8k"
metrics:
- name: "exact_match,strict-match"
value: 0.6353
- name: "exact_match,flexible-extract"
value: 0.637
limit: null
num_fewshot: null
...@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -343,7 +343,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu"
"csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu"
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
"csrc/cutlass_extensions/common.cpp" "csrc/cutlass_extensions/common.cpp"
"csrc/quantization/w8a8/fp8/per_token_group_quant.cu" "csrc/quantization/w8a8/fp8/per_token_group_quant.cu"
"csrc/quantization/w8a8/int8/per_token_group_quant.cu") "csrc/quantization/w8a8/int8/per_token_group_quant.cu")
...@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -619,31 +618,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
endif() endif()
#
# 2:4 Sparse Kernels
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper).
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
"if you intend on running FP8 sparse quantized models on Hopper.")
else()
message(STATUS "Not building sparse_scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
# The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require # The nvfp4_scaled_mm_sm120 kernels for Geforce Blackwell SM120 require
# CUDA 12.8 or later # CUDA 12.8 or later
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0)
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import copy
import itertools
import pickle as pkl
import time
from collections.abc import Callable, Iterable
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from utils import make_rand_sparse_tensors
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.utils.argparse_utils import FlexibleArgumentParser
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
DEFAULT_TP_SIZES = [1]
# bench
def bench_fn(
label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs
) -> TMeasurement:
min_run_time = 1
globals = {
"args": args,
"kwargs": kwargs,
"fn": fn,
}
return TBenchmark.Timer(
stmt="fn(*args, **kwargs)",
globals=globals,
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def bench_int8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.int8
b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl - bfloat16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16),
b.to(dtype=torch.bfloat16),
)
)
# pytorch impl - float16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp16_fp16_fp16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.float16),
b.to(dtype=torch.float16),
)
)
# cutlass impl
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_mm_bias",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass sparse impl
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass sparse with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_i8_i8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
return timers
def bench_fp8(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
assert dtype == torch.float8_e4m3fn
b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16)
out = ops.cutlass_scaled_sparse_mm(
a, b_compressed, e, scale_a, scale_b, torch.bfloat16
)
out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16)
if not torch.allclose(out, out_ref):
print("Incorrect results")
print(out)
print(out_ref)
else:
print("Correct results")
timers = []
# pytorch impl w. bf16
timers.append(
bench_fn(
label,
sub_label,
"pytorch_bf16_bf16_bf16_matmul-no-scales",
torch.mm,
a.to(dtype=torch.bfloat16, device="cuda"),
b.to(dtype=torch.bfloat16, device="cuda"),
)
)
# pytorch impl: bf16 output, without fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
)
)
# pytorch impl: bf16 output, with fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_bf16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.bfloat16,
use_fast_accum=True,
)
)
# pytorch impl: fp16 output, without fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
)
)
# pytorch impl: fp16 output, with fp8 fast accum
timers.append(
bench_fn(
label,
sub_label,
"pytorch_fp8_fp8_fp16_scaled_mm_fast_accum",
torch._scaled_mm,
a,
b,
scale_a=scale_a,
scale_b=scale_b,
out_dtype=torch.float16,
use_fast_accum=True,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_mm",
ops.cutlass_scaled_mm,
a,
b,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: bf16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
)
)
# cutlass impl: fp16 output
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
)
)
# cutlass impl: bf16 output, with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.bfloat16,
bias,
)
)
# cutlass impl: fp16 output, with bias
timers.append(
bench_fn(
label,
sub_label,
"cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias",
ops.cutlass_scaled_sparse_mm,
a,
b_compressed,
e,
scale_a,
scale_b,
torch.float16,
bias.to(dtype=torch.float16),
)
)
return timers
def bench(
dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str
) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label)
if dtype == torch.float8_e4m3fn:
return bench_fp8(dtype, m, k, n, label, sub_label)
raise ValueError(
f"Unsupported dtype {dtype}: should be one of torch.int8, torch.float8_e4m3fn."
)
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(
dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]]
) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})")
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "int8":
return torch.int8
if dt == "fp8":
return torch.float8_e4m3fn
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Cutlass GEMM.
To run square GEMMs:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/cutlass_benchmarks/sparse_benchmarks.py --dtype fp8 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['int8', 'fp8']",
)
subparsers = parser.add_subparsers(dest="cmd")
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument(
"--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES
)
model_parser.add_argument(
"--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES
)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
...@@ -5,8 +5,6 @@ ...@@ -5,8 +5,6 @@
import torch import torch
import vllm._custom_ops as ops
def to_fp8(tensor: torch.Tensor) -> torch.Tensor: def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
...@@ -39,49 +37,3 @@ def make_rand_tensors( ...@@ -39,49 +37,3 @@ def make_rand_tensors(
return to_fp8(a), to_fp8(b) return to_fp8(a), to_fp8(b)
raise ValueError("unsupported dtype") raise ValueError("unsupported dtype")
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda") * 5
b = torch.randn((n, k), device="cuda").t() * 5
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
...@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, ...@@ -285,16 +285,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
std::optional<torch::Tensor> const& azp, std::optional<torch::Tensor> const& azp,
std::optional<torch::Tensor> const& bias); std::optional<torch::Tensor> const& bias);
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability);
void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a);
std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func( std::tuple<torch::Tensor, torch::Tensor> scaled_fp4_quant_func(
torch::Tensor const& input, torch::Tensor const& input_scale, torch::Tensor const& input, torch::Tensor const& input_scale,
bool is_sf_swizzled_layout); bool is_sf_swizzled_layout);
......
#pragma once
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
#include "cutlass/numeric_conversion.h"
#include "cutlass/transform/device/transform_universal_adapter.hpp"
#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
// clang-format on
using namespace cute;
using namespace vllm;
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
/// Make A structured sparse by replacing elements with 0 and compress it
template <typename Gemm>
CompressorResult cutlass_sparse_compress(torch::Tensor const& a) {
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 || a.dtype() == torch::kFloat8_e4m3fn ||
a.dtype() == torch::kFloat16 || a.dtype() == torch::kBFloat16);
TORCH_CHECK(a.dim() == 2)
// Check for strides and alignment
TORCH_CHECK(a.stride(0) % 4 == 0) // Required for semi-structured sparsity
TORCH_CHECK(a.stride(1) == 1)
using GemmKernel = typename Gemm::KernelType;
using ElementA = typename Gemm::ElementAB;
using ElementE = typename GemmKernel::CollectiveMainloop::ElementE;
int m = a.size(0);
int k = a.size(1);
using ProblemShape = typename GemmKernel::ProblemShape;
ProblemShape prob_shape{m, 1, k, 1};
int64_t lda = a.stride(0);
using StrideA = Stride<int64_t, Int<1>, int64_t>;
StrideA a_stride{lda, Int<1>{}, 0};
using CompressorUtility = typename Gemm::CompressorUtility;
CompressorUtility compressor_utility(prob_shape, a_stride);
// Allocate buffers for the metadata E and the compressed matrix A
int ME = compressor_utility.get_metadata_m_physical();
int KE = compressor_utility.get_metadata_k_physical();
int MC = compressor_utility.get_tensorA_m_physical();
int KC = compressor_utility.get_tensorA_k_physical();
auto const a_meta_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto const a_nzs_options =
torch::TensorOptions().dtype(a.dtype()).device(a.device());
auto a_meta = torch::zeros({ME, KE}, a_meta_options);
auto a_nzs = torch::zeros({MC, KC}, a_nzs_options);
auto a_ptr = static_cast<ElementA*>(a.data_ptr());
auto a_nzs_ptr = static_cast<ElementA*>(a_nzs.data_ptr());
auto a_meta_ptr = static_cast<ElementE*>(a_meta.data_ptr());
cutlass::KernelHardwareInfo hw_info;
hw_info.device_id = a.device().index();
hw_info.sm_count =
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
hw_info.device_id);
using Compressor = typename Gemm::Compressor;
typename Compressor::Arguments arguments{
prob_shape, {a_ptr, a_stride, a_nzs_ptr, a_meta_ptr}, {hw_info}};
Compressor compressor_op;
size_t workspace_size = Compressor::get_workspace_size(arguments);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
CUTLASS_CHECK(compressor_op.can_implement(arguments));
CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.data_ptr()));
CUTLASS_CHECK(compressor_op.run());
CUDA_CHECK(cudaDeviceSynchronize());
return {a_meta, a_nzs};
}
#endif
// clang-format will break include orders
// clang-format off
#include <cudaTypedefs.h>
#if defined CUDA_VERSION && CUDA_VERSION >= 12020
#include "sparse_scaled_mm_c3x.cuh"
// clang-format on
using namespace cute;
using namespace vllm;
struct GemmCallerTraits {
using return_type = void;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_gemm_caller<GemmConfig>(std::forward<Args>(args)...);
}
};
struct GemmCompressorTraits {
using return_type = CompressorResult;
template <typename GemmConfig, typename... Args>
static return_type invoke(Args&&... args) {
return cutlass_sparse_compress<GemmConfig>(std::forward<Args>(args)...);
}
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_fp8_dispatch(
uint32_t m, uint32_t n, Args&&... args) {
static_assert(std::is_same_v<InType, cutlass::float_e4m3_t>);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM256 =
typename sm90_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM512 =
typename sm90_fp8_config_M512<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm1 =
typename sm90_fp8_config_1<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm2 =
typename sm90_fp8_config_2<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm3 =
typename sm90_fp8_config_3<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm4 =
typename sm90_fp8_config_4<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm5 =
typename sm90_fp8_config_5<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm6 =
typename sm90_fp8_config_6<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm7 =
typename sm90_fp8_config_7<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemm8 =
typename sm90_fp8_config_8<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
if (mp2 <= 64) {
if (n == 28672) {
return DispatchFunc::template invoke<Cutlass3xGemm2>(
std::forward<Args>(args)...);
} else if (n == 4096 || n == 6144) {
return DispatchFunc::template invoke<Cutlass3xGemm1>(
std::forward<Args>(args)...);
}
} else if (mp2 <= 128) {
if (n == 4096) {
return DispatchFunc::template invoke<Cutlass3xGemm3>(
std::forward<Args>(args)...);
} else if (n == 28672) {
return DispatchFunc::template invoke<Cutlass3xGemm5>(
std::forward<Args>(args)...);
} else if (n == 6144) {
return DispatchFunc::template invoke<Cutlass3xGemm4>(
std::forward<Args>(args)...);
}
} else if (mp2 <= 256) {
if (n == 4096) {
return DispatchFunc::template invoke<Cutlass3xGemm6>(
std::forward<Args>(args)...);
} else if (n == 28672) {
return DispatchFunc::template invoke<Cutlass3xGemm8>(
std::forward<Args>(args)...);
} else if (n == 6144) {
return DispatchFunc::template invoke<Cutlass3xGemm7>(
std::forward<Args>(args)...);
}
} else {
if (n == 6144 || n == 28672) {
return DispatchFunc::template invoke<Cutlass3xGemm8>(
std::forward<Args>(args)...);
} else if (n == 4096) {
return DispatchFunc::template invoke<Cutlass3xGemm7>(
std::forward<Args>(args)...);
}
}
// Otherwise the default heuristic
if (mp2 <= 64) {
// n in [1, 64]
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
std::forward<Args>(args)...);
} else if (mp2 <= 128) {
// n in (64, 128]
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
std::forward<Args>(args)...);
} else if (mp2 <= 256) {
// n in (128, 256]
return DispatchFunc::template invoke<Cutlass3xGemmM256>(
std::forward<Args>(args)...);
} else {
// n in (256, inf)
return DispatchFunc::template invoke<Cutlass3xGemmM512>(
std::forward<Args>(args)...);
}
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_16bit_dispatch(
uint32_t m, uint32_t n, Args&&... args) {
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
std::forward<Args>(args)...);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename DispatchFunc, typename... Args>
typename DispatchFunc::return_type cutlass_gemm_sm90_int8_dispatch(
uint32_t m, uint32_t n, Args&&... args) {
static_assert(std::is_same_v<InType, int8_t>);
using Cutlass3xGemmDefault =
typename sm90_config_default<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
bool const is_small_n = n < 8192;
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return DispatchFunc::template invoke<Cutlass3xGemmM32NSmall>(
std::forward<Args>(args)...);
} else {
return DispatchFunc::template invoke<Cutlass3xGemmM32NBig>(
std::forward<Args>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return DispatchFunc::template invoke<Cutlass3xGemmM64>(
std::forward<Args>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return DispatchFunc::template invoke<Cutlass3xGemmM128>(
std::forward<Args>(args)...);
} else {
// m in (128, inf)
return DispatchFunc::template invoke<Cutlass3xGemmDefault>(
std::forward<Args>(args)...);
}
}
// Dispatch to GEMM implementations based on element types
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_sparse_mm_sm90_epilogue(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
EpilogueArgs&&... epilogue_args) {
uint32_t const m = out.size(0);
uint32_t const n = out.size(1);
// TODO: add dispatch functions to all of these
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(bt_nzs.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue,
GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue,
GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::half_t, Epilogue, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
} else if (a.dtype() == torch::kFloat16) {
TORCH_CHECK(bt_nzs.dtype() == torch::kFloat16);
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
Epilogue, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
} else { // a.dtype() == torch::kBFloat16
TORCH_CHECK(a.dtype() == torch::kBFloat16);
TORCH_CHECK(bt_nzs.dtype() == torch::kBFloat16);
TORCH_CHECK(out.dtype() == torch::kBFloat16);
return cutlass_gemm_sm90_16bit_dispatch<
cutlass::bfloat16_t, cutlass::bfloat16_t, Epilogue, GemmCallerTraits>(
m, n, out, a, bt_nzs, bt_meta,
std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
TORCH_CHECK(bt_meta.dtype() == torch::kUInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"CUTLASS scaled_mm bias dtype must match output dtype ",
out.dtype());
return cutlass_scaled_sparse_mm_sm90_epilogue<
c3x::ScaledEpilogueColumnBias>(out, a, bt_nzs, bt_meta, b_scales,
a_scales, *bias);
} else {
return cutlass_scaled_sparse_mm_sm90_epilogue<c3x::ScaledEpilogue>(
out, a, bt_nzs, bt_meta, b_scales, a_scales);
}
}
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a) {
// These m and n variables are fordispatching to different GEMM algorithms.
uint32_t const m = 1; // Set M to 1 for compression
uint32_t const n = a.size(1);
// Note: For correctness, the compressed format must be invariant in:
// - M, the flattened number of tokens
// - Whether output dtype is fp16 or bf16
// - CUTLASS epilogues
if (a.dtype() == torch::kInt8) {
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat8_e4m3fn) {
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else if (a.dtype() == torch::kFloat16) {
return cutlass_gemm_sm90_16bit_dispatch<
cutlass::bfloat16_t, cutlass::bfloat16_t, c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
} else {
TORCH_CHECK(a.dtype() == torch::kBFloat16,
"cutlass_sparse_compress only supports int8, fp8_e4m3, fp16, "
"and bf16 datatypes");
return cutlass_gemm_sm90_16bit_dispatch<cutlass::half_t, cutlass::half_t,
c3x::TrivialEpilogue,
GemmCompressorTraits>(m, n, a);
}
}
#endif
This diff is collapsed.
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "cutlass_extensions/common.hpp"
bool cutlass_sparse_scaled_mm_supported(int64_t cuda_device_capability) {
// sparse CUTLASS kernels need exactly hopper and are not forward compatible
// CUDA 12.2 and SM90 (Hopper)
#if defined CUDA_VERSION
return CUDA_VERSION >= 12020 && cuda_device_capability == 90;
#endif
return false;
}
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
void cutlass_scaled_sparse_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& e,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias);
using CompressorResult = std::tuple<torch::Tensor, torch::Tensor>;
CompressorResult cutlass_sparse_compress_sm90(torch::Tensor const& a);
#endif
void cutlass_scaled_sparse_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& bt_nzs,
torch::Tensor const& bt_meta,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
std::optional<torch::Tensor> const& bias) {
// Checks for conformality
TORCH_CHECK(a.dim() == 2 && bt_nzs.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(1) == bt_nzs.size(0) && bt_nzs.size(1) * 2 == a.size(1) &&
a.size(0) == c.size(0));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == bt_nzs.size(0));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && bt_nzs.stride(1) == 1 &&
c.stride(1) == 1); // Row-major
TORCH_CHECK(c.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(bt_nzs.stride(0) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == bt_nzs.size(0) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if (version_num == 90) {
cutlass_scaled_sparse_mm_sm90(c, a, bt_nzs, bt_meta, a_scales, b_scales,
bias);
return;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_sparse_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
std::vector<torch::Tensor> cutlass_sparse_compress(torch::Tensor const& a) {
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1); // Row-major
TORCH_CHECK(a.stride(0) % 8 == 0); // 8 Byte Alignment for Compression
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SPARSE_SCALED_MM_C3X && ENABLE_SPARSE_SCALED_MM_C3X
// We build for 9.0a which is not forward compatible, so restrict this to
// Hopper only
if (version_num == 90) {
std::vector<torch::Tensor> result_tensors;
auto [a_meta, a_nzs] = cutlass_sparse_compress_sm90(a);
result_tensors.push_back(std::move(a_nzs));
result_tensors.push_back(std::move(a_meta));
return result_tensors;
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_sparse_compress for a compute capability equal to "
"CUDA device capability: ",
version_num);
}
...@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -523,26 +523,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.impl("cutlass_scaled_mm_supports_block_fp8", ops.impl("cutlass_scaled_mm_supports_block_fp8",
&cutlass_scaled_mm_supports_block_fp8); &cutlass_scaled_mm_supports_block_fp8);
// Check if cutlass sparse scaled_mm is supported for CUDA devices of the
// given capability
ops.def(
"cutlass_sparse_scaled_mm_supported(int cuda_device_capability) -> bool");
ops.impl("cutlass_sparse_scaled_mm_supported",
&cutlass_sparse_scaled_mm_supported);
// CUTLASS sparse GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
ops.def(
"cutlass_scaled_sparse_mm(Tensor! out, Tensor a,"
" Tensor bt_nzs,"
" Tensor bt_meta, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_sparse_mm", torch::kCUDA, &cutlass_scaled_sparse_mm);
// CUTLASS sparse matrix compressor
ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]");
ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress);
// SM100 CUTLASS MLA decode // SM100 CUTLASS MLA decode
ops.def( ops.def(
"sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope," "sm100_cutlass_mla_decode(Tensor! out, Tensor! lse, Tensor q_nope,"
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for sparse cutlass kernels
Run `pytest tests/kernels/quantization/test_cutlass_2of4_sparse.py`.
"""
import pytest
import torch
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported,
)
from vllm.platforms import current_platform
CUDA_DEVICES = [
f"cuda:{i}" for i in range(1 if torch.accelerator.device_count() == 1 else 2)
]
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.bfloat16)
def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
return tensor.to(dtype=torch.float16)
def prune_to_2_4(tensor):
# Reshape tensor to [N, 4] where N is number of groups of 4
original_shape = tensor.shape
reshaped = tensor.reshape(-1, 4)
# Get indices of top 2 absolute values in each group of 4
_, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)
# Create binary mask
mask = torch.zeros_like(reshaped)
mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype))
# Apply mask and reshape back
pruned = reshaped * mask
# Turn all -0.0 to 0.0
pruned[pruned == -0.0] = 0.0
return pruned.reshape(original_shape)
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(
dtype: torch.dtype,
b: torch.Tensor,
b_compressed: torch.Tensor,
b_metadata: torch.Tensor,
):
# For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
# same dtype as its inputs. This line addresses that constraint while
# arbitrarily using bfloat16 for the int8/fp8 cases.
out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16
eye = torch.eye(b.shape[0], device="cuda", dtype=dtype)
eye_scale = torch.ones(1, device="cuda", dtype=torch.float32)
b_decomp = ops.cutlass_scaled_sparse_mm(
eye, b_compressed, b_metadata, eye_scale, eye_scale, out_dtype=out_dtype
)
torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)
def make_rand_sparse_tensors(
dtype: torch.dtype, m: int, n: int, k: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
a = torch.randn((m, k), device="cuda")
b = torch.randn((n, k), device="cuda").t()
if dtype == torch.int8:
# ensure A and B aren't all zeros after rounding
a = a * 5.0
b = b * 5.0
b = prune_to_2_4(b.t()).t()
if dtype == torch.int8:
a, b = to_int8(a), to_int8(b)
elif dtype == torch.float8_e4m3fn:
a, b = to_fp8(a), to_fp8(b)
elif dtype == torch.float16:
a, b = to_fp16(a), to_fp16(b)
elif dtype == torch.bfloat16:
a, b = to_bf16(a), to_bf16(b)
else:
raise ValueError("unsupported dtype")
b_compressed, e = ops.cutlass_sparse_compress(b.t())
check_compress_decompress_invariance(dtype, b, b_compressed, e)
# Compressed B, Metadata, Original A, B
return b_compressed, e, a, b
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():
big_m = 1024
m, n, k = 512, 512, 512
# Create tensors
b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, big_m, n, k)
a = whole_a[0:m, 0:k]
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=torch.bfloat16
)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=torch.bfloat16)
torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)
MNK_FACTORS = [
(1, 256, 128),
(1, 16384, 1024),
(1, 24576, 512),
(16, 256, 512),
(16, 16384, 128),
(16, 24576, 4096),
(32, 8192, 4096),
(32, 16384, 4096),
(33, 1024, 1024),
(33, 8192, 128),
(64, 2048, 512),
(64, 16384, 1024),
(100, 8192, 512),
(128, 32768, 4096),
(256, 4096, 4096),
(512, 256, 1024),
(512, 8192, 4096),
(512, 16384, 128),
(512, 24576, 128),
]
# Test working with a subset of A and B for sparse matmul
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_gemm(
m: int, k: int, n: int, dtype: type[torch.dtype], use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)
bias = torch.rand((n,), device="cuda", dtype=dtype) if use_bias else None
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=dtype, bias=bias
)
baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype=dtype, bias=bias)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(
not current_platform.has_device_capability(89),
reason="FP8 is not supported on this GPU type.",
)
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse CUTLASS is not supported on this GPU type.",
)
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(
m: int, n: int, k: int, per_act_token: bool, per_out_ch: bool, use_bias: bool
):
# Create tensors
b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32)
scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32)
out_dtype = torch.bfloat16
bias = torch.rand((n,), device="cuda", dtype=out_dtype) * 10 if use_bias else None
out = ops.cutlass_scaled_sparse_mm(
a, b_comp, e, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
baseline = baseline_scaled_mm(
a, b, scale_a, scale_b, out_dtype=out_dtype, bias=bias
)
torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)
...@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType ...@@ -12,7 +12,6 @@ from compressed_tensors.quantization import QuantizationType
from tests.models.utils import check_logprobs_close from tests.models.utils import check_logprobs_close
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensors24,
CompressedTensorsLinearMethod, CompressedTensorsLinearMethod,
CompressedTensorsW4A4Fp4, CompressedTensorsW4A4Fp4,
CompressedTensorsW4A8Fp8, CompressedTensorsW4A8Fp8,
...@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8 ...@@ -27,9 +26,6 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
cutlass_fp4_supported, cutlass_fp4_supported,
) )
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
sparse_cutlass_supported,
)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
...@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner): ...@@ -362,283 +358,6 @@ def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
assert output assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"):
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
assert qkv_proj.scheme.input_quant.strategy == input_strategy
assert qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == format
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
@pytest.mark.skipif(
not current_platform.is_cuda() or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
"channel",
"token",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
"channel",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif(
not current_platform.is_cuda() or not current_platform.has_device_capability(90),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
"channel",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
"tensor",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
"tensor",
"tensor",
),
],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(
qkv_proj,
weight_strategy,
input_strategy,
format="sparse-24-bitmask",
)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
"channel",
"token",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
"tensor",
"tensor",
),
(
"nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
"tensor",
"token",
),
],
)
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
model, weight_strategy, input_strategy = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert qkv_proj.scheme.weights_dtype == torch.int8
_test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="2of4 Sparse is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4",
[("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "dense"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif(
not sparse_cutlass_supported(),
reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
"args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]
)
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
model = args_2of4
with vllm_runner(model, enforce_eager=True) as llm:
def check_model(model):
layer = model.model.layers[0]
qkv_proj = layer.self_attn.qkv_proj
assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
assert isinstance(qkv_proj.scheme, CompressedTensors24)
assert qkv_proj.scheme.weight_quant is None
assert qkv_proj.scheme.input_quant is None
assert not qkv_proj.scheme.quantized
assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map # noqa: E501
assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
assert sparsity_map.get("Linear").sparsity_structure == "2:4"
llm.apply_model(check_model)
output = llm.generate_greedy("Hello my name is", max_tokens=4)
print(output)
assert output
@pytest.mark.skipif( @pytest.mark.skipif(
not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform." not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
) )
......
...@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main ...@@ -20,8 +20,6 @@ compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
#compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main #compressed-tensors, mgoin/DeepSeek-Coder-V2-Lite-Instruct-FP8, main
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-FP8-Dynamic-testing, main, 90
compressed-tensors, nm-testing/SparseLlama-3.1-8B-gsm8k-pruned.2of4-W8A8-testing, main, 90
awq, casperhansen/mixtral-instruct-awq, main awq, casperhansen/mixtral-instruct-awq, main
awq_marlin, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
......
...@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp( ...@@ -876,10 +876,6 @@ def cutlass_scaled_mm_azp(
return out.view(*target_shape) return out.view(*target_shape)
def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_sparse_scaled_mm_supported(cuda_device_capability)
def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
if cuda_device_capability < 90 or cuda_device_capability >= 110: if cuda_device_capability < 90 or cuda_device_capability >= 110:
return False return False
...@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool: ...@@ -890,94 +886,6 @@ def cutlass_group_gemm_supported(cuda_device_capability: int) -> bool:
return False return False
def cutlass_sparse_compress(a: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compresses a sparse matrix for use with Cutlass sparse operations.
This function takes a dense tensor and compresses it into two components:
non-zero elements and metadata. The compressed representation is compatible
with Cutlass sparse kernels.
Args:
a (torch.Tensor):
The input tensor to be compressed. Must have one of the following data types:
- `torch.int8`
- `torch.float8_e4m3fn`
- `torch.bfloat16`
- `torch.float16`
Returns:
tuple[torch.Tensor, torch.Tensor]:
A tuple containing:
- `a_nzs` (torch.Tensor): A tensor containing non-zero elements of `a`.
- `a_meta` (torch.Tensor): A tensor containing metadata for the sparse representation.
Raises:
ValueError: If the compression operation fails.
Notes:
- The `a_meta` tensor has a data type of `torch.uint8`.
- Each metadata element encodes the sparsity of 4 non-zero elements (i.e., `elemsPerMetaElem = 4`).
- The shape of `a_nzs` is `(m, k // 2)`, where `m` and `k` are the dimensions of the input tensor.
- The shape of `a_meta` is `(m, k // 2 // elemsPerMetaElem)`.
"""
assert a.dtype in [torch.int8, torch.float8_e4m3fn, torch.bfloat16, torch.float16]
assert a.is_contiguous()
# a_meta.dtype: torch.uint8 so elemsPerMetaElem = 8b / 2b_per_nz = 4
elemsPerMetaElem = 4
assert a.shape[1] % (2 * elemsPerMetaElem) == 0
return torch.ops._C.cutlass_sparse_compress(a)
def cutlass_scaled_sparse_mm(
a: torch.Tensor,
bt_nzs: torch.Tensor,
bt_meta: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out_dtype: torch.dtype,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Performs a scaled sparse matrix multiplication using Cutlass.
Steps:
1. Create a dense matrix `a` of shape (m, k) on the CUDA device:
`a = torch.randn((m, k), device='cuda')`.
2. Create a dense matrix `b` of shape (k, n) on the CUDA device:
`b = torch.randn((k, n), device='cuda')`.
3. Prune matrix `b` to 2:4 sparsity along the specified dimension:
`b = prune_to_2_4(b, dim=0)`.
4. Compress the transposed sparse matrix `b.t()`:
`bt_nzs, bt_meta = cutlass_sparse_compress(b.t())`.
5. Perform sparse matrix multiplication using the compressed matrix,
applying scaling factors for `a` and `b`, and the output data type:
`out = cutlass_scaled_sparse_mm(a, bt_nzs, bt_meta, scale_a, scale_b, out_dtype)`.
Returns:
- The result of the scaled sparse matrix multiplication.
"""
assert bt_nzs.shape[0] % 16 == 0 and bt_nzs.shape[1] % 16 == 0
assert out_dtype is torch.bfloat16 or out_dtype is torch.float16
assert bias is None or bias.shape[0] == bt_nzs.shape[0] and bias.dtype == out_dtype
m = a.shape[0]
n = bt_nzs.shape[0]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_sparse_mm(
out, a, bt_nzs, bt_meta, scale_a, scale_b, bias
)
return out
def get_cutlass_moe_mm_data( def get_cutlass_moe_mm_data(
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
expert_offsets: torch.Tensor, expert_offsets: torch.Tensor,
......
...@@ -5,39 +5,16 @@ from collections.abc import Callable ...@@ -5,39 +5,16 @@ from collections.abc import Callable
from typing import Any from typing import Any
import torch import torch
from compressed_tensors import CompressionFormat, ModelCompressor
from compressed_tensors.quantization import ( from compressed_tensors.quantization import (
QuantizationArgs, QuantizationArgs,
QuantizationStrategy,
QuantizationType,
) )
from compressed_tensors.utils import combine_shards
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme, CompressedTensorsScheme,
) )
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
convert_to_channelwise,
sparse_cutlass_supported,
)
from vllm.model_executor.parameter import (
BasevLLMParameter,
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
__all__ = ["CompressedTensors24"] __all__ = ["CompressedTensors24"]
from vllm.platforms import current_platform
class CompressedTensors24(CompressedTensorsScheme): class CompressedTensors24(CompressedTensorsScheme):
def __init__( def __init__(
...@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -47,33 +24,11 @@ class CompressedTensors24(CompressedTensorsScheme):
input_quant: QuantizationArgs | None = None, input_quant: QuantizationArgs | None = None,
model_compression_config: dict[str, Any] | None = None, model_compression_config: dict[str, Any] | None = None,
): ):
self.quantized = quantized raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
self.weight_quant = weight_quant
self.input_quant = input_quant
model_compressor = ModelCompressor.from_compression_config(
model_compression_config
)
self.do_sparse_decompress = (
model_compressor is not None
and model_compressor.sparsity_config.format
== CompressionFormat.sparse_24_bitmask.value
)
if self.do_sparse_decompress:
self.model_compressor = model_compressor
if (
quantized
and input_quant is not None
and self._get_quant_dtype() == current_platform.fp8_dtype()
):
static = not input_quant.dynamic
g_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.quant_fp8 = QuantFP8(static, g_shape)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# Only cutlass 3.x kernels are implemented so far raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
return 90
def create_weights( def create_weights(
self, self,
...@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -85,164 +40,10 @@ class CompressedTensors24(CompressedTensorsScheme):
weight_loader: Callable, weight_loader: Callable,
**kwargs, **kwargs,
): ):
if not sparse_cutlass_supported(): raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
raise ValueError(
"Sparse CUTLASS not supported. vLLM must be built with "
"CUDA 12.2 or later to use this feature"
)
layer.logical_widths = output_partition_sizes
layer.input_size = input_size
layer.input_size_per_partition = input_size_per_partition
self.weights_dtype: torch.dtype = self._get_params_dtype(params_dtype)
# parameter to store uncompressed weight
weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.do_sparse_decompress:
assert all(
partition_size % 8 == 0 for partition_size in output_partition_sizes
), "All partitions must be divisible by 8 for "
"2:4 sparse compressed models"
shape = BasevLLMParameter(
data=torch.empty(2, 1, dtype=torch.int64),
weight_loader=weight_loader,
)
compressed_weight = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 2,
dtype=self.weights_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
bitmask = ModelWeightParameter(
data=torch.empty(
sum(output_partition_sizes),
input_size_per_partition // 8,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("shape", shape)
layer.register_parameter("compressed", compressed_weight)
layer.register_parameter("bitmask", bitmask)
# Check if quantized, not just 2:4 Sparse
if self.quantized:
if (
self.weight_quant
and self.weight_quant.strategy == QuantizationStrategy.CHANNEL.value
):
weight_scale = ChannelQuantScaleParameter(
data=torch.empty(
(sum(output_partition_sizes), 1), dtype=torch.float32
),
output_dim=0,
weight_loader=weight_loader,
)
else:
assert (
self.weight_quant
and self.weight_quant.strategy == QuantizationStrategy.TENSOR.value
)
weight_scale = PerTensorScaleParameter(
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("weight_scale", weight_scale)
# input quant will be non-none
if self.input_quant and not self.input_quant.dynamic:
# register input quant scale
assert self.input_quant.strategy == QuantizationStrategy.TENSOR.value
input_scale = BasevLLMParameter(
data=torch.empty(1, dtype=torch.float32),
weight_loader=weight_loader,
)
layer.register_parameter("input_scale", input_scale)
else:
# for sparse-only, pass in 1 for weight/input scales
weight_scale = torch.nn.Parameter(
data=torch.ones(1, dtype=torch.float32), requires_grad=False
)
input_scale = torch.nn.Parameter(
data=torch.ones(1, dtype=torch.float32), requires_grad=False
)
layer.register_parameter("input_scale", input_scale)
layer.register_parameter("weight_scale", weight_scale)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
""" raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
Compress weights after loading. Store compressed weight and meta
tensor
:post-condition: layer.w_compressed and layer.meta are
set to the compressed weight and meta tensor in the
format expected by the Cutlass kernels
:param layer: The layer with the weights to be processed
"""
if self.do_sparse_decompress:
layer.weight.data = self._decompress_bitmask_compressed_weight(
compressed=layer.compressed,
bitmask=layer.bitmask,
layer=layer,
)
# compressed and bitmask tensors
# are no longer needed after decompression
del layer.compressed
del layer.bitmask
# torch.compile workaround
if hasattr(layer, "input_scale"):
layer.input_scale = torch.nn.Parameter(
layer.input_scale.data, requires_grad=False
)
if self.weight_quant:
if self.weight_quant.strategy == QuantizationStrategy.TENSOR.value:
layer.weight_scale = torch.nn.Parameter(
convert_to_channelwise(
weight_scale=layer.weight_scale,
logical_widths=layer.logical_widths,
),
requires_grad=False,
)
else:
# torch.compile workaround
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.data, requires_grad=False
)
# Set all negative zero values to 0 prior to compression
if layer.weight.dtype.is_floating_point and layer.weight.dtype.itemsize >= 2:
layer.weight.data[layer.weight.data == -0.0] = 0.0
w_compressed, meta = ops.cutlass_sparse_compress(layer.weight.data)
layer.weight = torch.nn.Parameter(w_compressed, requires_grad=False)
layer.meta = torch.nn.Parameter(meta, requires_grad=False)
def apply_weights( def apply_weights(
self, self,
...@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme): ...@@ -250,143 +51,4 @@ class CompressedTensors24(CompressedTensorsScheme):
x: torch.Tensor, x: torch.Tensor,
bias: torch.Tensor | None = None, bias: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" raise NotImplementedError("Sparse24 models are no longer supported by vLLM")
Returns the output tensor for the layer with 2:4
sparse compressed weights, given the input tensor
and bias
:param layer: The layer with 2:4 sparse compressed
weights to be used for the computation
:param x: The input tensor to the layer
:param bias: The bias to be added to the output tensor
:return: The output tensor of the layer
"""
if self.quantized:
scale = getattr(layer, "input_scale", None)
if self.weights_dtype == torch.int8:
ops_output = ops.scaled_int8_quant(x, scale=scale)
q_input = ops_output[0]
input_scale = ops_output[1]
else:
assert self.weights_dtype == torch.float8_e4m3fn
q_input, input_scale = self.quant_fp8(x, scale=scale)
else:
# Not quantized, nothing to do with the input_scales, use as is
input_scale = layer.input_scale
q_input = x
out = ops.cutlass_scaled_sparse_mm(
a=q_input,
bt_nzs=layer.weight,
bt_meta=layer.meta,
scale_a=input_scale,
scale_b=layer.weight_scale,
out_dtype=x.dtype,
bias=bias,
)
assert out.is_contiguous()
return out
def _get_params_dtype(self, params_dtype: torch.dtype) -> torch.dtype:
if not self.quantized:
return params_dtype
return self._get_quant_dtype()
def _get_quant_dtype(self) -> torch.dtype:
assert self.quantized
assert self.weight_quant is not None
assert self.input_quant is not None
is_8_bits = self.weight_quant.num_bits == self.input_quant.num_bits == 8
if not is_8_bits:
raise ValueError("Cutlass only supports 8-bit quantization")
if (
self.weight_quant.type == QuantizationType.FLOAT
and self.input_quant.type == QuantizationType.FLOAT
):
return torch.float8_e4m3fn
if (
self.weight_quant.type == QuantizationType.INT
and self.input_quant.type == QuantizationType.INT
):
return torch.int8
raise ValueError("Quantization type not supported by Cutlass")
def _decompress_bitmask_compressed_weight(
self,
compressed: torch.Tensor,
bitmask: torch.Tensor,
layer: torch.nn.Module,
) -> torch.Tensor:
"""
Decompress a compressed 2:4 sparse weight tensor using the bitmask and
return the result.
This function also supports sharded decompression.
:param compressed: The 2:4 sparse weight tensor compressed using the
sparse-24-bitmask compressor. This is different from
`cutlass_sparse_compress` which uses a different scheme (2 bits for
every nonzero element that represent the coordinate within the block
of 4). The bitmask compression here uses a bitmask to indicate the
positions of non-zero elements.
:param bitmask: The 2:4 bitmask associated with the compressed weights,
representing the positions of non-zero elements in the compressed
tensor.
:param layer: The layer whose weights need to be processed after
loading.
:return: The decompressed 2:4 sparse weight tensor.
"""
sparsity_compressor = self.model_compressor.sparsity_compressor
def _process_split(
bitmask_compressed_weight: torch.Tensor,
shape,
bitmask: torch.Tensor,
) -> torch.Tensor:
weight_data = dict(
compressed=bitmask_compressed_weight,
shape=shape,
bitmask=bitmask,
)
return sparsity_compressor.decompress_weight(weight_data)
split_weights: list[torch.Tensor] = []
split_bitmask: list[torch.Tensor] = []
split_shape: list[tuple[int, int]] = []
if isinstance(layer, (QKVParallelLinear, MergedColumnParallelLinear)):
split_weights = torch.split(compressed, layer.logical_widths)
split_bitmask = torch.split(bitmask, layer.logical_widths)
split_shape = [
(out, layer.input_size_per_partition) for out in layer.logical_widths
]
if split_weights:
decompressed_shards = [
_process_split(compressed_weight, shape, bitmask)
for compressed_weight, shape, bitmask in zip(
split_weights, split_shape, split_bitmask
)
]
decompressed = combine_shards(decompressed_shards)
else:
decompressed = sparsity_compressor.decompress_weight(
dict(
compressed=compressed,
shape=(
layer.logical_widths[0],
layer.input_size_per_partition,
),
bitmask=bitmask,
)
)
return decompressed
...@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC): ...@@ -20,7 +20,7 @@ class CompressedTensorsScheme(ABC):
""" """
Get minimum device capability. Get minimum device capability.
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def create_weights(self, *args, **kwargs): def create_weights(self, *args, **kwargs):
...@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC): ...@@ -28,7 +28,7 @@ class CompressedTensorsScheme(ABC):
Weight creation for the particular scheme. Inputs to this function Weight creation for the particular scheme. Inputs to this function
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def apply_weights( def apply_weights(
...@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC): ...@@ -44,7 +44,7 @@ class CompressedTensorsScheme(ABC):
:param bias: bias parameter :param bias: bias parameter
""" """
raise NotImplementedError raise NotImplementedError()
@abstractmethod @abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module): def process_weights_after_loading(self, layer: torch.nn.Module):
...@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC): ...@@ -52,4 +52,4 @@ class CompressedTensorsScheme(ABC):
Called after weight loading is complete for any cleanup that Called after weight loading is complete for any cleanup that
needs to occur. needs to occur.
""" """
raise NotImplementedError raise NotImplementedError()
...@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops ...@@ -8,16 +8,6 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform from vllm.platforms import current_platform
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
return False
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_sparse_scaled_mm_supported(capability)
def cutlass_fp8_supported() -> bool: def cutlass_fp8_supported() -> bool:
if not current_platform.is_cuda(): if not current_platform.is_cuda():
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment