Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
include(FetchContent)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
# #
# Define environment variables for special configurations # Define environment variables for special configurations
...@@ -82,9 +85,39 @@ else() ...@@ -82,9 +85,39 @@ else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.")
endif() endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
#
if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.5.3
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
set(ONEDNN_BUILD_TESTS "OFF")
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS dnnl numa) list(APPEND LIBS numa)
# #
# _C extension # _C extension
......
...@@ -138,10 +138,181 @@ macro(string_to_ver OUT_VER IN_STR) ...@@ -138,10 +138,181 @@ macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
endmacro() endmacro()
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()
#
# Extract unique CUDA architectures from a list of compute capabilities codes in
# the form `<major><minor>[<letter>]`, convert them to the form sort
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
# stores them in `OUT_ARCHES`.
#
# Example:
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
# OUT_ARCHES="7.5;...;9.0"
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
set(_CUDA_ARCHES)
foreach(_ARCH ${CUDA_ARCH_FLAGS})
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string_to_ver(_COMPUTE_VER ${_COMPUTE})
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHES)
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
endfunction()
#
# For a specific file set the `-gencode` flag in compile options conditionally
# for the CUDA language.
#
# Example:
# set_gencode_flag_for_srcs(
# SRCS "foo.cu"
# ARCH "compute_75"
# CODE "sm_75")
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
# `foo.cu` (only for the CUDA language).
#
macro(set_gencode_flag_for_srcs)
set(options)
set(oneValueArgs ARCH CODE)
set(multiValueArgs SRCS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
set_property(
SOURCE ${arg_SRCS}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
)
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
endmacro(set_gencode_flag_for_srcs)
#
# For a list of source files set the `-gencode` flags in the files specific
# compile options (specifically for the CUDA language).
#
# arguments are:
# SRCS: list of source files
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
# that is larger than BUILD_PTX_FOR_ARCH.
#
macro(set_gencode_flags_for_srcs)
set(options)
set(oneValueArgs BUILD_PTX_FOR_ARCH)
set(multiValueArgs SRCS CUDA_ARCHS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_ARCH}"
CODE "sm_${_ARCH}")
endforeach()
if (${arg_BUILD_PTX_FOR_ARCH})
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_PTX_ARCH}"
CODE "compute_${_PTX_ARCH}")
endif()
endif()
endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result.
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
set(_CUDA_ARCHS "9.0a")
endif()
endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# less or eqault to ARCH
foreach(_ARCH ${CUDA_ARCHS})
set(_TMP_ARCH)
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
set(_TMP_ARCH ${_SRC_ARCH})
else()
break()
endif()
endforeach()
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
endif()
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS)
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
# #
# Override the GPU architectures detected by cmake/torch and filter them by # Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`. # `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
# the architectures on a per file basis.
# #
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
# #
...@@ -179,109 +350,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) ...@@ -179,109 +350,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif() endif()
elseif(${GPU_LANG} STREQUAL "CUDA")
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if (NOT _CUDA_ARCH_FLAGS)
message(FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
# Process each `gencode` flag.
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
if (_SM)
set(_SM ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
if (_CODE)
set(_CODE ${CMAKE_MATCH_1})
endif()
# Make sure the virtual architecture can be matched.
if (NOT _COMPUTE)
message(FATAL_ERROR
"Could not determine virtual architecture from: ${_ARCH}.")
endif()
# One of sm_ or compute_ must exist.
if ((NOT _SM) AND (NOT _CODE))
message(FATAL_ERROR
"Could not determine a codegen architecture from: ${_ARCH}.")
endif()
if (_SM)
# -real suffix let CMake to only generate elf code for the kernels.
# we want this, otherwise the added ptx (default) will increase binary size.
set(_VIRT "-real")
set(_CODE_ARCH ${_SM})
else()
# -virtual suffix let CMake to generate ptx code for the kernels.
set(_VIRT "-virtual")
set(_CODE_ARCH ${_CODE})
endif()
# Check if the current version is in the supported arch list.
string_to_ver(_CODE_VER ${_CODE_ARCH})
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
continue()
endif()
# Add it to the arch list.
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
endforeach()
endif() endif()
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
endmacro() endmacro()
# #
......
...@@ -267,13 +267,16 @@ def get_neuron_sdk_version(run_lambda): ...@@ -267,13 +267,16 @@ def get_neuron_sdk_version(run_lambda):
def get_vllm_version(): def get_vllm_version():
try: from vllm import __version__, __version_tuple__
import vllm
return vllm.__version__ + "@" + vllm.__commit__ if __version__ == "dev":
except Exception: return "N/A (dev)"
# old version of vllm does not have __commit__
return 'N/A' if len(__version_tuple__) == 4: # dev build
git_sha = __version_tuple__[-1][1:] # type: ignore
return f"{__version__} (git sha: {git_sha}"
return __version__
def summarize_vllm_build_flags(): def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
......
#pragma once
#define VLLM_IMPLIES(p, q) (!(p) || (q))
...@@ -12,6 +12,11 @@ ...@@ -12,6 +12,11 @@
// could be a macro instead of a literal token. // could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized // REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement. // via python's import statement.
#define REGISTER_EXTENSION(NAME) \ #define REGISTER_EXTENSION(NAME) \
......
...@@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec<FP32Vec8> { ...@@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
}; };
#ifdef __AVX512F__
struct INT32Vec16: public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512i reg;
int32_t values[VEC_ELEM_NUM];
};
__m512i reg;
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
void save(int32_t* ptr) const {
_mm512_storeu_epi32(ptr, reg);
}
void save(int32_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm512_mask_storeu_epi32(ptr, mask, reg);
}
};
#endif
#ifdef __AVX512F__ #ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> { struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16; constexpr static int VEC_ELEM_NUM = 16;
...@@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(__m512 data) : reg(data) {} explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data) explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4( : reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4( _mm512_inserti32x4(
...@@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16 &v)
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const { FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg)); return FP32Vec16(_mm512_mul_ps(reg, b.reg));
} }
...@@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
} }
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(_mm512_min_ps(reg, b.reg));
}
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
}
FP32Vec16 abs() const { FP32Vec16 abs() const {
return FP32Vec16(_mm512_abs_ps(reg)); return FP32Vec16(_mm512_abs_ps(reg));
} }
...@@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> { ...@@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_max() const { return _mm512_reduce_max_ps(reg); } float reduce_max() const { return _mm512_reduce_max_ps(reg); }
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) { template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0); static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
......
This diff is collapsed.
...@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, ...@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b_scales, const torch::Tensor& b_scales,
const c10::optional<torch::Tensor>& bias); const c10::optional<torch::Tensor>& bias);
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const torch::Tensor& azp_adj,
const c10::optional<torch::Tensor>& azp,
const c10::optional<torch::Tensor>& bias);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops // vLLM custom ops
...@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales," " Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()"); " Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif #endif
} }
......
This diff is collapsed.
...@@ -13,6 +13,7 @@ struct ConvParamsBase { ...@@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t; using index_t = uint32_t;
int batch, dim, seqlen, width; int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation; bool silu_activation;
index_t x_batch_stride; index_t x_batch_stride;
...@@ -24,6 +25,7 @@ struct ConvParamsBase { ...@@ -24,6 +25,7 @@ struct ConvParamsBase {
index_t out_c_stride; index_t out_c_stride;
index_t out_l_stride; index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride; index_t conv_state_batch_stride;
index_t conv_state_c_stride; index_t conv_state_c_stride;
index_t conv_state_l_stride; index_t conv_state_l_stride;
...@@ -35,6 +37,10 @@ struct ConvParamsBase { ...@@ -35,6 +37,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr; void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr; void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for // For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor. // the current batch doesn't need to be a contiguous tensor.
...@@ -52,6 +58,11 @@ struct ConvParamsBase { ...@@ -52,6 +58,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride; index_t final_states_batch_stride;
index_t final_states_l_stride; index_t final_states_l_stride;
index_t final_states_c_stride; index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
}; };
......
...@@ -21,6 +21,7 @@ struct SSMParamsBase { ...@@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio; int dim_ngroups_ratio;
bool is_variable_B; bool is_variable_B;
bool is_variable_C; bool is_variable_C;
int64_t pad_slot_id;
bool delta_softplus; bool delta_softplus;
...@@ -54,10 +55,14 @@ struct SSMParamsBase { ...@@ -54,10 +55,14 @@ struct SSMParamsBase {
void *__restrict__ delta_ptr; void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr; void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr; void *__restrict__ out_ptr;
void *__restrict__ x_ptr; void *__restrict__ ssm_states_ptr;
void *__restrict__ z_ptr; void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr; void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
}; };
...@@ -201,7 +206,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u, ...@@ -201,7 +206,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load, typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) { int seqlen) {
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load); auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load( typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
...@@ -217,21 +222,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u, ...@@ -217,21 +222,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
} }
} }
template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}
template<typename Ktraits> template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar, inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
...@@ -240,7 +230,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar, ...@@ -240,7 +230,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
int seqlen) { int seqlen) {
constexpr int kNItems = Ktraits::kNItems; constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems]; typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight); auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
...@@ -263,7 +253,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out, ...@@ -263,7 +253,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
typename Ktraits::input_t write_vals[Ktraits::kNItems]; typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll #pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) { if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store); auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t; using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store( typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
......
This diff is collapsed.
This diff is collapsed.
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment