Unverified Commit 86e9c8df authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)


Co-authored-by: default avatarmgoin <michael@neuralmagic.com>
Co-authored-by: default avatarDivakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: default avatarTyler Michael Smith <tyler@neuralmagic.com>
parent ee5f34b1
...@@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu" "csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu" "csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu" "csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
......
...@@ -4,8 +4,10 @@ import itertools ...@@ -4,8 +4,10 @@ import itertools
import math import math
import pickle as pkl import pickle as pkl
import time import time
from typing import Callable, Iterable, List, Tuple from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple
import pandas as pd
import torch import torch
import torch.utils.benchmark as TBenchmark import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement from torch.utils.benchmark import Measurement as TMeasurement
...@@ -84,6 +86,10 @@ def loop_over_weights( ...@@ -84,6 +86,10 @@ def loop_over_weights(
fn(a, w_ref, w_q, w_s) fn(a, w_ref, w_q, w_s)
_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
def bench(atype: torch.dtype, def bench(atype: torch.dtype,
wtype: ScalarType, wtype: ScalarType,
group_size: int, group_size: int,
...@@ -94,6 +100,8 @@ def bench(atype: torch.dtype, ...@@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
sub_label: str, sub_label: str,
benchmark_marlinv1: bool = True, benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]: sweep_schedules: bool = True) -> Iterable[TMeasurement]:
global _SWEEP_SCHEDULES_RESULTS
a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}" sub_label += f", L={len(weights)}"
...@@ -163,6 +171,11 @@ def bench(atype: torch.dtype, ...@@ -163,6 +171,11 @@ def bench(atype: torch.dtype,
best_schedule = None best_schedule = None
schedules = ops.machete_supported_schedules(wtype) schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules): for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])
# Prune known bad schedules
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue
def run(a, _, w_q, w_s, schedule=schedule): def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a, ops.machete_gemm(a,
...@@ -175,6 +188,20 @@ def bench(atype: torch.dtype, ...@@ -175,6 +188,20 @@ def bench(atype: torch.dtype,
res = bench_fn(label, sub_label, "machete_best", res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run)) lambda: loop_over_weights(a, weights_machete, run))
results_row = {
"M": m,
"K": k,
"N": n,
"group_size": group_size,
"schedule": schedule,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row
print(f" {res.median:5.5} ", schedule) print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median: if not best or res.median < best.median:
best = res best = res
...@@ -235,18 +262,22 @@ def run_square_bench(args): ...@@ -235,18 +262,22 @@ def run_square_bench(args):
dim_sizes = list( dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment)) range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))
data = run(args.dtype, args.sweep_schedules, MKNs) data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"square_bench-{args.dtype}") make_output(data, MKNs, f"square_bench-{args.dtype}")
def run_range_bench(args): def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment)) m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
n = len(dim_sizes) m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes m_increment, k_increment, n_increment = \
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes [int(x) for x in args.dim_increment.split(",")]
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes Ms = list(range(m_start, m_end + 1, m_increment))
MKNs = list(zip(Ms, Ks, Ns)) Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
MKNs = list(product(Ms, Ks, Ns))
data = run(args.dtype, args.sweep_schedules, MKNs) data = run(args.dtype, args.sweep_schedules, MKNs)
make_output(data, MKNs, f"range_bench-{args.dtype}") make_output(data, MKNs, f"range_bench-{args.dtype}")
...@@ -333,6 +364,9 @@ Benchmark Machete GEMM. ...@@ -333,6 +364,9 @@ Benchmark Machete GEMM.
action="store_true", action="store_true",
help="Run a sweep over all supported schedules", help="Run a sweep over all supported schedules",
) )
parser.add_argument("--sweep-csv-out",
help="CSV to store sweep results",
default="sch_sweep_results.csv")
subparsers = parser.add_subparsers(dest="cmd", required=True) subparsers = parser.add_subparsers(dest="cmd", required=True)
square_parser = subparsers.add_parser("square_bench") square_parser = subparsers.add_parser("square_bench")
...@@ -342,12 +376,21 @@ Benchmark Machete GEMM. ...@@ -342,12 +376,21 @@ Benchmark Machete GEMM.
square_parser.set_defaults(func=run_square_bench) square_parser.set_defaults(func=run_square_bench)
range_parser = subparsers.add_parser("range_bench") range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True) range_parser.add_argument(
range_parser.add_argument("--dim-end", type=int, required=True) "--dim-start",
range_parser.add_argument("--dim-increment", type=int, required=True) type=str,
range_parser.add_argument("--m-constant", type=int, default=None) required=True,
range_parser.add_argument("--n-constant", type=int, default=None) help="Start value for M,K,N as common separated list")
range_parser.add_argument("--k-constant", type=int, default=None) range_parser.add_argument(
"--dim-end",
type=str,
required=True,
help="End value (inclusive) for M,K,N as common separated list")
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
help="Increment value for M,K,N as common separated list")
range_parser.set_defaults(func=run_range_bench) range_parser.set_defaults(func=run_range_bench)
model_parser = subparsers.add_parser("model_bench") model_parser = subparsers.add_parser("model_bench")
...@@ -369,4 +412,9 @@ Benchmark Machete GEMM. ...@@ -369,4 +412,9 @@ Benchmark Machete GEMM.
model_parser.set_defaults(func=run_model_bench) model_parser.set_defaults(func=run_model_bench)
args = parser.parse_args() args = parser.parse_args()
_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
args.func(args) args.func(args)
if _SWEEP_SCHEDULES_RESULTS is not None:
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
pandas
\ No newline at end of file
...@@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, ...@@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
name, ".stride(", idx, ") to be ", StrideEle::value); name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{}; return StrideEle{};
} else { } else {
return tensor.stride(idx); if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
} }
} else { } else {
// Extra strides are assumed to be 0 or 1 // Extra strides are assumed to be 0 or 1
......
...@@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B, ...@@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,
}; // namespace machete }; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta, torch::Tensor& b_meta,
torch::Tensor& b_scales, torch::Tensor& b_scales,
......
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp16.h>
static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = std::max(finish_row - start_row, 0);
int row_stride = size_k * sizeof(half) / 16;
auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;
int offset = row * row_stride;
half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);
int base_k = 0;
for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
base_k += default_threads;
}
if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];
out_half[cur_k] = a_row_half[src_pos];
}
}
};
for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}
// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);
TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});
torch::Tensor D = torch::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
A_2d.size(0), A_2d.size(1), block_rows);
return D;
}
\ No newline at end of file
...@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput ...@@ -157,7 +157,7 @@ TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass @dataclass(frozen=True)
class ScheduleConfig: class ScheduleConfig:
tile_shape_mn: Tuple[int, int] tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int] cluster_shape_mnk: Tuple[int, int, int]
...@@ -328,56 +328,137 @@ def generate(): ...@@ -328,56 +328,137 @@ def generate():
# about how this works # about how this works
SCRIPT_DIR = os.path.dirname(__file__) SCRIPT_DIR = os.path.dirname(__file__)
schedules = [ schedule_common_params = dict(
ScheduleConfig( kernel_schedule=TmaMI,
tile_shape_mn=tile_shape_mn, epilogue_schedule=TmaCoop,
cluster_shape_mnk=cluster_shape_mnk, tile_scheduler=TileSchedulerType.StreamK,
kernel_schedule=kernel_schedule, )
epilogue_schedule=epilogue_schedule,
tile_scheduler=tile_scheduler,
) for tile_shape_mn, cluster_shape_mnk in (
((128, 16), (1, 1, 1)),
((128, 32), (1, 1, 1)),
((128, 64), (1, 1, 1)),
((128, 128), (1, 1, 1)),
) for kernel_schedule in (TmaMI, ) for epilogue_schedule in (TmaCoop, )
for tile_scheduler in (TileSchedulerType.StreamK, )
]
# For now we use the same heuristic for all types # For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [ default_heuristic = [
("M > 64", #### M = 257+
ScheduleConfig( (
tile_shape_mn=(128, 128), "M > 256 && K <= 16384 && N <= 4096",
cluster_shape_mnk=(1, 1, 1), ScheduleConfig(
kernel_schedule=TmaMI, tile_shape_mn=(128, 128),
epilogue_schedule=TmaCoop, cluster_shape_mnk=(2, 1, 1),
tile_scheduler=TileSchedulerType.StreamK, **schedule_common_params # type: ignore
)), )),
("M > 32", (
ScheduleConfig( "M > 256",
tile_shape_mn=(128, 64), ScheduleConfig(
cluster_shape_mnk=(1, 1, 1), tile_shape_mn=(128, 256),
kernel_schedule=TmaMI, cluster_shape_mnk=(2, 1, 1),
epilogue_schedule=TmaCoop, **schedule_common_params # type: ignore
tile_scheduler=TileSchedulerType.StreamK, )),
)), #### M = 129-256
("M > 16", (
ScheduleConfig( "M > 128 && K <= 4096 && N <= 4096",
tile_shape_mn=(128, 32), ScheduleConfig(
cluster_shape_mnk=(1, 1, 1), tile_shape_mn=(128, 64),
kernel_schedule=TmaMI, cluster_shape_mnk=(2, 1, 1),
epilogue_schedule=TmaCoop, **schedule_common_params # type: ignore
tile_scheduler=TileSchedulerType.StreamK, )),
)), (
(None, "M > 128 && K <= 8192 && N <= 8192",
ScheduleConfig(tile_shape_mn=(128, 16), ScheduleConfig(
cluster_shape_mnk=(1, 1, 1), tile_shape_mn=(128, 128),
kernel_schedule=TmaMI, cluster_shape_mnk=(2, 1, 1),
epilogue_schedule=TmaCoop, **schedule_common_params # type: ignore
tile_scheduler=TileSchedulerType.StreamK)) )),
(
"M > 128",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 65-128
(
"M > 64 && K <= 4069 && N <= 4069",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K <= 4069 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K >= 8192 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 33-64
(
"M > 32 && K <= 6144 && N <= 6144",
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32 && K >= 16384 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 17-32
(
"M > 16 && K <= 12288 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 16",
ScheduleConfig(
tile_shape_mn=(256, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 1-16
(
"N >= 26624",
ScheduleConfig(
tile_shape_mn=(256, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
(
None,
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
] ]
schedules = list(set([x[1] for x in default_heuristic]))
impl_configs = [] impl_configs = []
GPTQ_kernel_type_configs = list( GPTQ_kernel_type_configs = list(
......
...@@ -152,7 +152,8 @@ struct MacheteKernelTemplate { ...@@ -152,7 +152,8 @@ struct MacheteKernelTemplate {
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A); int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size = maybe_group_size.value_or(K); int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size; int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K); TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
......
...@@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) { ...@@ -71,7 +71,7 @@ torch::Tensor run_impl(PyTorchArguments args) {
auto arguments = MacheteKernel::create_arguments( auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr, stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0), layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size.value_or(K)); args.group_size);
TORCH_CHECK(MacheteKernel::can_implement(arguments), TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments"); "Machete kernel cannot be run with these arguments");
......
...@@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) { ...@@ -53,7 +53,7 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// clang-format on // clang-format on
// Allocate output // Allocate output
torch::Tensor D = torch::empty_like(B); torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt, prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr())); static_cast<ElementB*>(D.mutable_data_ptr()));
......
...@@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -192,6 +192,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"-> Tensor"); "-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// gptq_marlin Optimized Quantized GEMM for GPTQ. // gptq_marlin Optimized Quantized GEMM for GPTQ.
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
......
...@@ -31,6 +31,8 @@ MNK_SHAPES = [ ...@@ -31,6 +31,8 @@ MNK_SHAPES = [
(257, 4224, 4160), (257, 4224, 4160),
(257, 4096, 4096), (257, 4096, 4096),
(64, 4096, 4096), (64, 4096, 4096),
(1024, 4096, 8192),
(1024, 8192, 4096),
] ]
ACT_TYPES = [torch.float16, torch.bfloat16] ACT_TYPES = [torch.float16, torch.bfloat16]
...@@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype, ...@@ -139,6 +141,7 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
output_ref = torch.matmul(a, w_ref) output_ref = torch.matmul(a, w_ref)
for schedule in ops.machete_supported_schedules(wtype): for schedule in ops.machete_supported_schedules(wtype):
print(f"Testing schedule {schedule}")
output = ops.machete_gemm( output = ops.machete_gemm(
a, a,
b_q=w_q_machete, b_q=w_q_machete,
......
import pytest
import torch
from tests.kernels.utils import opcheck
from vllm._custom_ops import permute_cols
@pytest.mark.parametrize('shape', [(1, 512), (544, 4096), (67, 8192)])
@pytest.mark.parametrize('dtype', [torch.bfloat16, torch.float16])
def test_permute_cols(shape, dtype):
x = torch.randn(shape, dtype=dtype).cuda()
perm = torch.randperm(x.shape[1]).to(torch.int).cuda()
opcheck(torch.ops._C.permute_cols, (x, perm))
y = permute_cols(x, perm)
torch.testing.assert_close(y, x[:, perm])
\ No newline at end of file
...@@ -438,7 +438,8 @@ try: ...@@ -438,7 +438,8 @@ try:
@torch.library.register_fake("_C::machete_prepack_B") @torch.library.register_fake("_C::machete_prepack_B")
def machete_prepack_B_fake(b_q_weight: torch.Tensor, def machete_prepack_B_fake(b_q_weight: torch.Tensor,
b_type: ScalarType) -> torch.Tensor: b_type: ScalarType) -> torch.Tensor:
return torch.empty_like(b_q_weight) return torch.empty_like(b_q_weight,
memory_format=torch.contiguous_format)
@torch.library.register_fake("_C::causal_conv1d_fwd") @torch.library.register_fake("_C::causal_conv1d_fwd")
def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor, def causal_conv1d_fwd_fake(x: torch.Tensor, weight: torch.Tensor,
...@@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor, ...@@ -625,6 +626,22 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
return torch.ops._C.machete_prepack_B(b_q_weight, b_type) return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
# TODO: has to be a better way to do this
try:
torch.ops._C.permute_cols # noqa B018
@torch.library.register_fake("_C::permute_cols")
def _permute_cols_fake(a: torch.Tensor,
perm: torch.Tensor) -> torch.Tensor:
return torch.empty_like(a)
except Exception:
pass
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
return torch.ops._C.permute_cols(a, perm)
# fp8 # fp8
def scaled_fp8_quant( def scaled_fp8_quant(
input: torch.Tensor, input: torch.Tensor,
......
...@@ -7,10 +7,11 @@ from vllm.logger import init_logger ...@@ -7,10 +7,11 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales,
replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) verify_marlin_supported, verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
...@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -231,7 +232,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits) num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qweight", marlin_qweight) replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format. # Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales( marlin_scales = marlin_permute_scales(
...@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -239,7 +240,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.input_size_per_partition, size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size) group_size=self.quant_config.group_size)
replace_tensor(layer, "scales", marlin_scales) replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format. # Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points( marlin_zp = awq_to_marlin_zero_points(
...@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): ...@@ -247,7 +248,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
size_k=layer.num_groups, size_k=layer.num_groups,
size_n=layer.output_size_per_partition, size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits) num_bits=self.quant_config.quant_type.size_bits)
replace_tensor(layer, "qzeros", marlin_zp) replace_parameter(layer, "qzeros", marlin_zp)
# Not-used # Not-used
layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx = marlin_make_empty_g_idx(device)
......
from typing import Callable, List, Optional from typing import Callable, List, Optional, Set
import torch import torch
from vllm import _custom_ops as ops from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsScheme) CompressedTensorsScheme)
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
ActivationOrdering) ActivationOrdering)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, marlin_repeat_scales_on_all_ranks)
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
ChannelQuantScaleParameter, ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
...@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter, ...@@ -19,6 +18,8 @@ from vllm.model_executor.parameter import (BasevLLMParameter,
RowvLLMParameter) RowvLLMParameter)
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
logger = init_logger(__name__)
__all__ = ["CompressedTensorsWNA16"] __all__ = ["CompressedTensorsWNA16"]
WNA16_SUPPORTED_TYPES_MAP = { WNA16_SUPPORTED_TYPES_MAP = {
4: scalar_types.uint4b8, 4: scalar_types.uint4b8,
...@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) ...@@ -28,6 +29,7 @@ WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
class CompressedTensorsWNA16(CompressedTensorsScheme): class CompressedTensorsWNA16(CompressedTensorsScheme):
_kernel_backends_being_used: Set[str] = set()
def __init__(self, def __init__(self,
strategy: str, strategy: str,
...@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -52,35 +54,43 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits] self.quant_type = WNA16_SUPPORTED_TYPES_MAP[num_bits]
# Verify supported on platform.
verify_marlin_supported(quant_type=self.quant_type,
group_size=self.group_size)
@classmethod @classmethod
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
# ampere and up # ampere and up
return 80 return 80
def create_weights(self, layer: torch.nn.Module, input_size: int, def create_weights(self, layer: torch.nn.Module, output_size: int,
output_partition_sizes: List[int], input_size: int, output_partition_sizes: List[int],
input_size_per_partition: int, input_size_per_partition: int,
params_dtype: torch.dtype, weight_loader: Callable, params_dtype: torch.dtype, weight_loader: Callable,
**kwargs): **kwargs):
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_type,
act_type=params_dtype,
group_size=self.group_size,
zero_points=False,
has_g_idx=self.has_g_idx
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for CompressedTensorsWNA16",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# If group_size is -1, we are in channelwise case. # If group_size is -1, we are in channelwise case.
group_size = self.group_size if self.group_size != -1 else input_size group_size = self.group_size if self.group_size != -1 else input_size
row_parallel = (input_size != input_size_per_partition) row_parallel = (input_size != input_size_per_partition)
partition_scales = not marlin_repeat_scales_on_all_ranks( partition_scales = not marlin_repeat_scales_on_all_ranks(
self.has_g_idx, self.group_size, row_parallel) self.has_g_idx, self.group_size, row_parallel)
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size)
scales_and_zp_size = input_size // group_size scales_and_zp_size = input_size // group_size
if partition_scales: if partition_scales:
...@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): ...@@ -137,69 +147,17 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
weight_loader=weight_loader) weight_loader=weight_loader)
layer.register_parameter("weight_g_idx", weight_g_idx) layer.register_parameter("weight_g_idx", weight_g_idx)
layer.input_size_per_partition = input_size_per_partition self.kernel = kernel_type(mp_linear_kernel_config,
layer.output_size_per_partition = output_size_per_partition w_q_param_name="weight_packed",
layer.input_size = input_size w_s_param_name="weight_scale",
layer.group_size = group_size w_zp_param_name=None,
w_gidx_param_name="weight_g_idx")
# Checkpoints are serialized in compressed-tensors format, which is # Checkpoints are serialized in compressed-tensors format, which is
# different from marlin format. Handle repacking here. # different from the format the kernel may want. Handle repacking here.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.weight_packed.device self.kernel.process_weights_after_loading(layer)
# Allocate marlin workspace.
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "weight_g_idx", g_idx)
else:
layer.weight_g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.weight_zp = marlin_make_empty_g_idx(device)
# Update for kernel
layer.weight_packed = torch.nn.Parameter(
layer.weight_packed.t().contiguous(), requires_grad=False)
layer.weight_scale = torch.nn.Parameter(
layer.weight_scale.squeeze().t().contiguous(), requires_grad=False)
# Repack weights from compressed-tensors format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.weight_packed,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_type.size_bits)
replace_tensor(layer, "weight_packed", marlin_qweight)
# Permute scales from compressed-tensors format to marlin format.
# scale is required on all partitions if activation reordering
marlin_scales = marlin_permute_scales(
layer.weight_scale,
size_k=(layer.input_size
if self.has_g_idx else layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=layer.group_size)
replace_tensor(layer, "weight_scale", marlin_scales)
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor: bias: Optional[torch.Tensor]) -> torch.Tensor:
return self.kernel.apply_weights(layer, x, bias)
return apply_gptq_marlin_linear(
input=x,
weight=layer.weight_packed,
weight_scale=layer.weight_scale,
weight_zp=layer.weight_zp,
g_idx=layer.weight_g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=True,
bias=bias)
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, ...@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, check_marlin_supported, marlin_moe_permute_scales,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter, from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter, GroupQuantScaleParameter,
...@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config. quant_config: The GPTQ Marlin quantization config.
""" """
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None: def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config self.quant_config = quant_config
...@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype, params_dtype: torch.dtype,
**extra_weight_attrs, **extra_weight_attrs,
) -> None: ) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes) output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader") weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size # Normalize group_size
if self.quant_config.group_size != -1: if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size group_size = self.quant_config.group_size
else: else:
group_size = input_size group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
# Determine sharding # Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size, self.quant_config.group_size,
...@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("g_idx", g_idx) layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales) layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros) layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile self.kernel = kernel_type(mp_linear_kernel_config,
layer.qweight = Parameter(layer.qweight.data, requires_grad=False) w_q_param_name="qweight",
layer.scales = Parameter(layer.scales.data, requires_grad=False) w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
# Allocate marlin workspace def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.workspace = marlin_make_workspace( self.kernel.process_weights_after_loading(layer)
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales)
def apply( def apply(
self, self,
...@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase): ...@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return apply_gptq_marlin_linear( return self.kernel.apply_weights(layer, x, bias)
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias,
)
class GPTQMarlinMoEMethod(FusedMoEMethodBase): class GPTQMarlinMoEMethod(FusedMoEMethodBase):
...@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_g_idx_sort_indices[e]] w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][ w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]] w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices", replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices) w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices", replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices) w2_g_idx_sort_indices)
else: else:
# Reset g_idx related tensors # Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0] num_experts = layer.w13_g_idx.shape[0]
...@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[2], layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
) )
replace_tensor(layer, "w13_qweight", marlin_w13_qweight) replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack( marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight, layer.w2_qweight,
layer.w2_g_idx_sort_indices, layer.w2_g_idx_sort_indices,
...@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[2], layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits, self.quant_config.quant_type.size_bits,
) )
replace_tensor(layer, "w2_qweight", marlin_w2_qweight) replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales # Repack scales
marlin_w13_scales = marlin_moe_permute_scales( marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales, s=layer.w13_scales,
...@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): ...@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_n=layer.w13_scales.shape[2], size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_tensor(layer, "w13_scales", marlin_w13_scales) replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales( marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales, s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2], size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size, group_size=self.quant_config.group_size,
) )
replace_tensor(layer, "w2_scales", marlin_w2_scales) replace_parameter(layer, "w2_scales", marlin_w2_scales)
def apply( def apply(
self, self,
......
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
import os
from typing import List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
from functools import partial
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment