Unverified Commit 4561f139 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Refactor] Rename `gptq_marlin` to `marlin` to match MoE (#32952)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 6cc6d92b
...@@ -377,7 +377,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -377,7 +377,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# preselected input type pairs and schedules. # preselected input type pairs and schedules.
# Generate sources: # Generate sources:
set(MARLIN_GEN_SCRIPT set(MARLIN_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/marlin/generate_kernels.py)
file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH)
list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR) list(JOIN CUDA_ARCHS "," CUDA_ARCHS_STR)
set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})") set(MARLIN_GEN_SCRIPT_HASH_AND_ARCH "${MARLIN_GEN_SCRIPT_HASH}(ARCH:${CUDA_ARCHS_STR})")
...@@ -412,7 +412,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -412,7 +412,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
if (MARLIN_ARCHS) if (MARLIN_ARCHS)
file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_float16.cu") file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_float16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_ARCHS}") CUDA_ARCHS "${MARLIN_ARCHS}")
...@@ -422,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -422,7 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC})
file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/gptq_marlin/sm80_kernel_*_bfloat16.cu") file(GLOB MARLIN_TEMPLATE_BF16_KERNEL_SRC "csrc/quantization/marlin/sm80_kernel_*_bfloat16.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_BF16_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_BF16_ARCHS}") CUDA_ARCHS "${MARLIN_BF16_ARCHS}")
...@@ -434,7 +434,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -434,7 +434,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
if (MARLIN_SM75_ARCHS) if (MARLIN_SM75_ARCHS)
file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/gptq_marlin/sm75_kernel_*.cu") file(GLOB MARLIN_TEMPLATE_SM75_KERNEL_SRC "csrc/quantization/marlin/sm75_kernel_*.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_SM75_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_SM75_ARCHS}") CUDA_ARCHS "${MARLIN_SM75_ARCHS}")
...@@ -446,7 +446,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -446,7 +446,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif() endif()
if (MARLIN_FP8_ARCHS) if (MARLIN_FP8_ARCHS)
file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/gptq_marlin/sm89_kernel_*.cu") file(GLOB MARLIN_TEMPLATE_FP8_KERNEL_SRC "csrc/quantization/marlin/sm89_kernel_*.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}" SRCS "${MARLIN_TEMPLATE_FP8_KERNEL_SRC}"
CUDA_ARCHS "${MARLIN_FP8_ARCHS}") CUDA_ARCHS "${MARLIN_FP8_ARCHS}")
...@@ -459,10 +459,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") ...@@ -459,10 +459,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MARLIN_SRCS set(MARLIN_SRCS
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu" "csrc/quantization/marlin/marlin.cu"
"csrc/quantization/gptq_marlin/marlin_int4_fp8_preprocess.cu" "csrc/quantization/marlin/marlin_int4_fp8_preprocess.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/quantization/marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu") "csrc/quantization/marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs( set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}" SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_OTHER_ARCHS}") CUDA_ARCHS "${MARLIN_OTHER_ARCHS}")
......
...@@ -231,7 +231,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: ...@@ -231,7 +231,7 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
assert bt.w_tok_s is None assert bt.w_tok_s is None
assert bt.group_size is not None assert bt.group_size is not None
fn = lambda: ops.gptq_marlin_gemm( fn = lambda: ops.marlin_gemm(
a=bt.a, a=bt.a,
c=None, c=None,
b_q_weight=w_q, b_q_weight=w_q,
......
...@@ -239,7 +239,7 @@ def bench_run( ...@@ -239,7 +239,7 @@ def bench_run(
"sm_version": sm_version, "sm_version": sm_version,
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
# Kernels # Kernels
"gptq_marlin_gemm": ops.gptq_marlin_gemm, "marlin_gemm": ops.marlin_gemm,
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
"gptq_marlin_repack": ops.gptq_marlin_repack, "gptq_marlin_repack": ops.gptq_marlin_repack,
"allspark_w8a16_gemm": ops.allspark_w8a16_gemm, "allspark_w8a16_gemm": ops.allspark_w8a16_gemm,
...@@ -263,21 +263,21 @@ def bench_run( ...@@ -263,21 +263,21 @@ def bench_run(
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm", description="marlin_gemm",
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
) )
results.append( results.append(
benchmark.Timer( benchmark.Timer(
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 stmt="output = marlin_gemm(a, None, marlin_q_w, marlin_s, None, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
globals=globals, globals=globals,
label=label, label=label,
sub_label=sub_label, sub_label=sub_label,
description="gptq_marlin_gemm_fp32", description="marlin_gemm_fp32",
).blocked_autorange(min_run_time=min_run_time) ).blocked_autorange(min_run_time=min_run_time)
) )
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/marlin/marlin_dtypes.cuh"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define MARLIN_KERNEL_PARAMS \ #define MARLIN_KERNEL_PARAMS \
......
...@@ -23,10 +23,10 @@ ...@@ -23,10 +23,10 @@
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16 #define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif #endif
#include "quantization/gptq_marlin/marlin.cuh" #include "quantization/marlin/marlin.cuh"
#include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "quantization/marlin/marlin_dtypes.cuh"
#include "quantization/gptq_marlin/dequant.h" #include "quantization/marlin/dequant.h"
#include "quantization/gptq_marlin/marlin_mma.h" #include "quantization/marlin/marlin_mma.h"
#include "core/scalar_type.hpp" #include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_bf16.h> #include <cuda_bf16.h>
#include <iostream> #include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh" #include "../marlin/marlin_dtypes.cuh"
using marlin::MarlinScalarType2; using marlin::MarlinScalarType2;
namespace allspark { namespace allspark {
......
...@@ -46,7 +46,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, ...@@ -46,7 +46,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
} // namespace marlin } // namespace marlin
torch::Tensor gptq_marlin_gemm( torch::Tensor marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
...@@ -528,7 +528,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ...@@ -528,7 +528,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias,
} // namespace marlin } // namespace marlin
torch::Tensor gptq_marlin_gemm( torch::Tensor marlin_gemm(
torch::Tensor& a, std::optional<torch::Tensor> c_or_none, torch::Tensor& a, std::optional<torch::Tensor> c_or_none,
torch::Tensor& b_q_weight, torch::Tensor& b_q_weight,
std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales, std::optional<torch::Tensor> const& b_bias_or_none, torch::Tensor& b_scales,
...@@ -856,5 +856,5 @@ torch::Tensor gptq_marlin_gemm( ...@@ -856,5 +856,5 @@ torch::Tensor gptq_marlin_gemm(
#endif #endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm); m.impl("marlin_gemm", &marlin_gemm);
} }
...@@ -303,9 +303,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -303,9 +303,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols); ops.impl("permute_cols", torch::kCUDA, &permute_cols);
// gptq_marlin Optimized Quantized GEMM for GPTQ. // Marlin Optimized Quantized GEMM (supports GPTQ, AWQ, FP8, NVFP4, MXFP4).
ops.def( ops.def(
"gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " "marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, "
"Tensor? b_bias_or_none,Tensor b_scales, " "Tensor? b_bias_or_none,Tensor b_scales, "
"Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, " "Tensor? a_scales, Tensor? global_scale, Tensor? b_zeros_or_none, "
"Tensor? " "Tensor? "
......
...@@ -59,7 +59,7 @@ if current_platform.is_rocm(): ...@@ -59,7 +59,7 @@ if current_platform.is_rocm():
pytest.skip( pytest.skip(
"These tests require gptq_marlin_repack," "These tests require gptq_marlin_repack,"
"marlin_int4_fp8_preprocess, gptq_marlin_24_gemm," "marlin_int4_fp8_preprocess, gptq_marlin_24_gemm,"
"or gptq_marlin_gemm which are not supported on ROCm.", "or marlin_gemm which are not supported on ROCm.",
allow_module_level=True, allow_module_level=True,
) )
...@@ -417,7 +417,7 @@ def marlin_generate_valid_test_cases(): ...@@ -417,7 +417,7 @@ def marlin_generate_valid_test_cases():
), ),
marlin_generate_valid_test_cases(), marlin_generate_valid_test_cases(),
) )
def test_gptq_marlin_gemm( def test_marlin_gemm(
a_type, a_type,
b_type, b_type,
c_type, c_type,
...@@ -511,7 +511,7 @@ def test_gptq_marlin_gemm( ...@@ -511,7 +511,7 @@ def test_gptq_marlin_gemm(
output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device) output = torch.empty((size_m, size_n), dtype=dtype, device=a_input.device)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
a_input, a_input,
output, output,
marlin_q_w, marlin_q_w,
...@@ -646,7 +646,7 @@ def test_marlin_gemm_subset_input(): ...@@ -646,7 +646,7 @@ def test_marlin_gemm_subset_input():
marlin_zp = marlin_make_empty_g_idx(marlin_s.device) marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device) workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
a_input, a_input,
None, None,
marlin_q_w, marlin_q_w,
...@@ -695,7 +695,7 @@ def test_marlin_gemm_with_bias(size_m): ...@@ -695,7 +695,7 @@ def test_marlin_gemm_with_bias(size_m):
marlin_zp = marlin_make_empty_g_idx(marlin_s.device) marlin_zp = marlin_make_empty_g_idx(marlin_s.device)
workspace = marlin_make_workspace_new(a_input.device) workspace = marlin_make_workspace_new(a_input.device)
output = ops.gptq_marlin_gemm( output = ops.marlin_gemm(
a_input, a_input,
None, None,
marlin_q_w, marlin_q_w,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment