Unverified Commit aeb37c2a authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)

parent 3dbb215b
......@@ -143,6 +143,19 @@ else()
message(FATAL_ERROR "Can't find CUDA or HIP installation.")
endif()
#
# For cuda we want to be able to control which architectures we compile for on
# a per-file basis in order to cut down on compile time. So here we extract
# the set of architectures we want to compile for and remove the from the
# CMAKE_CUDA_FLAGS so that they are not applied globally.
#
if(VLLM_GPU_LANG STREQUAL "CUDA")
clear_cuda_arches(CUDA_ARCH_FLAGS)
extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}")
message(STATUS "CUDA target architectures: ${CUDA_ARCHS}")
endif()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# the supported versions for the current language.
......@@ -223,30 +236,89 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS})
if (MARLIN_ARCHS)
set(MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_SRCS}"
CUDA_ARCHS "${MARLIN_ARCHS}")
list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}")
message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}")
else()
message(STATUS "Not building Marlin kernels as no compatible archs found"
"in CUDA target architectures")
endif()
#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_3X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1")
message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}")
else()
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set(SCALED_MM_3X_ARCHS)
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper.")
else()
message(STATUS "Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
set_source_files_properties(
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}")
# subtract out the archs that are already built for 3x
list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS})
if (SCALED_MM_2X_ARCHS)
set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu")
set_gencode_flags_for_srcs(
SRCS "${SRCS}"
CUDA_ARCHS "${SCALED_MM_2X_ARCHS}")
list(APPEND VLLM_EXT_SRC "${SRCS}")
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1")
message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}")
else()
if (SCALED_MM_3X_ARCHS)
message(STATUS "Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x")
else()
message(STATUS "Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures")
endif()
endif()
......@@ -254,15 +326,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Machete kernels
# The machete kernels only work on hopper and require CUDA 12.0 or later.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
# Only build Machete kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
set(MACHETE_GEN_SCRIPT
${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py)
file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH)
message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}")
message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}")
if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH})
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py
${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT}
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
......@@ -275,26 +358,40 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH}
CACHE STRING "Last run machete generate script hash" FORCE)
message(STATUS "Machete generation completed successfully.")
endif()
else()
message(STATUS "Machete generation script has not changed, skipping generation.")
endif()
# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu")
list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")
set_source_files_properties(
${MACHETE_GEN_SOURCES}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()
# forward compatible
set_gencode_flags_for_srcs(
SRCS "${MACHETE_GEN_SOURCES}"
CUDA_ARCHS "${MACHETE_ARCHS}")
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
# raise an error if the user that this was built with an incompatible
# CUDA version)
list(APPEND VLLM_EXT_SRC
csrc/quantization/machete/machete_pytorch.cu)
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
else()
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
AND MACHETE_ARCHS)
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper.")
else()
message(STATUS "Not building Machete kernels as no compatible archs "
"found in CUDA target architectures")
endif()
endif()
# if CUDA endif
endif()
message(STATUS "Enabling C extension.")
......@@ -323,14 +420,31 @@ set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu")
set_gencode_flags_for_srcs(
SRCS "${VLLM_MOE_EXT_SRC}"
CUDA_ARCHS "${CUDA_ARCHS}")
if(VLLM_GPU_LANG STREQUAL "CUDA")
list(APPEND VLLM_MOE_EXT_SRC
cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}")
if (MARLIN_MOE_ARCHS)
set(MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_moe_ops.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_MOE_SRC}"
CUDA_ARCHS "${MARLIN_MOE_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}")
message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}")
else()
message(STATUS "Not building Marlin MOE kernels as no compatible archs found"
"in CUDA target architectures")
endif()
endif()
message(STATUS "Enabling moe extension.")
......
......@@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
endmacro()
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()
#
# Extract unique CUDA architectures from a list of compute capabilities codes in
# the form `<major><minor>[<letter>]`, convert them to the form sort
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
# stores them in `OUT_ARCHES`.
#
# Example:
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
# OUT_ARCHES="7.5;...;9.0"
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
set(_CUDA_ARCHES)
foreach(_ARCH ${CUDA_ARCH_FLAGS})
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string_to_ver(_COMPUTE_VER ${_COMPUTE})
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHES)
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
endfunction()
#
# For a specific file set the `-gencode` flag in compile options conditionally
# for the CUDA language.
#
# Example:
# set_gencode_flag_for_srcs(
# SRCS "foo.cu"
# ARCH "compute_75"
# CODE "sm_75")
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
# `foo.cu` (only for the CUDA language).
#
macro(set_gencode_flag_for_srcs)
set(options)
set(oneValueArgs ARCH CODE)
set(multiValueArgs SRCS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
set_property(
SOURCE ${arg_SRCS}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
)
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
endmacro(set_gencode_flag_for_srcs)
#
# For a list of source files set the `-gencode` flags in the files specific
# compile options (specifically for the CUDA language).
#
# arguments are:
# SRCS: list of source files
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
# that is larger than BUILD_PTX_FOR_ARCH.
#
macro(set_gencode_flags_for_srcs)
set(options)
set(oneValueArgs BUILD_PTX_FOR_ARCH)
set(multiValueArgs SRCS CUDA_ARCHS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_ARCH}"
CODE "sm_${_ARCH}")
endforeach()
if (${arg_BUILD_PTX_FOR_ARCH})
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_PTX_ARCH}"
CODE "compute_${_PTX_ARCH}")
endif()
endif()
endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result.
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
set(_CUDA_ARCHS "9.0a")
endif()
endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# less or eqault to ARCH
foreach(_ARCH ${CUDA_ARCHS})
set(_TMP_ARCH)
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
set(_TMP_ARCH ${_SRC_ARCH})
else()
break()
endif()
endforeach()
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
endif()
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS)
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`.
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
# the architectures on a per file basis.
#
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
#
......@@ -174,109 +345,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()
elseif(${GPU_LANG} STREQUAL "CUDA")
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if (NOT _CUDA_ARCH_FLAGS)
message(FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
# Process each `gencode` flag.
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
if (_SM)
set(_SM ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
if (_CODE)
set(_CODE ${CMAKE_MATCH_1})
endif()
# Make sure the virtual architecture can be matched.
if (NOT _COMPUTE)
message(FATAL_ERROR
"Could not determine virtual architecture from: ${_ARCH}.")
endif()
# One of sm_ or compute_ must exist.
if ((NOT _SM) AND (NOT _CODE))
message(FATAL_ERROR
"Could not determine a codegen architecture from: ${_ARCH}.")
endif()
if (_SM)
# -real suffix let CMake to only generate elf code for the kernels.
# we want this, otherwise the added ptx (default) will increase binary size.
set(_VIRT "-real")
set(_CODE_ARCH ${_SM})
else()
# -virtual suffix let CMake to generate ptx code for the kernels.
set(_VIRT "-virtual")
set(_CODE_ARCH ${_CODE})
endif()
# Check if the current version is in the supported arch list.
string_to_ver(_CODE_VER ${_CODE_ARCH})
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
continue()
endif()
# Add it to the arch list.
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
endforeach()
endif()
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
endmacro()
#
......
......@@ -12,6 +12,11 @@
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
......
......@@ -27,6 +27,7 @@
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
......@@ -552,3 +553,7 @@ torch::Tensor marlin_gemm_moe(
thread_n, sms, max_par, replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights);
#include "core/registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
// Apply topk softmax to the gating outputs.
......@@ -18,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe);
// conditionally compiled so impl registration is in source file
#endif
}
......
......@@ -90,63 +90,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _zeros, int64_t split_k_iters,
int64_t thx, int64_t thy);
torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t size_m, int64_t size_n, int64_t size_k);
namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype);
}; // namespace machete
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n,
int64_t size_k);
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce);
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits);
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits);
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits);
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
......@@ -156,11 +101,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k);
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);
void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
......@@ -175,14 +115,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& azp_adj,
c10::optional<torch::Tensor> const& azp,
c10::optional<torch::Tensor> const& bias);
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
torch::Tensor const& b_q_weight,
torch::Tensor const& s_tok,
torch::Tensor const& s_ch,
torch::Tensor const& s_group,
torch::Tensor& workspace, int64_t size_m,
int64_t size_n, int64_t size_k);
#endif
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
......
......@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
......@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
return;
}
#endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) {
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else {
return;
}
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
}
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: ",
version_num);
}
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
......@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype ", c.dtype());
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
int32_t version_num = get_sm_version_num();
if (version_num >= 90) {
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if (version_num >= 90) {
cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
#else
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
return;
}
#endif
} else if (version_num == 89) {
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else if (version_num >= 80) {
return;
}
if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
} else {
return;
}
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias);
}
return;
#endif
TORCH_CHECK_NOT_IMPLEMENTED(
false,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: ",
version_num);
}
\ No newline at end of file
......@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
......@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("fp8_marlin_gemm", &fp8_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
}
uint32_t vals[8];
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
......@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -195,7 +176,7 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS) \
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
......@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out;
}
#endif
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
......@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
\ No newline at end of file
......@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
......@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
......@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
......@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -256,7 +236,7 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
......@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out;
}
#endif
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
......@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
\ No newline at end of file
......@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=2):
def create_sources(impl_config: ImplConfig, num_impl_files=1):
sources = []
type_name = generate_type_signature(impl_config.type_config)
......
......@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B(cudaStream_t stream,
typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout,
typename PrepackedLayoutB::ElementB* B_out_ptr) {
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
......
......@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// Allocate output
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B<PrepackedLayoutB>(stream, B_ptr, layout_Bt,
static_cast<ElementB*>(D.mutable_data_ptr()));
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
......
......@@ -2,6 +2,8 @@
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace machete {
using namespace vllm;
......@@ -78,14 +80,16 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
}
torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm::ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_gemm", &gemm);
m.impl("machete_supported_schedules", &supported_schedules);
}
}; // namespace machete
......@@ -26,6 +26,7 @@
#include <iostream>
#include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
......@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm", &marlin_gemm);
}
......@@ -30,6 +30,7 @@
#include <iostream>
#include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
......@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return d;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
}
......@@ -28,6 +28,7 @@
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
......@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
}
......@@ -167,7 +167,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor");
ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm);
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops.def(
......@@ -175,22 +175,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor");
ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm);
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops.def("machete_supported_schedules", &machete::supported_schedules);
ops.def(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]");
ops.def(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor");
ops.impl("machete_gemm", torch::kCUDA, &machete::gemm);
ops.def(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor");
ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B);
// conditionally compiled so impl registration is in source file
ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor");
ops.impl("permute_cols", torch::kCUDA, &permute_cols);
......@@ -202,21 +204,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor");
ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm);
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
ops.def(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor");
ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack);
ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta);
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops.def(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor");
ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack);
ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta);
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor");
......@@ -237,7 +237,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm);
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops.def(
......@@ -245,7 +245,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor");
ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment