Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aeb37c2a
Unverified
Commit
aeb37c2a
authored
Oct 03, 2024
by
Lucas Wilkinson
Committed by
GitHub
Oct 03, 2024
Browse files
[CI/Build] Per file CUDA Archs (improve wheel size and dev build times) (#8845)
parent
3dbb215b
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
508 additions
and
370 deletions
+508
-370
CMakeLists.txt
CMakeLists.txt
+169
-55
cmake/utils.cmake
cmake/utils.cmake
+172
-103
csrc/core/registration.h
csrc/core/registration.h
+5
-0
csrc/moe/marlin_moe_ops.cu
csrc/moe/marlin_moe_ops.cu
+5
-0
csrc/moe/marlin_moe_ops.h
csrc/moe/marlin_moe_ops.h
+0
-15
csrc/moe/torch_bindings.cpp
csrc/moe/torch_bindings.cpp
+1
-2
csrc/ops.h
csrc/ops.h
+0
-68
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+51
-25
csrc/quantization/fp8/fp8_marlin.cu
csrc/quantization/fp8/fp8_marlin.cu
+6
-0
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
+24
-37
csrc/quantization/gptq_marlin/gptq_marlin.cu
csrc/quantization/gptq_marlin/gptq_marlin.cu
+6
-0
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
+27
-41
csrc/quantization/machete/generate.py
csrc/quantization/machete/generate.py
+1
-1
csrc/quantization/machete/machete_prepack_kernel.cuh
csrc/quantization/machete/machete_prepack_kernel.cuh
+3
-4
csrc/quantization/machete/machete_prepack_launcher.cuh
csrc/quantization/machete/machete_prepack_launcher.cuh
+2
-2
csrc/quantization/machete/machete_pytorch.cu
csrc/quantization/machete/machete_pytorch.cu
+9
-5
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
+5
-0
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
+5
-0
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
+5
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+12
-12
No files found.
CMakeLists.txt
View file @
aeb37c2a
...
...
@@ -143,6 +143,19 @@ else()
message
(
FATAL_ERROR
"Can't find CUDA or HIP installation."
)
endif
()
#
# For cuda we want to be able to control which architectures we compile for on
# a per-file basis in order to cut down on compile time. So here we extract
# the set of architectures we want to compile for and remove the from the
# CMAKE_CUDA_FLAGS so that they are not applied globally.
#
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
clear_cuda_arches
(
CUDA_ARCH_FLAGS
)
extract_unique_cuda_archs_ascending
(
CUDA_ARCHS
"
${
CUDA_ARCH_FLAGS
}
"
)
message
(
STATUS
"CUDA target architectures:
${
CUDA_ARCHS
}
"
)
endif
()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# the supported versions for the current language.
...
...
@@ -223,30 +236,89 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
)
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
VLLM_EXT_SRC
}
"
CUDA_ARCHS
"
${
CUDA_ARCHS
}
"
)
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
cuda_archs_loose_intersection
(
MARLIN_ARCHS
"8.0;8.6;8.9;9.0"
${
CUDA_ARCHS
}
)
if
(
MARLIN_ARCHS
)
set
(
MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
MARLIN_SRCS
}
"
CUDA_ARCHS
"
${
MARLIN_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
MARLIN_SRCS
}
"
)
message
(
STATUS
"Building Marlin kernels for archs:
${
MARLIN_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building Marlin kernels as no compatible archs found"
"in CUDA target architectures"
)
endif
()
#
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0;9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_3X_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_SCALED_MM_C3X=1"
)
message
(
STATUS
"Building scaled_mm_c3x for archs:
${
SCALED_MM_3X_ARCHS
}
"
)
else
()
# clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't
# build any 3x kernels
set
(
SCALED_MM_3X_ARCHS
)
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
message
(
STATUS
"Not building scaled_mm_c3x as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running FP8 quantized models on "
"Hopper."
)
else
()
message
(
STATUS
"Not building scaled_mm_c3x as no compatible archs found "
"in CUDA target architectures"
)
endif
()
endif
()
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0
)
set_source_files_properties
(
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a"
)
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection
(
SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.9;9.0;9.0a"
"
${
CUDA_ARCHS
}
"
)
# subtract out the archs that are already built for 3x
list
(
REMOVE_ITEM SCALED_MM_2X_ARCHS
${
SCALED_MM_3X_ARCHS
}
)
if
(
SCALED_MM_2X_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
SCALED_MM_2X_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
list
(
APPEND VLLM_GPU_FLAGS
"-DENABLE_SCALED_MM_C2X=1"
)
message
(
STATUS
"Building scaled_mm_c2x for archs:
${
SCALED_MM_2X_ARCHS
}
"
)
else
()
if
(
SCALED_MM_3X_ARCHS
)
message
(
STATUS
"Not building scaled_mm_c2x as all archs are already built"
" for and covered by scaled_mm_c3x"
)
else
()
message
(
STATUS
"Not building scaled_mm_c2x as no compatible archs found "
"in CUDA target architectures"
)
endif
()
endif
()
...
...
@@ -254,47 +326,72 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Machete kernels
# The machete kernels only work on hopper and require CUDA 12.0 or later.
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0
)
# Only build Machete kernels if we are building for something compatible with sm90a
cuda_archs_loose_intersection
(
MACHETE_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND MACHETE_ARCHS
)
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E env
PYTHONPATH=
${
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/cutlass_extensions/:
${
CUTLASS_DIR
}
/python/:
${
VLLM_PYTHON_PATH
}
:$PYTHONPATH
${
Python_EXECUTABLE
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log
ERROR_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log
)
if
(
NOT machete_generation_result EQUAL 0
)
message
(
FATAL_ERROR
"Machete generation failed."
" Result:
\"
${
machete_generation_result
}
\"
"
"
\n
Check the log for details: "
"
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log"
)
set
(
MACHETE_GEN_SCRIPT
${
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/quantization/machete/generate.py
)
file
(
MD5
${
MACHETE_GEN_SCRIPT
}
MACHETE_GEN_SCRIPT_HASH
)
message
(
STATUS
"Machete generation script hash:
${
MACHETE_GEN_SCRIPT_HASH
}
"
)
message
(
STATUS
"Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}"
)
if
(
NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH}
OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL
${
MACHETE_GEN_SCRIPT_HASH
}
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E env
PYTHONPATH=
${
CMAKE_CURRENT_SOURCE_DIR
}
/csrc/cutlass_extensions/:
${
CUTLASS_DIR
}
/python/:
${
VLLM_PYTHON_PATH
}
:$PYTHONPATH
${
Python_EXECUTABLE
}
${
MACHETE_GEN_SCRIPT
}
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log
ERROR_FILE
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log
)
if
(
NOT machete_generation_result EQUAL 0
)
message
(
FATAL_ERROR
"Machete generation failed."
" Result:
\"
${
machete_generation_result
}
\"
"
"
\n
Check the log for details: "
"
${
CMAKE_CURRENT_BINARY_DIR
}
/machete_generation.log"
)
else
()
set
(
MACHETE_GEN_SCRIPT_HASH
${
MACHETE_GEN_SCRIPT_HASH
}
CACHE STRING
"Last run machete generate script hash"
FORCE
)
message
(
STATUS
"Machete generation completed successfully."
)
endif
()
else
()
message
(
STATUS
"Machete generation
complet
ed s
uccessfully
."
)
message
(
STATUS
"Machete generation
script has not chang
ed
,
s
kipping generation
."
)
endif
()
# Add machete generated sources
file
(
GLOB MACHETE_GEN_SOURCES
"csrc/quantization/machete/generated/*.cu"
)
list
(
APPEND VLLM_EXT_SRC
${
MACHETE_GEN_SOURCES
}
)
message
(
STATUS
"Machete generated sources:
${
MACHETE_GEN_SOURCES
}
"
)
set_source_files_properties
(
${
MACHETE_GEN_SOURCES
}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a"
)
# forward compatible
set_gencode_flags_for_srcs
(
SRCS
"
${
MACHETE_GEN_SOURCES
}
"
CUDA_ARCHS
"
${
MACHETE_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
csrc/quantization/machete/machete_pytorch.cu
)
message
(
STATUS
"Building Machete kernels for archs:
${
MACHETE_ARCHS
}
"
)
else
()
if
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0
AND MACHETE_ARCHS
)
message
(
STATUS
"Not building Machete kernels as CUDA Compiler version is "
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
"later if you intend on running w4a16 quantized models on "
"Hopper."
)
else
()
message
(
STATUS
"Not building Machete kernels as no compatible archs "
"found in CUDA target architectures"
)
endif
()
endif
()
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
# raise an error if the user that this was built with an incompatible
# CUDA version)
list
(
APPEND VLLM_EXT_SRC
csrc/quantization/machete/machete_pytorch.cu
)
# if CUDA endif
endif
()
message
(
STATUS
"Enabling C extension."
)
...
...
@@ -323,14 +420,31 @@ set(VLLM_MOE_EXT_SRC
"csrc/moe/torch_bindings.cpp"
"csrc/moe/topk_softmax_kernels.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
VLLM_MOE_EXT_SRC
}
"
CUDA_ARCHS
"
${
CUDA_ARCHS
}
"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_moe_ops.cu"
)
cuda_archs_loose_intersection
(
MARLIN_MOE_ARCHS
"8.0;8.6;8.9;9.0"
"
${
CUDA_ARCHS
}
"
)
if
(
MARLIN_MOE_ARCHS
)
set
(
MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu"
"csrc/moe/marlin_moe_ops.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
MARLIN_MOE_SRC
}
"
CUDA_ARCHS
"
${
MARLIN_MOE_ARCHS
}
"
)
list
(
APPEND VLLM_MOE_EXT_SRC
"
${
MARLIN_MOE_SRC
}
"
)
message
(
STATUS
"Building Marlin MOE kernels for archs:
${
MARLIN_MOE_ARCHS
}
"
)
else
()
message
(
STATUS
"Not building Marlin MOE kernels as no compatible archs found"
"in CUDA target architectures"
)
endif
()
endif
()
message
(
STATUS
"Enabling moe extension."
)
...
...
cmake/utils.cmake
View file @
aeb37c2a
...
...
@@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR)
string
(
REGEX REPLACE
"
\(
[0-9]+
\)\(
[0-9]
\)
"
"
\\
1.
\\
2"
${
OUT_VER
}
${
IN_STR
}
)
endmacro
()
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro
(
clear_cuda_arches CUDA_ARCH_FLAGS
)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string
(
REGEX MATCHALL
"-gencode arch=[^ ]+"
CUDA_ARCH_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string
(
REGEX REPLACE
"-gencode arch=[^ ]+ *"
""
CMAKE_CUDA_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
endmacro
()
#
# Extract unique CUDA architectures from a list of compute capabilities codes in
# the form `<major><minor>[<letter>]`, convert them to the form sort
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
# stores them in `OUT_ARCHES`.
#
# Example:
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
# OUT_ARCHES="7.5;...;9.0"
function
(
extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS
)
set
(
_CUDA_ARCHES
)
foreach
(
_ARCH
${
CUDA_ARCH_FLAGS
}
)
string
(
REGEX MATCH
"arch=compute_
\(
[0-9]+a?
\)
"
_COMPUTE
${
_ARCH
}
)
if
(
_COMPUTE
)
set
(
_COMPUTE
${
CMAKE_MATCH_1
}
)
endif
()
string_to_ver
(
_COMPUTE_VER
${
_COMPUTE
}
)
list
(
APPEND _CUDA_ARCHES
${
_COMPUTE_VER
}
)
endforeach
()
list
(
REMOVE_DUPLICATES _CUDA_ARCHES
)
list
(
SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING
)
set
(
${
OUT_ARCHES
}
${
_CUDA_ARCHES
}
PARENT_SCOPE
)
endfunction
()
#
# For a specific file set the `-gencode` flag in compile options conditionally
# for the CUDA language.
#
# Example:
# set_gencode_flag_for_srcs(
# SRCS "foo.cu"
# ARCH "compute_75"
# CODE "sm_75")
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
# `foo.cu` (only for the CUDA language).
#
macro
(
set_gencode_flag_for_srcs
)
set
(
options
)
set
(
oneValueArgs ARCH CODE
)
set
(
multiValueArgs SRCS
)
cmake_parse_arguments
(
arg
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
set
(
_FLAG -gencode arch=
${
arg_ARCH
}
,code=
${
arg_CODE
}
)
set_property
(
SOURCE
${
arg_SRCS
}
APPEND PROPERTY
COMPILE_OPTIONS
"$<$<COMPILE_LANGUAGE:CUDA>:
${
_FLAG
}
>"
)
message
(
DEBUG
"Setting gencode flag for
${
arg_SRCS
}
:
${
_FLAG
}
"
)
endmacro
(
set_gencode_flag_for_srcs
)
#
# For a list of source files set the `-gencode` flags in the files specific
# compile options (specifically for the CUDA language).
#
# arguments are:
# SRCS: list of source files
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
# that is larger than BUILD_PTX_FOR_ARCH.
#
macro
(
set_gencode_flags_for_srcs
)
set
(
options
)
set
(
oneValueArgs BUILD_PTX_FOR_ARCH
)
set
(
multiValueArgs SRCS CUDA_ARCHS
)
cmake_parse_arguments
(
arg
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
foreach
(
_ARCH
${
arg_CUDA_ARCHS
}
)
string
(
REPLACE
"."
""
_ARCH
"
${
_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_ARCH
}
"
CODE
"sm_
${
_ARCH
}
"
)
endforeach
()
if
(
${
arg_BUILD_PTX_FOR_ARCH
}
)
list
(
SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
list
(
GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH
)
if
(
_HIGHEST_ARCH VERSION_GREATER_EQUAL
${
arg_BUILD_PTX_FOR_ARCH
}
)
string
(
REPLACE
"."
""
_PTX_ARCH
"
${
arg_BUILD_PTX_FOR_ARCH
}
"
)
set_gencode_flag_for_srcs
(
SRCS
${
arg_SRCS
}
ARCH
"compute_
${
_PTX_ARCH
}
"
CODE
"compute_
${
_PTX_ARCH
}
"
)
endif
()
endif
()
endmacro
()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result.
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
function
(
cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS
)
list
(
REMOVE_DUPLICATES SRC_CUDA_ARCHS
)
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set
(
_CUDA_ARCHS
)
if
(
"9.0a"
IN_LIST SRC_CUDA_ARCHS
)
list
(
REMOVE_ITEM SRC_CUDA_ARCHS
"9.0a"
)
if
(
"9.0"
IN_LIST TGT_CUDA_ARCHS
)
set
(
_CUDA_ARCHS
"9.0a"
)
endif
()
endif
()
list
(
SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING
)
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# less or eqault to ARCH
foreach
(
_ARCH
${
CUDA_ARCHS
}
)
set
(
_TMP_ARCH
)
foreach
(
_SRC_ARCH
${
SRC_CUDA_ARCHS
}
)
if
(
_SRC_ARCH VERSION_LESS_EQUAL _ARCH
)
set
(
_TMP_ARCH
${
_SRC_ARCH
}
)
else
()
break
()
endif
()
endforeach
()
if
(
_TMP_ARCH
)
list
(
APPEND _CUDA_ARCHS
${
_TMP_ARCH
}
)
endif
()
endforeach
()
list
(
REMOVE_DUPLICATES _CUDA_ARCHS
)
set
(
${
OUT_CUDA_ARCHS
}
${
_CUDA_ARCHS
}
PARENT_SCOPE
)
endfunction
()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`.
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
# the architectures on a per file basis.
#
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
#
...
...
@@ -174,109 +345,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
"None of the detected ROCm architectures:
${
HIP_ARCHITECTURES
}
is"
" supported. Supported ROCm architectures are:
${
_GPU_SUPPORTED_ARCHES_LIST
}
."
)
endif
()
elseif
(
${
GPU_LANG
}
STREQUAL
"CUDA"
)
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message
(
DEBUG
"initial CMAKE_CUDA_FLAGS:
${
CMAKE_CUDA_FLAGS
}
"
)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string
(
REGEX MATCHALL
"-gencode arch=[^ ]+"
_CUDA_ARCH_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string
(
REGEX REPLACE
"-gencode arch=[^ ]+ *"
""
CMAKE_CUDA_FLAGS
${
CMAKE_CUDA_FLAGS
}
)
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if
(
NOT _CUDA_ARCH_FLAGS
)
message
(
FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (
${
CMAKE_CUDA_FLAGS
}
)"
)
endif
()
message
(
DEBUG
"final CMAKE_CUDA_FLAGS:
${
CMAKE_CUDA_FLAGS
}
"
)
message
(
DEBUG
"arch flags:
${
_CUDA_ARCH_FLAGS
}
"
)
# Initialize the architecture lists to empty.
set
(
${
GPU_ARCHES
}
)
# Process each `gencode` flag.
foreach
(
_ARCH
${
_CUDA_ARCH_FLAGS
}
)
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string
(
REGEX MATCH
"arch=compute_
\(
[0-9]+a?
\)
"
_COMPUTE
${
_ARCH
}
)
if
(
_COMPUTE
)
set
(
_COMPUTE
${
CMAKE_MATCH_1
}
)
endif
()
string
(
REGEX MATCH
"code=sm_
\(
[0-9]+a?
\)
"
_SM
${
_ARCH
}
)
if
(
_SM
)
set
(
_SM
${
CMAKE_MATCH_1
}
)
endif
()
string
(
REGEX MATCH
"code=compute_
\(
[0-9]+a?
\)
"
_CODE
${
_ARCH
}
)
if
(
_CODE
)
set
(
_CODE
${
CMAKE_MATCH_1
}
)
endif
()
# Make sure the virtual architecture can be matched.
if
(
NOT _COMPUTE
)
message
(
FATAL_ERROR
"Could not determine virtual architecture from:
${
_ARCH
}
."
)
endif
()
# One of sm_ or compute_ must exist.
if
((
NOT _SM
)
AND
(
NOT _CODE
))
message
(
FATAL_ERROR
"Could not determine a codegen architecture from:
${
_ARCH
}
."
)
endif
()
if
(
_SM
)
# -real suffix let CMake to only generate elf code for the kernels.
# we want this, otherwise the added ptx (default) will increase binary size.
set
(
_VIRT
"-real"
)
set
(
_CODE_ARCH
${
_SM
}
)
else
()
# -virtual suffix let CMake to generate ptx code for the kernels.
set
(
_VIRT
"-virtual"
)
set
(
_CODE_ARCH
${
_CODE
}
)
endif
()
# Check if the current version is in the supported arch list.
string_to_ver
(
_CODE_VER
${
_CODE_ARCH
}
)
if
(
NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST
)
message
(
STATUS
"discarding unsupported CUDA arch
${
_VER
}
."
)
continue
()
endif
()
# Add it to the arch list.
list
(
APPEND
${
GPU_ARCHES
}
"
${
_CODE_ARCH
}${
_VIRT
}
"
)
endforeach
()
endif
()
message
(
STATUS
"
${
GPU_LANG
}
target arches:
${${
GPU_ARCHES
}}
"
)
endmacro
()
#
...
...
csrc/core/registration.h
View file @
aeb37c2a
...
...
@@ -12,6 +12,11 @@
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
...
...
csrc/moe/marlin_moe_ops.cu
View file @
aeb37c2a
...
...
@@ -27,6 +27,7 @@
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
...
...
@@ -552,3 +553,7 @@ torch::Tensor marlin_gemm_moe(
thread_n
,
sms
,
max_par
,
replicate_input
,
apply_weights
);
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm_moe"
,
&
marlin_gemm_moe
);
}
csrc/moe/marlin_moe_ops.h
deleted
100644 → 0
View file @
3dbb215b
#pragma once
#include <torch/all.h>
#include "core/scalar_type.hpp"
torch
::
Tensor
marlin_gemm_moe
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b_q_weights
,
const
torch
::
Tensor
&
sorted_ids
,
const
torch
::
Tensor
&
topk_weights
,
const
torch
::
Tensor
&
topk_ids
,
const
torch
::
Tensor
&
b_scales
,
const
torch
::
Tensor
&
g_idx
,
const
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
int64_t
num_experts
,
int64_t
topk
,
int64_t
moe_block_size
,
bool
replicate_input
,
bool
apply_weights
);
csrc/moe/torch_bindings.cpp
View file @
aeb37c2a
#include "core/registration.h"
#include "moe_ops.h"
#include "marlin_moe_ops.h"
TORCH_LIBRARY_EXPAND
(
TORCH_EXTENSION_NAME
,
m
)
{
// Apply topk softmax to the gating outputs.
...
...
@@ -18,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_n, int size_k, bool is_k_full, int num_experts, int topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor"
);
m
.
impl
(
"marlin_gemm_moe"
,
torch
::
kCUDA
,
&
marlin_gemm_moe
);
// conditionally compiled so impl registration is in source file
#endif
}
...
...
csrc/ops.h
View file @
aeb37c2a
...
...
@@ -90,63 +90,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch
::
Tensor
_zeros
,
int64_t
split_k_iters
,
int64_t
thx
,
int64_t
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
namespace
machete
{
std
::
vector
<
std
::
string
>
supported_schedules
(
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
torch
::
Tensor
gemm
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
,
c10
::
optional
<
torch
::
Tensor
>
const
&
scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
zeros
,
c10
::
optional
<
int64_t
>
group_size
,
c10
::
optional
<
torch
::
Tensor
>
const
&
C
,
c10
::
optional
<
double
>
alpha
,
c10
::
optional
<
double
>
beta
,
c10
::
optional
<
std
::
string
>
schedule
);
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
vllm
::
ScalarTypeTorchPtr
const
&
btype
);
};
// namespace machete
torch
::
Tensor
permute_cols
(
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
perm
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
b_zeros
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
vllm
::
ScalarTypeTorchPtr
const
&
b_q_type
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
,
bool
has_zp
,
bool
use_fp32_reduce
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
);
torch
::
Tensor
ggml_dequantize
(
torch
::
Tensor
W
,
int64_t
type
,
int64_t
m
,
int64_t
n
);
...
...
@@ -156,11 +101,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch
::
Tensor
ggml_mul_mat_a8
(
torch
::
Tensor
W
,
torch
::
Tensor
X
,
int64_t
type
,
int64_t
row
);
torch
::
Tensor
fp8_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
@@ -175,14 +115,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
azp_adj
,
c10
::
optional
<
torch
::
Tensor
>
const
&
azp
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
torch
::
Tensor
marlin_qqq_gemm
(
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b_q_weight
,
torch
::
Tensor
const
&
s_tok
,
torch
::
Tensor
const
&
s_ch
,
torch
::
Tensor
const
&
s_group
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
aeb37c2a
...
...
@@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
c10
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#if defined
CUDA_VERSION && CUDA_VERSION >= 12000
#if defined
ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
void
cutlass_scaled_mm_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
...
...
@@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Hopper
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#else
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
void
cutlass_scaled_mm_azp
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
...
@@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
"currently bias dtype must match output dtype "
,
c
.
dtype
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
int32_t
version_num
=
get_sm_version_num
();
if
(
version_num
>=
90
)
{
// Hopper
// Guard against compilation issues for sm90 kernels
#
if
defined CUDA_VERSION && CUDA_VERSION >= 12000
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
cutlass_scaled_mm_azp_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
#else
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
#endif
}
else
if
(
version_num
==
89
)
{
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_azp_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
if
(
version_num
>=
80
)
{
return
;
}
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_azp_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
}
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_azp_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
return
;
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"No compiled cutlass_scaled_mm_azp for a compute capability less than "
"CUDA device capability: "
,
version_num
);
}
\ No newline at end of file
csrc/quantization/fp8/fp8_marlin.cu
View file @
aeb37c2a
...
...
@@ -22,6 +22,8 @@
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
#include "core/registration.h"
using
namespace
marlin
;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
...
...
@@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"fp8_marlin_gemm"
,
&
fp8_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
View file @
aeb37c2a
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
awq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
#include "core/registration.h"
namespace
marlin
{
...
...
@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
}
uint32_t
vals
[
8
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
...
...
@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
...
...
@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
...
...
@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
...
...
@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
}
// namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
awq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
...
...
@@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return
out
;
}
#endif
torch
::
Tensor
awq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
)
{
...
...
@@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"awq_marlin_repack"
,
&
awq_marlin_repack
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
Meta
,
m
)
{
m
.
impl
(
"awq_marlin_repack"
,
&
awq_marlin_repack_meta
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin.cu
View file @
aeb37c2a
...
...
@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
...
...
@@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
);
}
\ No newline at end of file
csrc/quantization/gptq_marlin/gptq_marlin_repack.cu
View file @
aeb37c2a
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace
marlin
{
template
<
int
const
num_threads
,
int
const
num_bits
,
bool
const
has_perm
>
__global__
void
gptq_marlin_repack_kernel
(
uint32_t
const
*
__restrict__
b_q_weight_ptr
,
uint32_t
const
*
__restrict__
perm_ptr
,
uint32_t
*
__restrict__
out_ptr
,
int
size_k
,
int
size_n
)
{}
}
// namespace marlin
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
)
{
TORCH_CHECK_NOT_IMPLEMENTED
(
false
,
"marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"
);
return
torch
::
empty
({
1
,
1
});
}
#else
#include "core/registration.h"
namespace
marlin
{
...
...
@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t
b1_vals
[
tile_ints
];
uint32_t
b2_vals
[
tile_ints
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
tile_ints
;
i
++
)
{
b1_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
sh_stride
*
i
];
b2_vals
[
i
]
=
sh_stage_int_ptr
[
cur_n
+
8
+
sh_stride
*
i
];
}
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
int
cur_elem
=
tc_row
+
tc_offsets
[
i
];
int
cur_int
=
cur_elem
/
pack_factor
;
...
...
@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr
int
pack_idx
[
8
]
=
{
0
,
2
,
4
,
6
,
1
,
3
,
5
,
7
};
uint32_t
res
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
res
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
4
);
}
...
...
@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t
res1
=
0
;
uint32_t
res2
=
0
;
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
res1
|=
vals
[
pack_idx
[
i
]]
<<
(
i
*
8
);
res2
|=
vals
[
4
+
pack_idx
[
i
]]
<<
(
i
*
8
);
...
...
@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
};
auto
start_pipes
=
[
&
](
int
k_tile_id
,
int
n_tile_id
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
-
1
;
pipe
++
)
{
fetch_to_shared
(
pipe
,
k_tile_id
,
n_tile_id
+
pipe
);
}
wait_for_stage
();
};
#pragma unroll
#pragma unroll
for
(
int
k_tile_id
=
start_k_tile
;
k_tile_id
<
finish_k_tile
;
k_tile_id
++
)
{
int
n_tile_id
=
0
;
...
...
@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes
(
k_tile_id
,
n_tile_id
);
while
(
n_tile_id
<
n_tiles
)
{
#pragma unroll
#pragma unroll
for
(
int
pipe
=
0
;
pipe
<
repack_stages
;
pipe
++
)
{
fetch_to_shared
((
pipe
+
repack_stages
-
1
)
%
repack_stages
,
k_tile_id
,
n_tile_id
+
pipe
+
repack_stages
-
1
);
...
...
@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
}
// namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
...
...
@@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return
out
;
}
#endif
torch
::
Tensor
gptq_marlin_repack_meta
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
c10
::
SymInt
size_k
,
c10
::
SymInt
size_n
,
int64_t
num_bits
)
{
...
...
@@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
{
size_k
/
marlin
::
tile_size
,
size_n
*
marlin
::
tile_size
/
pack_factor
},
options
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
);
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
Meta
,
m
)
{
m
.
impl
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack_meta
);
}
\ No newline at end of file
csrc/quantization/machete/generate.py
View file @
aeb37c2a
...
...
@@ -284,7 +284,7 @@ mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template
=
create_template
(
PREPACK_TEMPLATE
)
def
create_sources
(
impl_config
:
ImplConfig
,
num_impl_files
=
2
):
def
create_sources
(
impl_config
:
ImplConfig
,
num_impl_files
=
1
):
sources
=
[]
type_name
=
generate_type_signature
(
impl_config
.
type_config
)
...
...
csrc/quantization/machete/machete_prepack_kernel.cuh
View file @
aeb37c2a
...
...
@@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in,
}
template
<
typename
PrepackedLayoutB
,
typename
InLayout
>
static
void
prepack_B
(
cudaStream_t
stream
,
typename
PrepackedLayoutB
::
ElementB
const
*
B_in_ptr
,
InLayout
B_layout
,
typename
PrepackedLayoutB
::
ElementB
*
B_out_ptr
)
{
static
void
prepack_B_template
(
cudaStream_t
stream
,
typename
PrepackedLayoutB
::
ElementB
const
*
B_in_ptr
,
InLayout
B_layout
,
typename
PrepackedLayoutB
::
ElementB
*
B_out_ptr
)
{
using
TileShapeNKL
=
decltype
(
append
(
typename
PrepackedLayoutB
::
PPBlockShape_NK
{},
_1
{}));
auto
ilvd_NKbNbKL_to_offset
=
...
...
csrc/quantization/machete/machete_prepack_launcher.cuh
View file @
aeb37c2a
...
...
@@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) {
// Allocate output
torch
::
Tensor
D
=
torch
::
empty_like
(
B
,
{},
at
::
MemoryFormat
::
Contiguous
);
prepack_B
<
PrepackedLayoutB
>
(
stream
,
B_ptr
,
layout_Bt
,
static_cast
<
ElementB
*>
(
D
.
mutable_data_ptr
()));
prepack_B
_template
<
PrepackedLayoutB
>
(
stream
,
B_ptr
,
layout_Bt
,
static_cast
<
ElementB
*>
(
D
.
mutable_data_ptr
()));
return
D
;
};
...
...
csrc/quantization/machete/machete_pytorch.cu
View file @
aeb37c2a
...
...
@@ -2,6 +2,8 @@
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace
machete
{
using
namespace
vllm
;
...
...
@@ -78,14 +80,16 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
}
torch
::
Tensor
prepack_B
(
torch
::
Tensor
const
&
B
,
ScalarTypeTorchPtr
const
&
btype
)
{
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
vllm
::
ScalarTypeTorchPtr
const
&
btype
)
{
return
scalar_type_dispatch
(
*
btype
,
[
&
](
auto
BType
)
{
return
PrepackBDispatcher
<
half_t
,
decltype
(
BType
),
half_t
>::
dispatch
(
B
);
});
#else
TORCH_CHECK
(
false
,
"Machete requires CUDA 12.0 or later"
);
#endif
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"machete_prepack_B"
,
&
prepack_B
);
m
.
impl
(
"machete_gemm"
,
&
gemm
);
m
.
impl
(
"machete_supported_schedules"
,
&
supported_schedules
);
}
};
// namespace machete
csrc/quantization/marlin/dense/marlin_cuda_kernel.cu
View file @
aeb37c2a
...
...
@@ -26,6 +26,7 @@
#include <iostream>
#include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
...
...
@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_gemm"
,
&
marlin_gemm
);
}
csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu
View file @
aeb37c2a
...
...
@@ -30,6 +30,7 @@
#include <iostream>
#include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
...
...
@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return
d
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"marlin_qqq_gemm"
,
&
marlin_qqq_gemm
);
}
csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu
View file @
aeb37c2a
...
...
@@ -28,6 +28,7 @@
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
...
...
@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return
c
;
}
TORCH_LIBRARY_IMPL_EXPAND
(
TORCH_EXTENSION_NAME
,
CUDA
,
m
)
{
m
.
impl
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
);
}
csrc/torch_bindings.cpp
View file @
aeb37c2a
...
...
@@ -167,7 +167,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
def
(
"marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"
);
ops
.
impl
(
"marlin_gemm"
,
torch
::
kCUDA
,
&
marlin_gemm
);
// conditionally compiled so impl in source file
// Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ.
ops
.
def
(
...
...
@@ -175,22 +175,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor b_scales, Tensor workspace, "
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_24_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_24_gemm
);
// conditionally compiled so impl in source file
// Machete (Dense) Optimized Mixed Precision GEMM for Hopper.
ops
.
def
(
"machete_supported_schedules"
,
&
machete
::
supported_schedules
);
ops
.
def
(
"machete_supported_schedules("
" __torch__.torch.classes._core_C.ScalarType btype"
") -> str[]"
);
ops
.
def
(
"machete_gemm(Tensor A, Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype,"
" Tensor? scales, Tensor? zeros, int? group_size,"
" Tensor? C, float? alpha, float? beta, str? schedule)"
"-> Tensor"
);
ops
.
impl
(
"machete_gemm"
,
torch
::
kCUDA
,
&
machete
::
gemm
);
ops
.
def
(
"machete_prepack_B(Tensor B,"
" __torch__.torch.classes._core_C.ScalarType btype)"
"-> Tensor"
);
ops
.
impl
(
"machete_prepack_B"
,
torch
::
kCUDA
,
&
machete
::
prepack_B
);
// conditionally compiled so impl registration is in source file
ops
.
def
(
"permute_cols(Tensor A, Tensor perm) -> Tensor"
);
ops
.
impl
(
"permute_cols"
,
torch
::
kCUDA
,
&
permute_cols
);
...
...
@@ -202,21 +204,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"__torch__.torch.classes._core_C.ScalarType b_q_type, "
"int size_m, int size_n, int size_k, bool is_k_full, "
"bool has_zp, bool use_fp32_reduce) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_gemm"
,
torch
::
kCUDA
,
&
gptq_marlin_gemm
);
// conditionally compiled so impl registration is in source file
// gptq_marlin repack from GPTQ.
ops
.
def
(
"gptq_marlin_repack(Tensor b_q_weight, Tensor perm, "
"SymInt size_k, SymInt size_n, int num_bits) -> Tensor"
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kCUDA
,
&
gptq_marlin_repack
);
ops
.
impl
(
"gptq_marlin_repack"
,
torch
::
kMeta
,
&
gptq_marlin_repack_meta
);
// conditionally compiled so impl registrations are in source file
// awq_marlin repack from AWQ.
ops
.
def
(
"awq_marlin_repack(Tensor b_q_weight, SymInt size_k, "
"SymInt size_n, int num_bits) -> Tensor"
);
ops
.
impl
(
"awq_marlin_repack"
,
torch
::
kCUDA
,
&
awq_marlin_repack
);
ops
.
impl
(
"awq_marlin_repack"
,
torch
::
kMeta
,
&
awq_marlin_repack_meta
);
// conditionally compiled so impl registrations are in source file
// Dequantization for GGML.
ops
.
def
(
"ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"
);
...
...
@@ -237,7 +237,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, "
"Tensor! workspace, int num_bits, int size_m, int size_n, "
"int size_k) -> Tensor"
);
ops
.
impl
(
"fp8_marlin_gemm"
,
torch
::
kCUDA
,
&
fp8_marlin_gemm
);
// conditionally compiled so impl registration is in source file
// marlin_qqq_gemm for QQQ.
ops
.
def
(
...
...
@@ -245,7 +245,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"Tensor s_tok, Tensor s_ch, Tensor s_group, "
"Tensor! workspace, int size_m, int size_n, "
"int size_k) -> Tensor"
);
ops
.
impl
(
"marlin_qqq_gemm"
,
torch
::
kCUDA
,
&
marlin_qqq_gemm
);
// conditionally compiled so impl registration is in source file
// CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column
// quantization, as well as bias
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment