Unverified Commit 5288c06a authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Kernel] (1/N) Machete - Hopper Optimized Mixed Precision Linear Kernel (#7174)

parent b6f99a6f
...@@ -87,6 +87,9 @@ target/ ...@@ -87,6 +87,9 @@ target/
profile_default/ profile_default/
ipython_config.py ipython_config.py
# generated files
**/generated/**
# pyenv # pyenv
# For a library or package, you might want to ignore these files since the code is # For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in: # intended to run in multiple environments; otherwise, check them in:
......
...@@ -227,6 +227,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -227,6 +227,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"-gencode arch=compute_90a,code=sm_90a") "-gencode arch=compute_90a,code=sm_90a")
endif() endif()
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)
if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "Machete generation completed successfully.")
endif()
# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
# See comment above for scaled_mm_c3x (same if condition)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
${MACHETE_GEN_SOURCES}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
# Add pytorch binding
list(APPEND VLLM_EXT_SRC
csrc/quantization/machete/machete_pytorch.cu)
endif() endif()
define_gpu_extension_target( define_gpu_extension_target(
......
import argparse
import copy
import itertools
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales)
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
MarlinWorkspace)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
gptq_pack, pack_rows, quantize_weights)
from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils import FlexibleArgumentParser
DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"]
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]
def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor:
w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
w_q = w_q.t().contiguous().t() # make col major
return ops.machete_prepack_B(w_q, wtype)
def make_bench_tensors(
atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int,
k: int
) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor,
torch.tensor]]]:
assert wtype.is_integer(), "TODO: support floating point weights"
# we want to make sure that weights don't fit into L2 cache between runs so
# we construct enough weights to exceed L2 cache, which is 50mb on a H100
# so we target total weight size > 2*50mb
num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits))
a = torch.randn((m, k), device="cuda", dtype=atype) * 5
weights = [
torch.randn((k, n), device="cuda", dtype=atype)
for _ in range(num_weights)
]
quanitized_weights = [
quantize_weights(w, wtype, group_size) for w in weights
]
return a, quanitized_weights
# impl
# bench
def bench_fn(label: str, sub_label: str, description: str,
fn: Callable) -> TMeasurement:
min_run_time = 1
return TBenchmark.Timer(
stmt="fn()",
globals={
"fn": fn
},
label=label,
sub_label=sub_label,
description=description,
).blocked_autorange(min_run_time=min_run_time)
def loop_over_weights(
a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor,
torch.tensor, torch.tensor]],
fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor],
None]):
for w_ref, w_q, w_s, _ in weights:
fn(a, w_ref, w_q, w_s)
def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
m: int,
k: int,
n: int,
label: str,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"
weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp)
for w_ref, w_q, w_s, w_zp in weights]
timers = []
# pytorch impl
timers.append(
bench_fn(
label, sub_label, "torch.matmul", lambda: loop_over_weights(
a,
weights,
lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref),
)))
if benchmark_marlinv1:
w_ref = weights[0][0]
w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device)
sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device)
g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device)
def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor:
w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape)
return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape,
wtype.size_bits)
def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
return marlin_permute_scales(w_s, *w_ref.shape, group_size)
weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q),
marlinv1_permute_scales(w_s), w_zp)
for w_ref, w_q, w_s, w_zp in weights]
workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)
# marlinv1
timers.append(
bench_fn(
label, sub_label, "marlin_orig", lambda: loop_over_weights(
a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops.
gptq_marlin_gemm(a,
w_q,
w_s,
w_zp_empty,
g_idx,
sort_indices,
workspace.scratch,
wtype,
size_m=a.shape[0],
size_n=w_ref.shape[1],
size_k=w_ref.shape[0],
is_k_full=True))))
# machete
timers.append(
bench_fn(
label, sub_label, "machete_heuristic", lambda: loop_over_weights(
a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm(
a, w_q, wtype, b_scales=w_s, b_group_size=group_size))))
if sweep_schedules:
print("Finding best schedule for machete")
best = None
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
w_q,
wtype,
w_s,
b_group_size=group_size,
schedule=schedule)
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))
print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
best_schedule = schedule
print("Best schedule:", best_schedule)
timers.append(best)
return timers
# runner
def print_timers(timers: Iterable[TMeasurement]):
compare = TBenchmark.Compare(timers)
compare.print()
def run(dtype: torch.dtype, sweep_schedules: bool,
MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
scalar_types.uint4b8,
128,
m,
k,
n,
f"{dtype}-gemm",
f"MKN=({m}x{k}x{n})",
sweep_schedules=sweep_schedules)
print_timers(timers)
results.extend(timers)
return results
# output makers
def make_output(
data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
base_description: str,
timestamp=None,
):
print(f"== All Results {base_description} ====")
print_timers(data)
# pickle all the results
timestamp = int(time.time()) if timestamp is None else timestamp
with open(f"{base_description}-{timestamp}.pkl", "wb") as f:
pkl.dump(data, f)
# argparse runners
def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}")
def run_model_bench(args):
print("Benchmarking models:")
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size
KNs.append(KN)
return KNs
model_bench_data = []
models_tps = list(itertools.product(args.models, args.tp_sizes))
for model, tp_size in models_tps:
Ms = args.batch_sizes
KNs = model_shapes(model, tp_size)
MKNs = []
for m in Ms:
for k, n in KNs:
MKNs.append((m, k, n))
data = run(args.dtype, args.sweep_schedules, MKNs)
model_bench_data.append(data)
# Print all results
for data, model_tp in zip(model_bench_data, models_tps):
model, tp_size = model_tp
print(f"== Results {args.dtype} {model}-TP{tp_size} ====")
print_timers(data)
timestamp = int(time.time())
all_data = []
for d in model_bench_data:
all_data.extend(d)
# pickle all data
with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f:
pkl.dump(all_data, f)
if __name__ == "__main__":
def to_torch_dtype(dt):
if dt == "bfloat16":
return torch.bfloat16
if dt == "float16":
return torch.float16
raise ValueError("unsupported dtype")
parser = FlexibleArgumentParser(
description="""
Benchmark Machete GEMM.
To run square GEMMs:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64
To run constant N and K and sweep M:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384
To run dimensions from a model:
python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1
Output:
- a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs.
""", # noqa: E501
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--dtype",
type=to_torch_dtype,
required=True,
help="Available options are ['bfloat16', 'float16']",
)
parser.add_argument(
"--sweep-schedules",
action="store_true",
help="Run a sweep over all supported schedules",
)
subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench")
square_parser.add_argument("--dim-start", type=int, required=True)
square_parser.add_argument("--dim-end", type=int, required=True)
square_parser.add_argument("--dim-increment", type=int, required=True)
square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench")
model_parser.add_argument(
"--models",
nargs="+",
type=str,
default=DEFAULT_MODELS,
choices=WEIGHT_SHAPES.keys(),
)
model_parser.add_argument("--tp-sizes",
nargs="+",
type=int,
default=DEFAULT_TP_SIZES)
model_parser.add_argument("--batch-sizes",
nargs="+",
type=int,
default=DEFAULT_BATCH_SIZES)
model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args()
args.func(args)
import math
import pickle
import re
from collections import defaultdict
from typing import List
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from torch.utils.benchmark import Measurement as TMeasurement
from vllm.utils import FlexibleArgumentParser
if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('filename', type=str)
args = parser.parse_args()
with open(args.filename, 'rb') as f:
data: List[TMeasurement] = pickle.load(f)
results = defaultdict(lambda: list())
for v in data:
result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label)
if result is not None:
KN = result.group(1)
else:
raise Exception("MKN not found")
result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label)
if result is not None:
M = result.group(1)
else:
raise Exception("MKN not found")
kernel = v.task_spec.description
results[KN].append({
"kernel": kernel,
"batch_size": M,
"median": v.median
})
rows = int(math.ceil(len(results) / 2))
fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows))
axs = axs.flatten()
axs_idx = 0
for shape, data in results.items():
plt.sca(axs[axs_idx])
df = pd.DataFrame(data)
sns.lineplot(data=df,
x="batch_size",
y="median",
hue="kernel",
style="kernel",
markers=True,
dashes=False,
palette="Dark2")
plt.title(f"Shape: {shape}")
plt.ylabel("time (median, s)")
axs_idx += 1
plt.tight_layout()
plt.savefig("graph_machete_bench.pdf")
# Weight Shapes are in the format
# ([K, N], TP_SPLIT_DIM)
# Example:
# A shape of ([14336, 4096], 0) indicates the following GEMM shape,
# - TP1 : K = 14336, N = 4096
# - TP2 : K = 7168, N = 4096
# A shape of ([4096, 6144], 1) indicates the following GEMM shape,
# - TP1 : K = 4096, N = 6144
# - TP4 : K = 4096, N = 1536
# TP1 shapes
WEIGHT_SHAPES = {
"mistralai/Mistral-7B-v0.1": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-7b-hf": [
([4096, 12288], 1),
([4096, 4096], 0),
([4096, 22016], 1),
([11008, 4096], 0),
],
"meta-llama/Llama-3-8b": [
([4096, 6144], 1),
([4096, 4096], 0),
([4096, 28672], 1),
([14336, 4096], 0),
],
"meta-llama/Llama-2-13b-hf": [
([5120, 15360], 1),
([5120, 5120], 0),
([5120, 27648], 1),
([13824, 5120], 0),
],
"meta-llama/Llama-2-70b-hf": [
([8192, 10240], 1),
([8192, 8192], 0),
([8192, 57344], 1),
([28672, 8192], 0),
],
}
#pragma once #pragma once
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
#define DEVICE_INLINE __forceinline__ __device__
#define HOST_INLINE __forceinline__ __host__
#else
#define HOST_DEVICE_INLINE inline
#define DEVICE_INLINE inline
#define HOST_INLINE inline
#endif
int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_device_attribute(int64_t attribute, int64_t device_id);
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
#pragma once
#include <cute/tensor.hpp>
#include <torch/all.h>
namespace cute {
////////////////////////////////////////////////////////////////////
// layout utils
////////////////////////////////////////////////////////////////////
// Permute layout based on indices, example:
// permute_layout<1, 0>(layout) will swap the two dimensions
// permute_layout<0, 2, 1>(layout) will swap the last two dimensions
template <size_t... I, typename Layout>
CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) {
static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch");
return cute::make_layout(cute::get<I>(l)...);
}
// is the layout f(x) = x
template <typename Layout>
CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
if constexpr (std::is_same_v<Layout, void>)
return true;
else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
return true;
}
return false;
}
}
////////////////////////////////////////////////////////////////////
// Pointer utils
////////////////////////////////////////////////////////////////////
template <class PointerType>
static constexpr auto get_logical_ptr(PointerType* ptr) {
if constexpr (cute::sizeof_bits_v<PointerType> < 8) {
return cute::subbyte_iterator<PointerType>(ptr);
} else {
return ptr;
}
}
////////////////////////////////////////////////////////////////////
// Misc utils
////////////////////////////////////////////////////////////////////
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}
}; // namespace cute
#pragma once
#include <torch/all.h>
#include "cute/layout.hpp"
#include "cutlass/layout/matrix.h"
#include "cutlass/bfloat16.h"
#include "cutlass/half.h"
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using RowMajor = typename cutlass::layout::RowMajor;
namespace cute {
namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
}
template <class F, int... I>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {
return make_shape(f(I)...);
}
}; // namespace detail
template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
} else {
return f(t);
}
CUTE_GCC_UNREACHABLE;
}
// calls: make_shape(f(0), f(1), ..., f(N-1))
template <int N, class F>
CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) {
return detail::make_shape_from_idx(f, make_seq<N>{});
}
}; // namespace cute
// Make a layout from a tensor with `rank(Stride{})`, where the shape is the
// shape of the passed in tensor and the strides are of type `Stride` and
// contain the strides of the passed in tensor, checking that any static strides
// in `Stride{}` match the strides of the passed in tensor.
// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra
// strides are set to be 0 or 1.
template <typename Stride>
static inline auto make_cute_layout(torch::Tensor const& tensor,
std::string_view name = "tensor") {
TORCH_CHECK(tensor.dim() <= rank(Stride{}));
auto stride = cute::transform_with_idx(
Stride{}, [&](auto const& stride_ele, auto const& idx) {
using StrideEle = std::decay_t<decltype(stride_ele)>;
if (idx < tensor.dim()) {
if constexpr (cute::is_static_v<StrideEle>) {
TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ",
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
auto shape = cute::make_shape_from_idx<rank(Stride{})>([&](auto const& idx) {
if (idx < tensor.dim())
return tensor.size(idx);
else
return int64_t(1);
});
return make_layout(shape, stride);
}
template <typename Stride>
static inline auto maybe_make_cute_layout(
c10::optional<torch::Tensor> const& tensor,
std::string_view name = "tensor") {
using Layout = decltype(make_cute_layout<Stride>(*tensor));
if (tensor) {
return std::optional<Layout>{make_cute_layout<Stride>(*tensor, name)};
} else {
return std::optional<Layout>{};
}
}
//
// Torch Type to Cutlass Type (equivalent_cutlass_type)
//
template <typename T>
struct equivalent_cutlass_type {
using type = T;
};
template <typename T>
using equivalent_cutlass_type_t = typename equivalent_cutlass_type<T>::type;
template <>
struct equivalent_cutlass_type<c10::Half> {
using type = cutlass::half_t;
};
template <>
struct equivalent_cutlass_type<c10::BFloat16> {
using type = cutlass::bfloat16_t;
};
//
// equivalent_scalar_t (basically inverse of equivalent_cutlass_type)
//
// Return a `c10::CppTypeToScalarType<T>` compatible type, i.e. get the C++ from
// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half`
template <typename T>
struct equivalent_scalar_type {
using type = T;
};
template <typename T>
using equivalent_scalar_type_t = typename equivalent_scalar_type<T>::type;
template <>
struct equivalent_scalar_type<cutlass::half_t> {
using type = c10::Half;
};
template <>
struct equivalent_scalar_type<cutlass::bfloat16_t> {
using type = c10::BFloat16;
};
// get equivalent c10::ScalarType tag from compile time type
template <typename T>
static inline constexpr c10::ScalarType equivalent_scalar_type_v =
c10::CppTypeToScalarType<equivalent_scalar_type_t<T>>::value;
\ No newline at end of file
#pragma once
#include "cutlass/gemm/collective/collective_builder.hpp"
namespace cutlass::gemm::collective {
using namespace cute;
//
// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for
// for custom kernel tags, allowing you to build custom collectives. Without
// touching the cutlass library headers, using `CutlassKernelTag` will mean it
// will resort to using the standard cutlass collective builder.
//
// Use the default Cutlass collective builder, i.e. use an unmodified cutless
// collective
struct CutlassKernelTag {};
template <class KernelTag, class ArchTag, class OpClass, class ElementA,
class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB,
int AlignmentB, class ElementAccumulator, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, class Enable = void>
struct VLLMCollectiveBuilder {
static_assert(sizeof(ElementA) == 0,
"Could not build a collective for given parameters.");
};
template <class ArchTag, class OpClass, class ElementA, class GmemLayoutA,
int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA,
ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType> {
using CollectiveOp = typename CollectiveBuilder<
ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK,
ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp;
};
}; // namespace cutlass::gemm::collective
\ No newline at end of file
#pragma once
#include "cutlass/integer_subbyte.h"
namespace cutlass {
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed = false>
struct vllm_biased_integer_subbyte : public integer_subbyte<Bits, Signed> {
using Base = integer_subbyte<Bits, Signed>;
using Storage = typename Base::Storage;
using xint_t = typename Base::xint_t;
using Base::bits_mask_;
using Base::sign_mask_;
using Base::storage;
//
// Methods
//
/// No operation
vllm_biased_integer_subbyte() = default;
/// Conversion from integer type
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value)
: Base(value) {}
CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value)
: Base(value) {}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
// "GPTQ" types, i.e. symmetric quantization
using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8
using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128
///////////////////////////////////////////////////////////////////////////////////////////////////
template <int Bits, int Bias, bool Signed>
struct sizeof_bits<vllm_biased_integer_subbyte<Bits, Bias, Signed>> {
static constexpr int value = Bits;
};
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass
import enum
from typing import Dict, Union
from cutlass_library import *
#
# Extend cutlass library with custom types, and missing values
#
class VLLMDataType(enum.Enum):
u4b8 = enum_auto()
u8b128 = enum_auto()
class MixedInputKernelScheduleType(enum.Enum):
TmaWarpSpecializedMixedInput = enum_auto()
TmaWarpSpecializedPingpongMixedInput = enum_auto()
TmaWarpSpecializedCooperativeMixedInput = enum_auto()
VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeNames, # type: ignore
**{
VLLMDataType.u4b8: "u4b8",
VLLMDataType.u8b128: "u8b128",
}
}
VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = {
**DataTypeTag, # type: ignore
**{
VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t",
VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t",
}
}
VLLMKernelScheduleTag: Dict[Union[
MixedInputKernelScheduleType, KernelScheduleType], str] = {
**KernelScheduleTag, # type: ignore
**{
MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput",
MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput:
"cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput",
}
}
This diff is collapsed.
...@@ -83,6 +83,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ...@@ -83,6 +83,25 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace, torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k); int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
......
# Machete (Mixed Precision Cutlass-Based GEMM)
Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin.
## Overview
Machete effectively performs
```
scale_type = w_s.dtype
compute_type = a.dtype
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
```
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
`w_z` is the quantization zeropoints.
> **_NOTE:_** `w_z` is added after the scales so we can
use FMA operations, but this means they must have the scales pre-applied if the
supplied zeropoints assume that they will be subtracted before the scales are
applied.
## API
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
```
from vllm import _custom_ops as ops
...
W_q_packed = ops.machete_prepack_B(w_q, wtype)
output = ops.machete_gemm(
a,
b_q=W_q_packed,
b_type=wtype,
b_scales=w_s,
b_group_size=group_size
)
```
## Code Generation
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.
\ No newline at end of file
import itertools
import math
import os
import shutil
from collections.abc import Iterable
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import jinja2
# yapf conflicts with isort for this block
# yapf: disable
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType, VLLMDataType,
VLLMDataTypeNames, VLLMDataTypeTag,
VLLMKernelScheduleTag)
# yapf: enable
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
using GemmDispatcher_ = GemmDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
{% for s in schedules %}extern torch::Tensor
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
{% endfor %}
template <>
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
[[maybe_unused]] auto M = args.A.size(0);
[[maybe_unused]] auto N = args.B.size(1);
[[maybe_unused]] auto K = args.A.size(1);
if (!args.schedule) {
{%- for cond, s in heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
}
{% for s in schedules %}
if (*args.schedule == "{{ gen_sch_name(s) }}") {
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
}
{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
"schedule = ", *args.schedule);
}
template <>
std::vector<std::string> GemmDispatcher_::supported_schedules() {
return {
{% for s in schedules -%}
"{{ gen_sch_name(s) }}"{{ ",
" if not loop.last }}{%- endfor %}
};
}
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
using Kernel = MacheteKernelTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
Config, with_C, with_scales, with_zeropoints>;
{% for sch in schedules %}
{% set schedule_name = gen_sch_name(sch) -%}
struct sch_{{schedule_name}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
torch::Tensor
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
with_zeropoints = args.zeros.has_value();
{% for s in specializations %}
if (with_C == {{s.with_C|lower}}
&& with_zeropoints == {{s.with_zeropoints|lower}}
&& with_scales == {{s.with_scales|lower}}) {
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
}{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(
false, "for the sake of compile times and binary size machete_mm(..) is "
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
", with_zeropoints=", with_zeropoints,
" (for {{type_name}}_sch_{{schedule_name}})");
}
{% endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
using PrepackBDispatcher_ = PrepackBDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
using PrepackedLayoutB = PrepackedLayoutBTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
template <>
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
return prepack_impl<PrepackedLayoutB>(B);
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass
class ScheduleConfig:
tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass
class TypeConfig:
element_a: DataType
element_b: Union[DataType, VLLMDataType]
element_b_scale: DataType
element_b_zeropoint: DataType
element_d: DataType
accumulator: DataType
@dataclass
class Specialization:
with_C: bool
with_zeropoints: bool
with_scales: bool
@dataclass
class ImplConfig:
type_config: TypeConfig
schedule_configs: List[ScheduleConfig]
specializations: List[Specialization]
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
)
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
f"x{schedule_config.cluster_shape_mnk[1]}" +
f"x{schedule_config.cluster_shape_mnk[2]}")
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
.split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[
schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
.split("::")[-1]
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
f"_{epilogue_schedule}_{tile_scheduler}")
# mostly unique shorter schedule_name
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
schedule_name = generate_schedule_name(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
schedule_name = schedule_name.replace(orig, terse)
return schedule_name
# unique type_name
def generate_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
element_zeropoint = VLLMDataTypeNames[
kernel_type_config.element_b_zeropoint]
return (f"{element_a}{element_b}{element_d}"
f"{accumulator}{element_scale}{element_zeropoint}")
# non-unique shorter type_name
def generate_terse_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
return f"{element_a}{element_b}"
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: List[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
template_globals = {
"DataTypeTag": VLLMDataTypeTag,
"KernelScheduleTag": VLLMKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_name": generate_terse_schedule_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=2):
sources = []
type_name = generate_type_signature(impl_config.type_config)
terse_type_name = generate_terse_type_signature(impl_config.type_config)
sources.append((
f"machete_mm_{terse_type_name}",
mm_dispatch_template.render(type_name=type_name,
type_config=impl_config.type_config,
schedules=impl_config.schedule_configs,
heuristic=impl_config.heuristic),
))
sources.append((
f"machete_prepack_{terse_type_name}",
prepack_dispatch_template.render(
type_name=type_name,
type_config=impl_config.type_config,
),
))
num_schedules = len(impl_config.schedule_configs)
schedules_per_file = math.ceil(num_schedules / num_impl_files)
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
sources.append((
f"machete_mm_{terse_type_name}_impl_part{part}",
mm_impl_template.render(
type_name=type_name,
type_config=impl_config.type_config,
schedules=file_schedules,
specializations=impl_config.specializations,
),
))
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
schedules = [
ScheduleConfig(
tile_shape_mn=tile_shape_mn,
cluster_shape_mnk=cluster_shape_mnk,
kernel_schedule=kernel_schedule,
epilogue_schedule=epilogue_schedule,
tile_scheduler=tile_scheduler,
) for tile_shape_mn, cluster_shape_mnk in (
((128, 16), (1, 1, 1)),
((128, 32), (1, 1, 1)),
((128, 64), (1, 1, 1)),
((128, 128), (1, 1, 1)),
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
for tile_scheduler in (TileSchedulerType.StreamK, )
]
# For now we use the same heuristic for all types
default_heuristic = [
("M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
("M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
("M > 16",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)),
(None,
ScheduleConfig(tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK))
]
impl_configs = []
GPTQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for element_a in (DataType.f16, DataType.bf16)))
GPTQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(GPTQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
AWQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (DataType.u4, DataType.u8)
for element_a in (DataType.f16, DataType.bf16)))
AWQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(AWQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for impl_config in impl_configs:
for filename, code in create_sources(impl_config):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()
#pragma once
#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective
\ No newline at end of file
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete
This diff is collapsed.
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete
#pragma once
#include <torch/all.h>
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
struct PyTorchArguments {
torch::Tensor const& A;
torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales;
c10::optional<torch::Tensor> const& zeros;
c10::optional<int64_t> group_size;
c10::optional<torch::Tensor> const& C;
c10::optional<double> alpha;
c10::optional<double> beta;
c10::optional<std::string> schedule;
};
template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D =
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K));
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
return D;
};
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment