Commit 6d2051cc authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.3.post1' into v0.6.3.post1-dev

parents 2c7f740a a2c71c54
include(FetchContent)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_CXX_STANDARD 17)
#
# Define environment variables for special configurations
......@@ -82,9 +85,39 @@ else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.")
endif()
#
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
#
if (AVX512_FOUND AND NOT AVX512_DISABLED)
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.5.3
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
set(ONEDNN_LIBRARY_TYPE "STATIC")
set(ONEDNN_BUILD_DOC "OFF")
set(ONEDNN_BUILD_EXAMPLES "OFF")
set(ONEDNN_BUILD_TESTS "OFF")
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
FetchContent_MakeAvailable(oneDNN)
list(APPEND LIBS dnnl)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
list(APPEND LIBS dnnl numa)
list(APPEND LIBS numa)
#
# _C extension
......
......@@ -138,10 +138,181 @@ macro(string_to_ver OUT_VER IN_STR)
string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR})
endmacro()
#
# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in
# `CUDA_ARCH_FLAGS`.
#
# Example:
# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75"
# clear_cuda_arches(CUDA_ARCH_FLAGS)
# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75"
# CMAKE_CUDA_FLAGS="-Wall"
#
macro(clear_cuda_arches CUDA_ARCH_FLAGS)
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
endmacro()
#
# Extract unique CUDA architectures from a list of compute capabilities codes in
# the form `<major><minor>[<letter>]`, convert them to the form sort
# `<major>.<minor>`, dedupes them and then sorts them in ascending order and
# stores them in `OUT_ARCHES`.
#
# Example:
# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a"
# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS)
# OUT_ARCHES="7.5;...;9.0"
function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS)
set(_CUDA_ARCHES)
foreach(_ARCH ${CUDA_ARCH_FLAGS})
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string_to_ver(_COMPUTE_VER ${_COMPUTE})
list(APPEND _CUDA_ARCHES ${_COMPUTE_VER})
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHES)
list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING)
set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE)
endfunction()
#
# For a specific file set the `-gencode` flag in compile options conditionally
# for the CUDA language.
#
# Example:
# set_gencode_flag_for_srcs(
# SRCS "foo.cu"
# ARCH "compute_75"
# CODE "sm_75")
# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for
# `foo.cu` (only for the CUDA language).
#
macro(set_gencode_flag_for_srcs)
set(options)
set(oneValueArgs ARCH CODE)
set(multiValueArgs SRCS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE})
set_property(
SOURCE ${arg_SRCS}
APPEND PROPERTY
COMPILE_OPTIONS "$<$<COMPILE_LANGUAGE:CUDA>:${_FLAG}>"
)
message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}")
endmacro(set_gencode_flag_for_srcs)
#
# For a list of source files set the `-gencode` flags in the files specific
# compile options (specifically for the CUDA language).
#
# arguments are:
# SRCS: list of source files
# CUDA_ARCHS: list of CUDA architectures in the form `<major>.<minor>[letter]`
# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built
# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS
# that is larger than BUILD_PTX_FOR_ARCH.
#
macro(set_gencode_flags_for_srcs)
set(options)
set(oneValueArgs BUILD_PTX_FOR_ARCH)
set(multiValueArgs SRCS CUDA_ARCHS)
cmake_parse_arguments(arg "${options}" "${oneValueArgs}"
"${multiValueArgs}" ${ARGN} )
foreach(_ARCH ${arg_CUDA_ARCHS})
string(REPLACE "." "" _ARCH "${_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_ARCH}"
CODE "sm_${_ARCH}")
endforeach()
if (${arg_BUILD_PTX_FOR_ARCH})
list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH)
if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH})
string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}")
set_gencode_flag_for_srcs(
SRCS ${arg_SRCS}
ARCH "compute_${_PTX_ARCH}"
CODE "compute_${_PTX_ARCH}")
endif()
endif()
endmacro()
#
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
# `<major>.<minor>[letter]` compute the "loose intersection" with the
# `TGT_CUDA_ARCHS` list of gencodes.
# The loose intersection is defined as:
# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} }
# where `<=` is the version comparison operator.
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
# 9.0a to the result.
# The result is stored in `OUT_CUDA_ARCHS`.
#
# Example:
# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a"
# TGT_CUDA_ARCHS="8.0;8.9;9.0"
# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a"
#
function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS)
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
set(_CUDA_ARCHS)
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
set(_CUDA_ARCHS "9.0a")
endif()
endif()
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
# for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is
# less or eqault to ARCH
foreach(_ARCH ${CUDA_ARCHS})
set(_TMP_ARCH)
foreach(_SRC_ARCH ${SRC_CUDA_ARCHS})
if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH)
set(_TMP_ARCH ${_SRC_ARCH})
else()
break()
endif()
endforeach()
if (_TMP_ARCH)
list(APPEND _CUDA_ARCHS ${_TMP_ARCH})
endif()
endforeach()
list(REMOVE_DUPLICATES _CUDA_ARCHS)
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
endfunction()
#
# Override the GPU architectures detected by cmake/torch and filter them by
# `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in
# `GPU_ARCHES`.
# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set
# the architectures on a per file basis.
#
# Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`.
#
......@@ -179,109 +350,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
"None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is"
" supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.")
endif()
elseif(${GPU_LANG} STREQUAL "CUDA")
#
# Setup/process CUDA arch flags.
#
# The torch cmake setup hardcodes the detected architecture flags in
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
# can't modified on a per-target basis.
# So, all the `-gencode` flags need to be extracted and removed from
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
# Since it's not possible to use `target_compiler_options` for adding target
# specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property
# must be used instead. This requires repackaging the architecture flags
# into a format that cmake expects for `CUDA_ARCHITECTURES`.
#
# This is a bit fragile in that it depends on torch using `-gencode` as opposed
# to one of the other nvcc options to specify architectures.
#
# Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override
# detected architectures.
#
message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
# Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS`
string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS
${CMAKE_CUDA_FLAGS})
# Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified
# and passed back via the `CUDA_ARCHITECTURES` property.
string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS
${CMAKE_CUDA_FLAGS})
# If this error is triggered, it might mean that torch has changed how it sets
# up nvcc architecture code generation flags.
if (NOT _CUDA_ARCH_FLAGS)
message(FATAL_ERROR
"Could not find any architecture related code generation flags in "
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()
message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
# Process each `gencode` flag.
foreach(_ARCH ${_CUDA_ARCH_FLAGS})
# For each flag, extract the version number and whether it refers to PTX
# or native code.
# Note: if a regex matches then `CMAKE_MATCH_1` holds the binding
# for that match.
string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH})
if (_COMPUTE)
set(_COMPUTE ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH})
if (_SM)
set(_SM ${CMAKE_MATCH_1})
endif()
string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH})
if (_CODE)
set(_CODE ${CMAKE_MATCH_1})
endif()
# Make sure the virtual architecture can be matched.
if (NOT _COMPUTE)
message(FATAL_ERROR
"Could not determine virtual architecture from: ${_ARCH}.")
endif()
# One of sm_ or compute_ must exist.
if ((NOT _SM) AND (NOT _CODE))
message(FATAL_ERROR
"Could not determine a codegen architecture from: ${_ARCH}.")
endif()
if (_SM)
# -real suffix let CMake to only generate elf code for the kernels.
# we want this, otherwise the added ptx (default) will increase binary size.
set(_VIRT "-real")
set(_CODE_ARCH ${_SM})
else()
# -virtual suffix let CMake to generate ptx code for the kernels.
set(_VIRT "-virtual")
set(_CODE_ARCH ${_CODE})
endif()
# Check if the current version is in the supported arch list.
string_to_ver(_CODE_VER ${_CODE_ARCH})
if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST)
message(STATUS "discarding unsupported CUDA arch ${_VER}.")
continue()
endif()
# Add it to the arch list.
list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}")
endforeach()
endif()
message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}")
endmacro()
#
......
......@@ -267,13 +267,16 @@ def get_neuron_sdk_version(run_lambda):
def get_vllm_version():
try:
import vllm
return vllm.__version__ + "@" + vllm.__commit__
except Exception:
# old version of vllm does not have __commit__
return 'N/A'
from vllm import __version__, __version_tuple__
if __version__ == "dev":
return "N/A (dev)"
if len(__version_tuple__) == 4: # dev build
git_sha = __version_tuple__[-1][1:] # type: ignore
return f"{__version__} (git sha: {git_sha}"
return __version__
def summarize_vllm_build_flags():
# This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc.
......
#pragma once
#define VLLM_IMPLIES(p, q) (!(p) || (q))
......@@ -12,6 +12,11 @@
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE)
// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME
// could be a macro instead of a literal token.
#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \
TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE)
// REGISTER_EXTENSION allows the shared library to be loaded and initialized
// via python's import statement.
#define REGISTER_EXTENSION(NAME) \
......
......@@ -265,6 +265,30 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
#ifdef __AVX512F__
struct INT32Vec16: public Vec<INT32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512i reg;
int32_t values[VEC_ELEM_NUM];
};
__m512i reg;
explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {}
void save(int32_t* ptr) const {
_mm512_storeu_epi32(ptr, reg);
}
void save(int32_t* ptr, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
_mm512_mask_storeu_epi32(ptr, mask, reg);
}
};
#endif
#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
......@@ -283,8 +307,6 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
......@@ -303,6 +325,9 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
explicit FP32Vec16(const INT32Vec16 &v)
: reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
......@@ -333,6 +358,16 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg));
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(_mm512_min_ps(reg, b.reg));
}
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
constexpr uint32_t M = 0xFFFFFFFF;
__mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num));
return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg));
}
FP32Vec16 abs() const {
return FP32Vec16(_mm512_abs_ps(reg));
}
......@@ -341,6 +376,8 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
float reduce_max() const { return _mm512_reduce_max_ps(reg); }
float reduce_min() const { return _mm512_reduce_min_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
......
......@@ -5,25 +5,29 @@ namespace {
template <typename scalar_t>
struct KernelVecType {
using load_vec_type = void;
using azp_adj_load_vec_type = void;
using cvt_vec_type = void;
};
template <>
struct KernelVecType<float> {
using load_vec_type = vec_op::FP32Vec16;
using azp_adj_load_vec_type = vec_op::INT32Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
template <>
struct KernelVecType<c10::BFloat16> {
using load_vec_type = vec_op::BF16Vec16;
using azp_adj_load_vec_type = vec_op::INT32Vec16;
using cvt_vec_type = vec_op::FP32Vec16;
};
#ifdef __AVX512F__
template <typename scalar_t>
template <bool AZP, typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int num_tokens,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
......@@ -37,62 +41,110 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
cvt_vec_t zp_vec;
if constexpr (AZP) {
zp_vec = cvt_vec_t(static_cast<float>(*azp));
}
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
elems_fp32 = elems_fp32 * inv_scale;
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale).clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_fp32 = elems_fp32 * inv_scale;
if (j + vec_elem_num == hidden_size) {
elems_int8.save(output + i * hidden_size + j);
} else {
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + zp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <typename scalar_t>
template <bool AZP, typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, const int num_tokens,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
constexpr float i8_min =
static_cast<float>(std::numeric_limits<int8_t>::min());
constexpr float i8_max =
static_cast<float>(std::numeric_limits<int8_t>::max());
const cvt_vec_t i8_min_vec(i8_min);
const cvt_vec_t i8_max_vec(i8_max);
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t max_abs(0.0);
cvt_vec_t max_value(std::numeric_limits<float>::lowest());
cvt_vec_t min_value(std::numeric_limits<float>::max());
{
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
max_abs = max_abs.max(elems_fp32.abs());
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
}
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
if (j + vec_elem_num == hidden_size) {
max_abs = max_abs.max(elems_fp32.abs());
if constexpr (AZP) {
max_value = max_value.max(elems_fp32);
min_value = min_value.min(elems_fp32);
} else {
max_value = max_value.max(elems_fp32.abs());
}
} else {
max_abs = max_abs.max(elems_fp32.abs(), hidden_size - j);
if constexpr (AZP) {
max_value = max_value.max(elems_fp32, hidden_size - j);
min_value = min_value.min(elems_fp32, hidden_size - j);
} else {
max_value = max_value.max(elems_fp32.abs(), hidden_size - j);
}
}
}
float scale_val = max_abs.reduce_max() / 127.0f;
scale[i] = scale_val;
float scale_val, azp_val;
if constexpr (AZP) {
float max_scalar = max_value.reduce_max();
float min_scalar = min_value.reduce_min();
scale_val = (max_scalar - min_scalar) / 255.0f;
azp_val = std::nearbyint(-128.0f - min_scalar / scale_val);
azp[i] = static_cast<int32_t>(azp_val);
scale[i] = scale_val;
} else {
scale_val = max_value.reduce_max() / 127.0f;
scale[i] = scale_val;
}
const cvt_vec_t inv_scale(1.0 / scale_val);
const cvt_vec_t azp_vec(azp_val);
{
int j = 0;
......@@ -100,6 +152,11 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j);
}
......@@ -107,34 +164,111 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
load_vec_t elems(input + i * hidden_size + j);
cvt_vec_t elems_fp32(elems);
elems_fp32 = (elems_fp32 * inv_scale);
vec_op::INT8Vec16 elems_int8(elems_fp32);
if (j + vec_elem_num == hidden_size) {
elems_int8.save(output + i * hidden_size + j);
} else {
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
if constexpr (AZP) {
elems_fp32 = elems_fp32 + azp_vec;
}
elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec);
vec_op::INT8Vec16 elems_int8(elems_fp32);
elems_int8.save(output + i * hidden_size + j, hidden_size - j);
}
}
}
template <bool Bias, typename scalar_t>
void dynamic_output_scale_impl(const float* input, scalar_t* output,
const float* scale, const scalar_t* bias,
const int num_tokens, const int hidden_size) {
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
cvt_vec_t a_scale_vec(a_scale);
cvt_vec_t b_scale_vec(*b_scale);
cvt_vec_t scale_vec = a_scale_vec * b_scale_vec;
int j = 0;
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j);
}
cvt_vec_t elems_fp32(input + i * hidden_size + j);
azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
if constexpr (PerChannel) {
b_scale_vec = cvt_vec_t(b_scale + j);
scale_vec = b_scale_vec * a_scale_vec;
}
elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32;
load_vec_t elems_out(elems_fp32);
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
template <bool AZP, bool PerChannel, bool Bias, typename scalar_t>
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
const int32_t* azp, const int32_t* azp_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue)
using load_vec_t = typename KernelVecType<scalar_t>::load_vec_type;
using azp_adj_load_vec_t =
typename KernelVecType<scalar_t>::azp_adj_load_vec_type;
using cvt_vec_t = typename KernelVecType<scalar_t>::cvt_vec_type;
constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM;
#pragma omp parallel for
for (int i = 0; i < num_tokens; ++i) {
int j = 0;
cvt_vec_t token_scale_vec(scale[i]);
cvt_vec_t token_scale_vec(a_scale[i]);
cvt_vec_t token_zp_scale_vec;
if constexpr (AZP) {
float zp_scale_val = a_scale[i] * static_cast<float>(azp[i]);
if constexpr (!PerChannel) {
zp_scale_val *= *b_scale;
}
token_zp_scale_vec = cvt_vec_t(zp_scale_val);
}
for (; j < hidden_size - vec_elem_num; j += vec_elem_num) {
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
......@@ -148,6 +282,19 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
cvt_vec_t elems_fp32(input + i * hidden_size + j);
elems_fp32 = elems_fp32 * token_scale_vec;
if constexpr (AZP) {
azp_adj_load_vec_t azp_adj_vec(azp_adj + j);
cvt_vec_t azp_adj_fp32(azp_adj_vec);
azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec;
if constexpr (PerChannel) {
cvt_vec_t b_scale_vec(b_scale + j);
azp_adj_fp32 = azp_adj_fp32 * b_scale_vec;
}
elems_fp32 = elems_fp32 - azp_adj_fp32;
}
if constexpr (Bias) {
load_vec_t bias_vec(bias + j);
cvt_vec_t bias_vec_fp32(bias_vec);
......@@ -155,32 +302,41 @@ void dynamic_output_scale_impl(const float* input, scalar_t* output,
}
load_vec_t elems_out(elems_fp32);
if (j + vec_elem_num == hidden_size) {
elems_out.save(output + i * hidden_size + j);
} else {
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
elems_out.save(output + i * hidden_size + j, hidden_size - j);
}
}
#else
template <typename scalar_t>
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
const float* scale, const int num_tokens,
const float* scale, const int32_t* azp,
const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.")
}
template <typename scalar_t>
void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
float* scale, const int num_tokens,
float* scale, int32_t* azp,
const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.")
}
template <bool PerChannel, typename scalar_t>
void static_quant_epilogue(const float* input, scalar_t* output,
const float a_scale, const float* b_scale,
const int32_t* azp_with_adj, const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.")
}
template <typename scalar_t>
void dynamic_output_scale_impl() {
TORCH_CHECK(false, "dynamic_output_scale_impl requires AVX512 support.")
void dynamic_quant_epilogue(const float* input, scalar_t* output,
const float* a_scale, const float* b_scale,
const int32_t* azp, const int32_t* azp_with_adj,
const scalar_t* bias, const int num_tokens,
const int hidden_size) {
TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.")
}
#endif
} // namespace
......@@ -214,39 +370,52 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
bias->dim() == 1);
}
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "cutlass_scaled_mm", [&] {
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] {
if (a_scales.numel() != 1) {
// per-token
// Note: oneDNN doesn't support per-token activation quantization
// Ideally we want to fuse the GEMM and the scale procedure with oneDNN
// JIT, the intermediate data is cached in registers or L1. But for now
// the oneDNN GEMM code generation only supports two quantization
// patterns: per-tensor or per-output-channel of weight.
// So we have to apply the per-token scale with a 'epilogue'. In C=s_a *
// s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN
// GEMM, then the per-token scale (and bias) is applied with the epilogue
// C=s_a * C_inter + bias.
torch::Tensor tmp_fp32_out =
torch::empty_like(c, ::at::ScalarType::Float);
DNNLPrimitiveHelper<true>::gemm_s8s8_jit(
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), (void*)(0), a.size(0), b.size(1),
a.size(1), (float*)(0), b_scales.data_ptr<float>(), 0,
b_scales.numel());
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
if (bias.has_value()) {
dynamic_output_scale_impl<true>(
// Compute C=s_a * C_inter + bias
dynamic_quant_epilogue<false, true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), bias->data_ptr<scalar_t>(), c.size(0),
c.size(1));
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr,
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
dynamic_output_scale_impl<false>(
// Compute C=s_a * C_inter
dynamic_quant_epilogue<false, true, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), (scalar_t*)(0), c.size(0), c.size(1));
a_scales.data_ptr<float>(), nullptr, nullptr, nullptr, nullptr,
c.size(0), c.size(1));
}
} else {
// per-tensor
if (bias.has_value()) {
// Compute C=s_a * s_b * (A@B) + bias
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
bias->data_ptr<scalar_t>(), a.size(0), b.size(1), a.size(1),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
a_scales.numel(), b_scales.numel());
} else {
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
// Compute C=s_a * s_b * (A@B)
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<scalar_t, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(), c.data_ptr<scalar_t>(),
(void*)(0), a.size(0), b.size(1), a.size(1),
nullptr, a.size(0), b.size(1), a.size(1),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
a_scales.numel(), b_scales.numel());
}
......@@ -254,6 +423,127 @@ void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major
});
}
void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major
const torch::Tensor& a, // [M, IC], row-major
const torch::Tensor& b, // [IC, OC], column-major
const torch::Tensor& a_scales, // [1] or [M]
const torch::Tensor& b_scales, // [1] or [OC]
const torch::Tensor& azp_adj, // [OC]
const c10::optional<torch::Tensor>& azp, // [1] or [M]
const c10::optional<torch::Tensor>& bias // [OC]
) {
CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp)
// Checks for conformality
TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8,
"int8_scaled_mm_azp only supports INT8 inputs.")
TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2);
TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) &&
b.size(1) == c.size(1));
TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0));
TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1));
// Check for strides and alignment
TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major
TORCH_CHECK(b.stride(0) == 1); // Column-major
TORCH_CHECK(c.stride(0) % 16 == 0 &&
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous());
}
if (azp) {
TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous());
}
TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous());
// azp & bias types
TORCH_CHECK(azp_adj.dtype() == torch::kInt32);
TORCH_CHECK(!azp || azp->dtype() == torch::kInt32);
TORCH_CHECK(!bias || bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] {
torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float);
if (a_scales.numel() != 1) {
// per-token
// Note: oneDNN doesn't support per-token activation quantization
// Compute C_inter=s_b * (A@B)
DNNLPrimitiveHelper<true>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), nullptr, b_scales.data_ptr<float>(), 0, b_scales.numel());
if (bias.has_value()) {
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias
if (b_scales.numel() != 1) {
// Per-Channel
dynamic_quant_epilogue<true, true, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
} else {
// Per-Tensor
dynamic_quant_epilogue<true, false, true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(),
bias->data_ptr<scalar_t>(), c.size(0), c.size(1));
}
} else {
// Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj
if (b_scales.numel() != 1) {
// Per-Channel
dynamic_quant_epilogue<true, true, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
c.size(0), c.size(1));
} else {
// Per-Tensor
dynamic_quant_epilogue<true, false, false, scalar_t>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp->data_ptr<int32_t>(), azp_adj.data_ptr<int32_t>(), nullptr,
c.size(0), c.size(1));
}
}
} else {
// per-tensor
if (bias.has_value()) {
// Compute C_inter=s_a * s_b * (A@B) + bias
DNNLPrimitiveHelper<false>::gemm_s8s8_jit(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), bias->data_ptr<scalar_t>(),
a.size(0), b.size(1), a.size(1), a_scales.data_ptr<float>(),
b_scales.data_ptr<float>(), a_scales.numel(), b_scales.numel());
} else {
// Compute C_inter=s_a * s_b * (A@B)
DNNLPrimitiveHelper<false>::gemm_s8s8_jit<float, void>(
a.data_ptr<int8_t>(), b.data_ptr<int8_t>(),
tmp_fp32_out.data_ptr<float>(), nullptr, a.size(0), b.size(1),
a.size(1), a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
a_scales.numel(), b_scales.numel());
}
// Compute C=C_inter - s_a * s_b * azp_adj
if (b_scales.numel() != 1) {
// Per-Channel
static_quant_epilogue<true>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
} else {
// Per-Tensor
static_quant_epilogue<false>(
tmp_fp32_out.data_ptr<float>(), c.data_ptr<scalar_t>(),
*a_scales.data_ptr<float>(), b_scales.data_ptr<float>(),
azp_adj.data_ptr<int32_t>(), a.size(0), b.size(1));
}
}
});
}
// static-per-tensor quantization.
void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
const torch::Tensor& input, // [..., hidden_size]
......@@ -263,15 +553,22 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(scale.numel() == 1);
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
TORCH_CHECK(!azp.has_value() || azp->numel() == 1);
const int hidden_size = input.size(-1);
const int num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "static_scaled_int8_quant_impl", [&] {
static_scaled_int8_quant_impl(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), num_tokens, hidden_size);
if (azp.has_value()) {
static_scaled_int8_quant_impl<true>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
hidden_size);
} else {
static_scaled_int8_quant_impl<false>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
}
});
}
......@@ -284,14 +581,20 @@ void dynamic_scaled_int8_quant(
CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant)
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
TORCH_CHECK(!azp.has_value(), "Zero point is not supported on CPU.");
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] {
dynamic_scaled_int8_quant_impl(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), num_tokens, hidden_size);
if (azp.has_value()) {
dynamic_scaled_int8_quant_impl<true>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), azp->data_ptr<int32_t>(), num_tokens,
hidden_size);
} else {
dynamic_scaled_int8_quant_impl<false>(
input.data_ptr<scalar_t>(), out.data_ptr<int8_t>(),
scale.data_ptr<float>(), nullptr, num_tokens, hidden_size);
}
});
}
......@@ -11,6 +11,13 @@ void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b_scales,
const c10::optional<torch::Tensor>& bias);
void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a,
const torch::Tensor& b, const torch::Tensor& a_scales,
const torch::Tensor& b_scales,
const torch::Tensor& azp_adj,
const c10::optional<torch::Tensor>& azp,
const c10::optional<torch::Tensor>& bias);
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// vLLM custom ops
......@@ -111,6 +118,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
// quantization.
ops.def(
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales, Tensor azp_adj,"
" Tensor? azp, Tensor? bias) -> ()");
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
#endif
}
......
......@@ -39,8 +39,6 @@
template<typename input_t, typename weight_t>
void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
template <typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
template<typename input_t, typename weight_t>
void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
......@@ -55,8 +53,12 @@ void set_conv_params_fwd(ConvParamsBase &params,
const at::Tensor x,
const at::Tensor weight,
const at::Tensor out,
void* bias_ptr,
bool silu_activation) {
const c10::optional<at::Tensor>& bias,
bool silu_activation,
int64_t pad_slot_id,
const c10::optional<at::Tensor>& query_start_loc = std::nullopt,
const c10::optional<at::Tensor>& cache_indices = std::nullopt,
const c10::optional<at::Tensor>& has_initial_state = std::nullopt) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -65,33 +67,41 @@ void set_conv_params_fwd(ConvParamsBase &params,
params.dim = dim;
params.seqlen = seqlen;
params.width = width;
params.pad_slot_id = pad_slot_id;
params.silu_activation = silu_activation;
// Set the pointers and strides.
params.x_ptr = x.data_ptr();
params.weight_ptr = weight.data_ptr();
params.bias_ptr = bias_ptr;
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr();
// All stride are in elements, not bytes.
params.x_batch_stride = x.stride(0);
params.x_c_stride = x.stride(1);
params.x_l_stride = x.stride(-1);
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
const bool varlen = params.query_start_loc_ptr != nullptr;
params.x_batch_stride = x.stride(varlen ? 1 : 0);
params.x_c_stride = x.stride(varlen ? 0 : 1);
params.x_l_stride = x.stride(varlen ? 1 : -1);
params.weight_c_stride = weight.stride(0);
params.weight_width_stride = weight.stride(1);
params.out_batch_stride = out.stride(0);
params.out_c_stride = out.stride(1);
params.out_l_stride = out.stride(-1);
params.out_batch_stride = out.stride(varlen ? 1 : 0);
params.out_c_stride = out.stride(varlen ? 0 : 1);
params.out_l_stride = out.stride(varlen ? 1 : -1);
}
at::Tensor
causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
const c10::optional<at::Tensor> &seq_idx_,
const c10::optional<at::Tensor> &initial_states_,
const c10::optional<at::Tensor> &final_states_out_,
bool silu_activation) {
const c10::optional<at::Tensor> &conv_states,
const c10::optional<at::Tensor> &query_start_loc,
const c10::optional<at::Tensor> &cache_indices,
const c10::optional<at::Tensor> &has_initial_state,
bool silu_activation,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -99,24 +109,22 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
TORCH_CHECK(x.is_cuda());
TORCH_CHECK(weight.is_cuda());
const bool varlen = query_start_loc.has_value() ? true : false;
const auto sizes = x.sizes();
const int batch_size = sizes[0];
const int dim = sizes[1];
const int seqlen = sizes[2];
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int width = weight.size(-1);
CHECK_SHAPE(x, batch_size, dim, seqlen);
if (varlen){
CHECK_SHAPE(x, dim, seqlen);
}
else {
CHECK_SHAPE(x, batch_size, dim, seqlen);
}
CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
if (is_channel_last) {
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
}
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
if (bias_.has_value()) {
auto bias = bias_.value();
......@@ -126,56 +134,51 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
CHECK_SHAPE(bias, dim);
}
if (seq_idx_.has_value()) {
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
auto seq_idx = seq_idx_.value();
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
TORCH_CHECK(seq_idx.is_cuda());
TORCH_CHECK(seq_idx.is_contiguous());
CHECK_SHAPE(seq_idx, batch_size, seqlen);
}
at::Tensor out = torch::empty_like(x);
if (has_initial_state.has_value()) {
auto has_initial_state_ = has_initial_state.value();
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
}
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
silu_activation);
if (seq_idx_.has_value()) {
params.seq_idx_ptr = seq_idx_.value().data_ptr();
} else {
params.seq_idx_ptr = nullptr;
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}
if (initial_states_.has_value()) {
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
auto initial_states = initial_states_.value();
TORCH_CHECK(initial_states.scalar_type() == input_type);
TORCH_CHECK(initial_states.is_cuda());
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
TORCH_CHECK(initial_states.stride(1) == 1);
params.initial_states_ptr = initial_states.data_ptr();
params.initial_states_batch_stride = initial_states.stride(0);
params.initial_states_c_stride = initial_states.stride(1);
params.initial_states_l_stride = initial_states.stride(2);
} else {
params.initial_states_ptr = nullptr;
if (cache_indices.has_value()) {
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
}
if (final_states_out_.has_value()) {
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
auto final_states = final_states_out_.value();
TORCH_CHECK(final_states.scalar_type() == input_type);
TORCH_CHECK(final_states.is_cuda());
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
TORCH_CHECK(final_states.stride(1) == 1);
params.final_states_ptr = final_states.data_ptr();
params.final_states_batch_stride = final_states.stride(0);
params.final_states_c_stride = final_states.stride(1);
params.final_states_l_stride = final_states.stride(2);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id,
query_start_loc,
cache_indices,
has_initial_state
);
if (conv_states.has_value()) {
auto conv_states_ = conv_states.value();
TORCH_CHECK(conv_states_.scalar_type() == input_type);
TORCH_CHECK(conv_states_.is_cuda());
params.conv_states_ptr = conv_states_.data_ptr();
params.conv_states_batch_stride = conv_states_.stride(0);
params.conv_states_c_stride = conv_states_.stride(1);
params.conv_states_l_stride = conv_states_.stride(2);
} else {
params.final_states_ptr = nullptr;
params.conv_states_ptr = nullptr;
}
// Otherwise the kernel will be launched from cuda:0 device
......@@ -183,23 +186,21 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
auto stream = at::cuda::getCurrentCUDAStream().stream();
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
if (!is_channel_last) {
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
} else {
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
}
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
});
return out;
}
at::Tensor
causal_conv1d_update(const at::Tensor &x,
void causal_conv1d_update(const at::Tensor &x,
const at::Tensor &conv_state,
const at::Tensor &weight,
const c10::optional<at::Tensor> &bias_,
bool silu_activation,
const c10::optional<at::Tensor> &conv_state_indices_) {
const c10::optional<at::Tensor> &cache_seqlens_,
const c10::optional<at::Tensor> &conv_state_indices_,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = x.scalar_type();
auto weight_type = weight.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -214,9 +215,12 @@ causal_conv1d_update(const at::Tensor &x,
const auto sizes = x.sizes();
const int batch_size = sizes[0];
const int dim = sizes[1];
const int seqlen = sizes[2];
const int width = weight.size(-1);
const int conv_state_len = conv_state.size(2);
TORCH_CHECK(conv_state_len >= width - 1);
CHECK_SHAPE(x, batch_size, dim);
CHECK_SHAPE(x, batch_size, dim, seqlen);
CHECK_SHAPE(weight, dim, width);
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
......@@ -229,18 +233,31 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(bias, dim);
}
at::Tensor out = torch::empty_like(x);
at::Tensor out = x;
ConvParamsBase params;
set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out,
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
silu_activation);
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
bias_,
silu_activation,
pad_slot_id);
params.conv_state_ptr = conv_state.data_ptr();
params.conv_state_len = conv_state_len;
// All stride are in elements, not bytes.
params.conv_state_batch_stride = conv_state.stride(0);
params.conv_state_c_stride = conv_state.stride(1);
params.conv_state_l_stride = conv_state.stride(2);
if (cache_seqlens_.has_value()) {
auto cache_seqlens = cache_seqlens_.value();
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
TORCH_CHECK(cache_seqlens.is_cuda());
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
CHECK_SHAPE(cache_seqlens, batch_size);
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
} else {
params.cache_seqlens = nullptr;
}
if (conv_state_indices_.has_value()) {
auto conv_state_indices = conv_state_indices_.value();
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
......@@ -249,11 +266,11 @@ causal_conv1d_update(const at::Tensor &x,
CHECK_SHAPE(conv_state_indices, batch_size);
int conv_state_entries = conv_state.size(0);
CHECK_SHAPE(conv_state, conv_state_entries, dim, width);
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
} else {
CHECK_SHAPE(conv_state, batch_size, dim, width);
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
params.conv_state_indices_ptr = nullptr;
}
......@@ -264,7 +281,6 @@ causal_conv1d_update(const at::Tensor &x,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
});
return out;
}
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
......@@ -296,7 +312,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts;
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
......@@ -309,20 +325,42 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
const bool kVarlen = params.query_start_loc_ptr != nullptr;
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
+ channel_id * params.x_c_stride;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
if (tidx == 0) {
input_t zeros[kNElts] = {0};
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
input_t initial_state[kNElts] = {0};
if (has_initial_state) {
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
}
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
}
float weight_vals[kWidth];
......@@ -330,14 +368,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
constexpr int kChunkSize = kNThreads * kNElts;
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t x_vals_load[2 * kNElts] = {0};
if constexpr(kIsVecLoad) {
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
} else {
__syncthreads();
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
}
x += kChunkSize;
__syncthreads();
......@@ -375,19 +413,57 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
#pragma unroll
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
if constexpr(kIsVecLoad) {
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
} else {
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
}
out += kChunkSize;
}
// Final state is stored in the smem_exchange last token slot,
// in case seqlen < kWidth, we would need to take the final state from the
// initial state which is stored in conv_states
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
// and load it into conv_state accordingly
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
if (conv_states != nullptr && tidx == last_thread) {
input_t x_vals_load[kNElts * 2] = {0};
// in case we are on the first kWidth tokens
if (last_thread == 0 && seqlen < kWidth){
// Need to take the initial state
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
const int offset = seqlen - (kWidth - 1);
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
// pad the existing state
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
}
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
if (offset + w >= 0)
conv_states[w] = x_vals_load[offset + w ];
}
}
else {
// in case the final state is in between the threads data
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
#pragma unroll
for (int w = 0; w < kWidth - 1; ++w){
conv_states[w] = x_vals_load[offset + w ];
}
}
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
const bool kVarlen = params.query_start_loc_ptr != nullptr;
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize;
dim3 grid(params.batch, params.dim);
......@@ -422,220 +498,11 @@ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
}
}
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
struct Causal_conv1d_channellast_fwd_kernel_traits {
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
using input_t = input_t_;
using weight_t = weight_t_;
static constexpr int kNThreads = kNThreads_;
static_assert(kNThreads % 32 == 0);
static constexpr int kNWarps = kNThreads / 32;
static constexpr int kWidth = kWidth_;
static constexpr int kChunkSizeL = kChunkSizeL_;
static constexpr int kNBytes = sizeof(input_t);
static_assert(kNBytes == 2 || kNBytes == 4);
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
static constexpr int kNEltsPerRow = 128 / kNBytes;
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
static constexpr bool kIsVecLoad = kIsVecLoad_;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
// sizeof(typename BlockStoreT::TempStorage)});
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
};
template<typename Ktraits, bool kHasSeqIdx>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts;
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
// Shared memory.
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
const int batch_id = blockIdx.x;
const int chunk_l_id = blockIdx.y;
const int chunk_c_id = blockIdx.z;
const int tid = threadIdx.x;
const int l_idx = tid / kNThreadsPerC;
const int c_idx = tid % kNThreadsPerC;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
// from the previous L-chunk.
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
#pragma unroll
for (int l = 0; l < Ktraits::kNLoads; ++l) {
input_t x_vals_load[kNElts] = {0};
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
}
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
}
// Load the elements from the previous chunk that are needed for convolution.
if (l_idx < kWidth - 1) {
input_t x_vals_load[kNElts] = {0};
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
} else if (initial_states != nullptr
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
}
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
}
__syncthreads();
if (final_states != nullptr
&& l_idx < kWidth - 1
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
}
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
static_assert(kNThreadsPerRow <= 32);
const int row_idx = tid / kNThreadsPerRow;
const int col_idx = tid % kNThreadsPerRow;
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
float weight_vals[kWidth] = {0};
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
}
}
float x_vals[kWidth - 1 + kLPerThread];
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
}
int seq_idx_thread[kWidth - 1 + kLPerThread];
if constexpr (kHasSeqIdx) {
#pragma unroll
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
}
}
float out_vals[kLPerThread];
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) {
out_vals[i] = bias_val;
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
#pragma unroll
for (int w = 0; w < kWidth; ++w) {
if constexpr (!kHasSeqIdx) {
out_vals[i] += weight_vals[w] * x_vals[i + w];
} else {
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
}
}
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
}
__syncthreads();
#pragma unroll
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
__syncthreads();
#pragma unroll
for (int l = 0; l < Ktraits::kNLoads; ++l) {
input_t out_vals_store[kNElts];
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
}
}
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
// constexpr int kSmemSize = Ktraits::kSmemSize;
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
dim3 block(Ktraits::kNThreads);
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
// if (kSmemSize >= 48 * 1024) {
// C10_CUDA_CHECK(cudaFuncSetAttribute(
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
// }
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
template<typename input_t, typename weight_t>
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
if (params.width == 2) {
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
} else if (params.width == 3) {
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
} else if (params.width == 4) {
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
}
}
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
///////
......@@ -649,7 +516,7 @@ struct Causal_conv1d_update_kernel_traits {
static_assert(kNBytes == 2 || kNBytes == 4);
};
template<typename Ktraits>
template<typename Ktraits, bool kIsCircularBuffer>
__global__ __launch_bounds__(Ktraits::kNThreads)
void causal_conv1d_update_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
......@@ -660,6 +527,8 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int tidx = threadIdx.x;
const int batch_id = blockIdx.x;
const int channel_id = blockIdx.y * kNThreads + tidx;
if (channel_id >= params.dim) return;
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
+ channel_id * params.x_c_stride;
......@@ -668,6 +537,10 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
? batch_id
: params.conv_state_indices_ptr[batch_id];
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
if (conv_state_batch_coord == params.pad_slot_id){
return;
}
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
+ conv_state_batch_coord * params.conv_state_batch_stride
+ channel_id * params.conv_state_c_stride;
......@@ -675,35 +548,70 @@ void causal_conv1d_update_kernel(ConvParamsBase params) {
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
+ channel_id * params.out_c_stride;
float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
int state_len = params.conv_state_len;
int advance_len = params.seqlen;
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
int update_idx = cache_seqlen - (kWidth - 1);
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
float weight_vals[kWidth] = {0};
if (channel_id < params.dim) {
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
}
#pragma unroll
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
float x_vals[kWidth] = {0};
if (channel_id < params.dim) {
if constexpr (!kIsCircularBuffer) {
#pragma unroll 2
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
}
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); }
x_vals[kWidth - 1] = float(x[0]);
for (int i = 0; i < kWidth - 1; ++i) {
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
}
x_vals[i] = float(state_val);
}
} else {
#pragma unroll
for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); }
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
x_vals[i] = float(state_val);
}
}
#pragma unroll 2
for (int i = 0; i < params.seqlen; ++i) {
input_t x_val = x[i * params.x_l_stride];
if constexpr (!kIsCircularBuffer) {
if (i < advance_len && state_len - advance_len + i >= 0) {
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
}
} else {
conv_state[update_idx * params.conv_state_l_stride] = x_val;
++update_idx;
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
}
x_vals[kWidth - 1] = float(x_val);
float out_val = bias_val;
#pragma unroll
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
out[i * params.out_l_stride] = input_t(out_val);
// Shift the input buffer by 1
#pragma unroll
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
}
float out_val = bias_val;
#pragma unroll
for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; }
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
if (channel_id < params.dim) { out[0] = input_t(out_val); }
}
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
auto kernel = &causal_conv1d_update_kernel<Ktraits>;
auto kernel = params.cache_seqlens == nullptr
? &causal_conv1d_update_kernel<Ktraits, false>
: &causal_conv1d_update_kernel<Ktraits, true>;
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......
......@@ -13,6 +13,7 @@ struct ConvParamsBase {
using index_t = uint32_t;
int batch, dim, seqlen, width;
int64_t pad_slot_id;
bool silu_activation;
index_t x_batch_stride;
......@@ -24,6 +25,7 @@ struct ConvParamsBase {
index_t out_c_stride;
index_t out_l_stride;
int conv_state_len;
index_t conv_state_batch_stride;
index_t conv_state_c_stride;
index_t conv_state_l_stride;
......@@ -35,6 +37,10 @@ struct ConvParamsBase {
void *__restrict__ out_ptr;
void *__restrict__ conv_state_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ has_initial_state_ptr;
void *__restrict__ cache_indices_ptr;
int32_t *__restrict__ cache_seqlens;
// For the continuous batching case. Makes it so that the mamba state for
// the current batch doesn't need to be a contiguous tensor.
......@@ -52,6 +58,11 @@ struct ConvParamsBase {
index_t final_states_batch_stride;
index_t final_states_l_stride;
index_t final_states_c_stride;
void * conv_states_ptr;
index_t conv_states_batch_stride;
index_t conv_states_l_stride;
index_t conv_states_c_stride;
};
......
......@@ -21,6 +21,7 @@ struct SSMParamsBase {
int dim_ngroups_ratio;
bool is_variable_B;
bool is_variable_C;
int64_t pad_slot_id;
bool delta_softplus;
......@@ -54,10 +55,14 @@ struct SSMParamsBase {
void *__restrict__ delta_ptr;
void *__restrict__ delta_bias_ptr;
void *__restrict__ out_ptr;
void *__restrict__ x_ptr;
void *__restrict__ ssm_states_ptr;
void *__restrict__ z_ptr;
void *__restrict__ out_z_ptr;
void *__restrict__ index_ptr;
void *__restrict__ query_start_loc_ptr;
void *__restrict__ cache_indices_ptr;
void *__restrict__ has_initial_state_ptr;
};
......@@ -201,7 +206,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
typename Ktraits::input_t (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadT::TempStorage &smem_load,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_load);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(
......@@ -217,21 +222,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u,
}
}
template<typename Ktraits>
inline __device__ void load_index(int *u,
int (&u_vals)[Ktraits::kNItems],
typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index,
int seqlen) {
if constexpr (Ktraits::kIsEvenLen) {
auto& smem_load_index_vec = reinterpret_cast<typename Ktraits::BlockLoadIndexVecT::TempStorage&>(smem_load_index);
Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load(
reinterpret_cast<uint4*>(u),
reinterpret_cast<uint4(&)[Ktraits::kNLoadsIndex]>(u_vals)
);
} else {
Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0);
}
}
template<typename Ktraits>
inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
......@@ -240,7 +230,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar,
int seqlen) {
constexpr int kNItems = Ktraits::kNItems;
typename Ktraits::input_t B_vals_load[kNItems];
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_load_weight_vec = reinterpret_cast<typename Ktraits::BlockLoadWeightVecT::TempStorage&>(smem_load_weight);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load(
......@@ -263,7 +253,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out,
typename Ktraits::input_t write_vals[Ktraits::kNItems];
#pragma unroll
for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; }
if constexpr (Ktraits::kIsEvenLen) {
if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) {
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_store);
using vec_t = typename Ktraits::vec_t;
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(
......
......@@ -23,7 +23,7 @@
template<int kNThreads_, int kNItems_, int kNRows_, bool kIsEvenLen_,
bool kIsVariableB_, bool kIsVariableC_,
bool kHasZ_, bool kUseIndex_, typename input_t_, typename weight_t_>
bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_>
struct Selective_Scan_fwd_kernel_traits {
static_assert(kNItems_ % 4 == 0);
using input_t = input_t_;
......@@ -38,22 +38,19 @@ struct Selective_Scan_fwd_kernel_traits {
static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems);
static_assert(kNItems % kNElts == 0);
static constexpr int kNLoads = kNItems / kNElts;
static constexpr bool kIsEvenLen = kIsEvenLen_;
static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_;
static constexpr bool kIsVariableB = kIsVariableB_;
static constexpr bool kIsVariableC = kIsVariableC_;
static constexpr bool kHasZ = kHasZ_;
static constexpr bool kUseIndex = kUseIndex_;
static constexpr bool kVarlen = kVarlen_;
static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1;
static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1;
static constexpr int kNLoadsIndex = kNItems / 4;
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
using scan_t = float2;
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads,
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
using BlockLoadIndexT = cub::BlockLoad<int, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadIndexVecT = cub::BlockLoad<uint4, kNThreads, kNLoadsIndex,
!(kIsEvenLen && kNLoadsIndex == 1) ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
using BlockLoadWeightT = cub::BlockLoad<input_t, kNThreads, kNItems , cub::BLOCK_LOAD_WARP_TRANSPOSE>;
using BlockLoadWeightVecT = cub::BlockLoad<vec_t, kNThreads, kNLoads ,
!kDirectIO ? cub::BLOCK_LOAD_WARP_TRANSPOSE : cub::BLOCK_LOAD_DIRECT>;
......@@ -65,8 +62,6 @@ struct Selective_Scan_fwd_kernel_traits {
using BlockScanT = cub::BlockScan<scan_t, kNThreads, cub::BLOCK_SCAN_WARP_SCANS>;
static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage),
sizeof(typename BlockLoadVecT::TempStorage),
sizeof(typename BlockLoadIndexT::TempStorage),
sizeof(typename BlockLoadIndexVecT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage),
(int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage),
sizeof(typename BlockStoreT::TempStorage),
......@@ -80,7 +75,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
constexpr bool kIsVariableB = Ktraits::kIsVariableB;
constexpr bool kIsVariableC = Ktraits::kIsVariableC;
constexpr bool kHasZ = Ktraits::kHasZ;
constexpr bool kUseIndex = Ktraits::kUseIndex;
constexpr bool kVarlen = Ktraits::kVarlen;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNItems = Ktraits::kNItems;
constexpr int kNRows = Ktraits::kNRows;
......@@ -97,7 +92,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// auto& smem_load = reinterpret_cast<typename BlockLoadT::TempStorage&>(smem_loadstorescan);
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
auto& smem_load_weight = reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage&>(smem_);
auto& smem_load_index = reinterpret_cast<typename Ktraits::BlockLoadIndexT::TempStorage&>(smem_);
auto& smem_load_weight1 = *reinterpret_cast<typename Ktraits::BlockLoadWeightT::TempStorage*>(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage));
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
auto& smem_scan = *reinterpret_cast<typename Ktraits::BlockScanT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
......@@ -108,17 +102,33 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
const int batch_id = blockIdx.x;
const int dim_id = blockIdx.y;
const int group_id = dim_id / (params.dim_ngroups_ratio);
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + batch_id * params.u_batch_stride
int seqlen = params.seqlen;
int sequence_start_index = batch_id;
if constexpr (kVarlen){
int *query_start_loc = reinterpret_cast<int *>(params.query_start_loc_ptr);
sequence_start_index = query_start_loc[batch_id];
seqlen = query_start_loc[batch_id + 1] - sequence_start_index;
}
const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
: reinterpret_cast<int *>(params.cache_indices_ptr);
const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
// cache_index == params.pad_slot_id is defined as padding, so we exit early
if (cache_index == params.pad_slot_id){
return;
}
input_t *u = reinterpret_cast<input_t *>(params.u_ptr) + sequence_start_index * params.u_batch_stride
+ dim_id * kNRows * params.u_d_stride;
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + batch_id * params.delta_batch_stride
input_t *delta = reinterpret_cast<input_t *>(params.delta_ptr) + sequence_start_index * params.delta_batch_stride
+ dim_id * kNRows * params.delta_d_stride;
weight_t *A = reinterpret_cast<weight_t *>(params.A_ptr) + dim_id * kNRows * params.A_d_stride;
weight_t *B = reinterpret_cast<weight_t *>(params.B_ptr) + dim_id * kNRows * params.B_d_stride;
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride;
input_t *Bvar = reinterpret_cast<input_t *>(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride;
weight_t *C = reinterpret_cast<weight_t *>(params.C_ptr) + dim_id * kNRows * params.C_d_stride;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride;
scan_t *x = reinterpret_cast<scan_t *>(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate;
int *index = !kUseIndex ? nullptr :reinterpret_cast<int *>(params.index_ptr) + batch_id * params.seqlen;
input_t *Cvar = reinterpret_cast<input_t *>(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride;
input_t *ssm_states = reinterpret_cast<input_t *>(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate;
float D_val[kNRows] = {0};
if (params.D_ptr != nullptr) {
......@@ -142,9 +152,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// }
constexpr int kChunkSize = kNThreads * kNItems;
for (int chunk = 0; chunk < params.n_chunks; ++chunk) {
const int n_chunks = (seqlen + 2048 - 1) / 2048;
for (int chunk = 0; chunk < n_chunks; ++chunk) {
input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems];
int index_vals_load[kNRows][kNItems];
__syncthreads();
#pragma unroll
......@@ -152,15 +162,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize);
load_input<Ktraits>(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize);
if constexpr (!kDirectIO) { __syncthreads(); }
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize);
if constexpr (kUseIndex) {
load_index<Ktraits>(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize);
}
}
if constexpr (kUseIndex) {
index += kChunkSize;
load_input<Ktraits>(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize);
}
u += kChunkSize;
delta += kChunkSize;
......@@ -195,9 +199,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// If both B and C vary, this is unused.
weight_t BC_val[kNRows];
weight_t B_vals[kNItems], C_vals[kNItems];
if constexpr (kIsVariableB) {
if constexpr (kIsVariableB) {
load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1));
smem_load_weight, (seqlen - chunk * kChunkSize) * (1));
if constexpr (!kIsVariableC) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -208,7 +212,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (kIsVariableC) {
auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1;
load_weight<Ktraits>(Cvar + state_idx * params.C_dstate_stride, C_vals,
smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 ));
smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 ));
if constexpr (!kIsVariableB) {
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
......@@ -232,24 +236,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]),
!kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]);
// Reset A bar for cumulative sequences (Real)
if constexpr (kUseIndex) {
if (index_vals_load[r][i] == 0) {
thread_data[i].x = 0.f;
}
}
if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) {
if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct
if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) {
thread_data[i] = make_float2(1.f, 0.f);
}
}
}
// Initialize running total
scan_t running_prefix;
// If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read
running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f));
// running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f);
scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0);
SSMScanPrefixCallbackOp<weight_t> prefix_op(running_prefix);
typename Ktraits::BlockScanT(smem_scan).InclusiveScan(
thread_data, thread_data, SSMScanOp<weight_t>(), prefix_op
......@@ -258,7 +254,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
// Unless there's only 1 warp, but then it's the same thread (0) reading and writing.
if (threadIdx.x == 0) {
smem_running_prefix[state_idx] = prefix_op.running_prefix;
x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix;
if (chunk == n_chunks - 1) {
ssm_states[state_idx] = input_t(prefix_op.running_prefix.y);
}
}
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
......@@ -270,7 +268,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
}
}
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
+ dim_id * kNRows * params.out_d_stride + chunk * kChunkSize;
__syncthreads();
#pragma unroll
......@@ -278,26 +276,26 @@ void selective_scan_fwd_kernel(SSMParamsBase params) {
if constexpr (!kDirectIO) {
if (r > 0) { __syncthreads(); }
}
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
store_output<Ktraits>(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
}
if constexpr (kHasZ) {
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + batch_id * params.z_batch_stride
input_t *z = reinterpret_cast<input_t *>(params.z_ptr) + sequence_start_index * params.z_batch_stride
+ dim_id * kNRows * params.z_d_stride + chunk * kChunkSize;
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + batch_id * params.out_z_batch_stride
input_t *out_z = reinterpret_cast<input_t *>(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride
+ dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize;
#pragma unroll
for (int r = 0; r < kNRows; ++r) {
input_t z_vals[kNItems];
__syncthreads();
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize);
load_input<Ktraits>(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize);
#pragma unroll
for (int i = 0; i < kNItems; ++i) {
float z_val = z_vals[i];
out_vals[r][i] *= z_val / (1 + expf(-z_val));
}
__syncthreads();
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize);
store_output<Ktraits>(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize);
}
}
......@@ -316,8 +314,8 @@ void selective_scan_fwd_launch(SSMParamsBase &params, cudaStream_t stream) {
constexpr bool kIsVariableC = true;
constexpr bool kHasZ = true;
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kUseIndex, input_t, weight_t>;
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
dim3 grid(params.batch, params.dim / kNRows);
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
......@@ -393,7 +391,6 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const size_t seqlen,
const size_t dstate,
const size_t n_groups,
const size_t n_chunks,
const bool is_variable_B,
const bool is_variable_C,
// device pointers
......@@ -405,12 +402,16 @@ void set_ssm_params_fwd(SSMParamsBase &params,
const torch::Tensor out,
const torch::Tensor z,
const torch::Tensor out_z,
void* D_ptr,
void* delta_bias_ptr,
void* x_ptr,
const c10::optional<at::Tensor>& D,
const c10::optional<at::Tensor>& delta_bias,
const torch::Tensor ssm_states,
bool has_z,
bool delta_softplus,
void* index_ptr) {
const c10::optional<at::Tensor>& query_start_loc,
const c10::optional<at::Tensor>& cache_indices,
const c10::optional<at::Tensor>& has_initial_state,
bool varlen,
int64_t pad_slot_id) {
// Reset the parameters
memset(&params, 0, sizeof(params));
......@@ -420,8 +421,8 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.seqlen = seqlen;
params.dstate = dstate;
params.n_groups = n_groups;
params.n_chunks = n_chunks;
params.dim_ngroups_ratio = dim / n_groups;
params.pad_slot_id = pad_slot_id;
params.delta_softplus = delta_softplus;
......@@ -434,55 +435,86 @@ void set_ssm_params_fwd(SSMParamsBase &params,
params.A_ptr = A.data_ptr();
params.B_ptr = B.data_ptr();
params.C_ptr = C.data_ptr();
params.D_ptr = D_ptr;
params.delta_bias_ptr = delta_bias_ptr;
params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr;
params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr;
params.out_ptr = out.data_ptr();
params.x_ptr = x_ptr;
params.ssm_states_ptr = ssm_states.data_ptr();
params.z_ptr = has_z ? z.data_ptr() : nullptr;
params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr;
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
params.index_ptr = index_ptr;
// All stride are in elements, not bytes.
params.A_d_stride = A.stride(0);
params.A_dstate_stride = A.stride(1);
if (!is_variable_B) {
params.B_d_stride = B.stride(0);
} else {
params.B_batch_stride = B.stride(0);
params.B_group_stride = B.stride(1);
}
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
if (!is_variable_C) {
params.C_d_stride = C.stride(0);
} else {
params.C_batch_stride = C.stride(0);
params.C_group_stride = C.stride(1);
if (varlen){
params.B_batch_stride = B.stride(2);
params.B_group_stride = B.stride(0);
params.B_dstate_stride = B.stride(1);
params.C_batch_stride = C.stride(2);
params.C_group_stride = C.stride(0);
params.C_dstate_stride = C.stride(1);
params.u_batch_stride = u.stride(1);
params.u_d_stride = u.stride(0);
params.delta_batch_stride = delta.stride(1);
params.delta_d_stride = delta.stride(0);
if (has_z) {
params.z_batch_stride = z.stride(1);
params.z_d_stride = z.stride(0);
params.out_z_batch_stride = out_z.stride(1);
params.out_z_d_stride = out_z.stride(0);
}
params.out_batch_stride = out.stride(1);
params.out_d_stride = out.stride(0);
}
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
params.u_batch_stride = u.stride(0);
params.u_d_stride = u.stride(1);
params.delta_batch_stride = delta.stride(0);
params.delta_d_stride = delta.stride(1);
if (has_z) {
params.z_batch_stride = z.stride(0);
params.z_d_stride = z.stride(1);
params.out_z_batch_stride = out_z.stride(0);
params.out_z_d_stride = out_z.stride(1);
else{
if (!is_variable_B) {
params.B_d_stride = B.stride(0);
} else {
params.B_batch_stride = B.stride(0);
params.B_group_stride = B.stride(1);
}
params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2);
if (!is_variable_C) {
params.C_d_stride = C.stride(0);
} else {
params.C_batch_stride = C.stride(0);
params.C_group_stride = C.stride(1);
}
params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2);
params.u_batch_stride = u.stride(0);
params.u_d_stride = u.stride(1);
params.delta_batch_stride = delta.stride(0);
params.delta_d_stride = delta.stride(1);
if (has_z) {
params.z_batch_stride = z.stride(0);
params.z_d_stride = z.stride(1);
params.out_z_batch_stride = out_z.stride(0);
params.out_z_d_stride = out_z.stride(1);
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);
}
params.out_batch_stride = out.stride(0);
params.out_d_stride = out.stride(1);
}
std::vector<torch::Tensor>
selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C,
const c10::optional<torch::Tensor> &D_,
const c10::optional<torch::Tensor> &z_,
const c10::optional<torch::Tensor> &delta_bias_,
bool delta_softplus,
const c10::optional<torch::Tensor> &index_,
const c10::optional<torch::Tensor> &x) {
const c10::optional<torch::Tensor> &query_start_loc,
const c10::optional<torch::Tensor> &cache_indices,
const c10::optional<torch::Tensor> &has_initial_state,
const torch::Tensor &ssm_states,
// used to identify padding entries if cache_indices provided
// in case of padding, the kernel will return early
int64_t pad_slot_id) {
auto input_type = u.scalar_type();
auto weight_type = A.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
......@@ -505,23 +537,37 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
const auto sizes = u.sizes();
const int batch_size = sizes[0];
const int dim = sizes[1];
const int seqlen = sizes[2];
const bool varlen = query_start_loc.has_value();
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
const int dim = varlen ? sizes[0] : sizes[1];
const int seqlen = varlen ? sizes[1] : sizes[2];
const int dstate = A.size(1);
const int n_groups = is_variable_B ? B.size(1) : 1;
const int n_groups = varlen ? B.size(0) : B.size(1);
TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256");
CHECK_SHAPE(u, batch_size, dim, seqlen);
CHECK_SHAPE(delta, batch_size, dim, seqlen);
if (varlen) {
CHECK_SHAPE(u, dim, seqlen);
CHECK_SHAPE(delta, dim, seqlen);
} else {
CHECK_SHAPE(u, batch_size, dim, seqlen);
CHECK_SHAPE(delta, batch_size, dim, seqlen);
}
CHECK_SHAPE(A, dim, dstate);
TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size")
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen );
if (varlen) {
CHECK_SHAPE(B, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size")
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
if (varlen) {
CHECK_SHAPE(C, n_groups, dstate, seqlen);
} else {
CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen);
}
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
if (D_.has_value()) {
......@@ -539,13 +585,31 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim);
}
if (index_.has_value()) {
auto index = index_.value();
TORCH_CHECK(index.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(index.is_cuda());
CHECK_SHAPE(index, batch_size, seqlen);
if (has_initial_state.has_value()) {
auto has_initial_state_ = has_initial_state.value();
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
TORCH_CHECK(has_initial_state_.is_cuda());
CHECK_SHAPE(has_initial_state_, batch_size);
}
if (query_start_loc.has_value()) {
auto query_start_loc_ = query_start_loc.value();
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(query_start_loc_.is_cuda());
}
if (cache_indices.has_value()) {
auto cache_indices_ = cache_indices.value();
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
TORCH_CHECK(cache_indices_.is_cuda());
CHECK_SHAPE(cache_indices_, batch_size);
}
at::Tensor z, out_z;
const bool has_z = z_.has_value();
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
......@@ -553,32 +617,36 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
CHECK_SHAPE(z, batch_size, dim, seqlen);
out_z = torch::empty_like(z);
if (varlen){
CHECK_SHAPE(z, dim, seqlen);
} else {
CHECK_SHAPE(z, batch_size, dim, seqlen);
}
out_z = z;
const int n_chunks = (seqlen + 2048 - 1) / 2048;
// const int n_chunks = (seqlen + 1024 - 1) / 1024;
// at::Tensor out = torch::empty_like(u);
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
at::Tensor out = torch::empty_like(delta);
if (x.has_value()){
auto _x = x.value();
TORCH_CHECK(_x.scalar_type() == weight_type);
TORCH_CHECK(_x.is_cuda());
TORCH_CHECK(_x.stride(-1) == 1);
CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2);
}
at::Tensor out = delta;
TORCH_CHECK(ssm_states.scalar_type() == input_type);
TORCH_CHECK(ssm_states.is_cuda());
TORCH_CHECK(ssm_states.stride(-1) == 1);
SSMParamsBase params;
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C,
set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C,
u, delta, A, B, C, out, z, out_z,
D_.has_value() ? D_.value().data_ptr() : nullptr,
delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr,
x.value().data_ptr(),
D_,
delta_bias_,
ssm_states,
has_z,
delta_softplus,
index_.has_value() ? index_.value().data_ptr() : nullptr);
query_start_loc,
cache_indices,
has_initial_state,
varlen,
pad_slot_id
);
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
......@@ -586,8 +654,5 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
});
std::vector<at::Tensor> result = {out};
if (has_z) { result.push_back(out_z); }
return result;
}
......@@ -38,6 +38,7 @@ using FragA = Vec<half2, 4>;
using FragB = Vec<half2, 2>;
using FragC = Vec<float, 4>;
using FragS = Vec<half2, 1>; // quantization scales
using FragZP = Vec<half2, 4>;
// Predicated asynchronous global->shared copy; used for inputs A where we apply
// predication to handle batchsizes that are not multiples of 16.
......@@ -175,6 +176,46 @@ __device__ inline FragB dequant<vllm::kU8B128.id()>(int q) {
return frag_b;
}
template <>
__device__ inline FragB dequant<vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline FragB dequant<vllm::kU8.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
......@@ -183,11 +224,10 @@ __device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) {
frag_b[1] = __hmul2(frag_b[1], s);
}
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) {
half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
......@@ -205,6 +245,13 @@ __device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2,
frag_b[1] = __hmul2(frag_b[1], s_val_3_4);
}
// Given 2 floats multiply by 2 scales (halves)
__device__ inline void scale_float(float* c, FragS& s) {
__half* s_ptr = reinterpret_cast<__half*>(&s);
c[0] = __fmul_rn(c[0], __half2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], __half2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
......@@ -248,10 +295,11 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__device__ inline void MarlinMoESingle(
__device__ void MarlinMoESingle(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
......@@ -259,6 +307,8 @@ __device__ inline void MarlinMoESingle(
const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel
......@@ -400,8 +450,12 @@ __device__ inline void MarlinMoESingle(
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
constexpr int sorted_sh_stride = threads;
constexpr int sorted_gl_stride = threads;
// Zero-points sizes/strides
int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
......@@ -442,6 +496,19 @@ __device__ inline void MarlinMoESingle(
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
......@@ -453,23 +520,29 @@ __device__ inline void MarlinMoESingle(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
int shs_size;
if constexpr (has_act_order)
shs_size = sh_max_num_groups * s_sh_stride + threads;
else
shs_size = group_blocks > 0 ? stages * s_sh_stage : threads;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
int* sh_sorted = (int*)(sh_s + shs_size);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
......@@ -525,8 +598,10 @@ __device__ inline void MarlinMoESingle(
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators.
auto zero_accums = [&]() {
......@@ -633,6 +708,28 @@ __device__ inline void MarlinMoESingle(
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
......@@ -640,15 +737,9 @@ __device__ inline void MarlinMoESingle(
cp_async_fence();
};
// TODO we are currently hitting illegal memory accesses when fetching
// sorted_ids to shared data: fix this
auto fetch_sorted_ids_to_shared = [&]() {
const int mpt = ceildiv(prob_m, threads);
for (int i = 0; i < mpt; i++) {
if ((i * sorted_gl_stride) + threadIdx.x < prob_m) {
sh_sorted[(i * sorted_sh_stride) + threadIdx.x] =
sorted_ids[(i * sorted_gl_stride) + threadIdx.x];
}
auto fetch_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
......@@ -799,8 +890,83 @@ __device__ inline void MarlinMoESingle(
}
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = frag_qzp[k % 2][1];
}
frag_zp_0 = dequant<w_type_id>(zp_quant_0);
frag_zp_1 = dequant<w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
......@@ -818,6 +984,10 @@ __device__ inline void MarlinMoESingle(
FragB frag_b0 = dequant<w_type_id>(b_quant_0);
FragB frag_b1 = dequant<w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp(frag_b0, frag_zp[j], 0);
}
// Apply scale to frag_b0
if constexpr (has_act_order) {
......@@ -829,6 +999,11 @@ __device__ inline void MarlinMoESingle(
}
}
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j],
......@@ -1062,9 +1237,6 @@ __device__ inline void MarlinMoESingle(
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
// TODO re-enable after fixing this function
// fetch_sorted_ids_to_shared();
// __syncthreads();
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
......@@ -1075,6 +1247,12 @@ __device__ inline void MarlinMoESingle(
}
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
}
......@@ -1083,6 +1261,7 @@ __device__ inline void MarlinMoESingle(
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
......@@ -1102,6 +1281,7 @@ __device__ inline void MarlinMoESingle(
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
......@@ -1236,7 +1416,9 @@ __device__ inline void MarlinMoESingle(
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
}
}
......@@ -1250,6 +1432,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
......@@ -1261,6 +1444,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel
......@@ -1309,29 +1494,29 @@ __global__ void MarlinMoE(
if (max_block == 1) {
MarlinMoESingle<w_type_id, threads, 1, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 2) {
MarlinMoESingle<w_type_id, threads, 2, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else if (max_block == 3) {
MarlinMoESingle<w_type_id, threads, 3, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
} else {
MarlinMoESingle<w_type_id, threads, 4, thread_n_blocks, thread_k_blocks,
stages, has_act_order, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx,
stages, has_act_order, has_zp, group_blocks>(
A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx,
expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m,
prob_n, prob_k, tot_m, locks, replicate_input, apply_weights,
current_m_block);
......@@ -1347,6 +1532,7 @@ template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
......@@ -1358,6 +1544,8 @@ __global__ void MarlinMoE(
const float* __restrict__ topk_weights, // float topk weights
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
const int* __restrict__ expert_offsets,
int num_groups, // number of scale groups per output channel
......@@ -1374,7 +1562,6 @@ __global__ void MarlinMoE(
int current_m_block, // current m block to start kernel computation from
int max_par, // maximum parallelism
int cfg_max_m_blocks // upper bound on m blocks
) {
// Marlin is not implemented yet for SM < 8.0
assert(false);
......@@ -1389,37 +1576,41 @@ __global__ void MarlinMoE(
const int USER_THREADS =
256; // Note: This is only used with user-provided thread_k/n
const int STAGES = 4; // 4 pipeline stages fit into shared memory
// const int SHARED_MEM =
// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0)
static constexpr int min_thread_n = 64;
static constexpr int min_thread_k = 64;
#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
GROUP_BLOCKS, NUM_THREADS) \
HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
MarlinMoE<W_TYPE.id(), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
STAGES, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
cfg_max_m_blocks); \
}
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
} // namespace marlin_moe
#include "marlin_moe_kernel_ku4.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = true;
if (false) {
}
AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128)
else {
return false;
}
return true;
}
} // namespace marlin_moe
#pragma once
#include "marlin_moe_kernel.h"
namespace marlin_moe {
// We return bool so we can create these different kernel calls as a sequence
// of if-elseif's.
bool call_marlin_moe_kernel_ku4(
vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks,
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
......@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks) {
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256)
......
......@@ -11,10 +11,10 @@ bool call_marlin_moe_kernel_ku4b8(
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks);
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
} // namespace marlin_moe
......@@ -9,11 +9,13 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks) {
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks) {
bool has_zp = false;
if (false) {
}
GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256)
......
......@@ -9,10 +9,10 @@ bool call_marlin_moe_kernel_ku8b128(
bool has_act_order, int group_blocks, int num_threads, int blocks,
int max_shared_mem, cudaStream_t stream, const int4* A_ptr,
const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr,
const float* topk_weights_ptr, const int4* s_ptr, const int* g_idx_ptr,
int* expert_offsets_ptr, int num_groups, int expert_idx, int num_experts,
int topk, int prob_m, int prob_n, int prob_k, int tot_m, int* locks,
bool replicate_input, bool apply_weights, int m_block, int max_par,
int cfg_max_m_blocks);
const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr,
const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups,
int expert_idx, int num_experts, int topk, int prob_m, int prob_n,
int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights,
int m_block, int max_par, int cfg_max_m_blocks);
}
......@@ -25,9 +25,12 @@
#include <iostream>
#include "core/exception.hpp"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#include "marlin_kernels/marlin_moe_kernel_ku4b8.h"
#include "marlin_kernels/marlin_moe_kernel_ku8b128.h"
#include "marlin_kernels/marlin_moe_kernel_ku4.h"
template <typename T>
inline std::string str(T x) {
......@@ -155,6 +158,7 @@ thread_config_t small_batch_thread_configs[] = {
{128, 64, 128}, // Reduce N 2X, same K
{64, 256, 256}, // Reduce K 2X, increase N 2X
{64, 128, 128}, // Reduce K 2X, same N
{64, 64, 128}, // Reduce both 2X
};
thread_config_t large_batch_thread_configs[] = {
......@@ -165,6 +169,7 @@ thread_config_t large_batch_thread_configs[] = {
{128, 128, 256}, // Reduce N 2X, increase K 2X
{64, 128, 128}, // Reduce N 2X, same K
{128, 64, 128}, // Reduce N 4X, increase K 2X
{64, 64, 128}, // Reduce N 4X, same K
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
......@@ -189,7 +194,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int load_groups =
tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K
load_groups = max(load_groups, 32); // We load at least 32 scale groups
return load_groups * tb_n * 2;
return load_groups * tb_n * 4;
} else {
int tb_scales = tb_groups * tb_n * 2;
......@@ -310,27 +315,28 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION(q_type, thread_n_blocks, thread_k_blocks, \
has_act_order, group_blocks, num_threads, blocks, \
max_shared_mem, stream, A_ptr, B_ptr, C_ptr, \
sorted_ids_ptr, topk_weights_ptr, s_ptr, g_idx_ptr, \
expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, \
locks, replicate_input, apply_weights, m_block, \
max_par, exec_cfg.max_m_blocks)) { \
#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \
else if (KERNEL_FUNCTION( \
q_type, thread_n_blocks, thread_k_blocks, has_act_order, \
group_blocks, num_threads, blocks, max_shared_mem, stream, \
A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
replicate_input, apply_weights, m_block, max_par, \
exec_cfg.max_m_blocks)) { \
}
void marlin_mm_moe(const void* A, const void* B, void* C,
const void* sorted_ids, const void* topk_weights,
const void* topk_ids, const void* s, const void* g_idx,
const void* perm, void* a_tmp, void* expert_offsets,
int prob_m, int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, int num_groups, int group_size,
int num_experts, int topk, int moe_block_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par, bool replicate_input, bool apply_weights) {
const void* topk_ids, const void* s, void* zp,
const void* g_idx, const void* perm, void* a_tmp,
void* expert_offsets, int prob_m, int prob_n, int prob_k,
void* workspace, vllm::ScalarType const& q_type,
bool has_act_order, bool is_k_full, bool has_zp,
int num_groups, int group_size, int num_experts, int topk,
int moe_block_size, int dev, cudaStream_t stream,
int thread_k, int thread_n, int sms, int max_par,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
......@@ -433,11 +439,9 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
int4* C_ptr = (int4*)C;
const float* topk_weights_ptr = (const float*)topk_weights;
const int* sorted_ids_ptr = (const int*)sorted_ids;
const int4* s_ptr =
(const int4*)s +
(((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) *
prob_n / 8) *
expert_idx;
const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx;
const int4* zp_ptr =
(const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx;
const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx;
const int* perm_ptr = (const int*)perm + prob_k * expert_idx;
int* locks = (int*)workspace;
......@@ -458,6 +462,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C,
}
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128)
CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
......@@ -477,13 +482,21 @@ torch::Tensor marlin_gemm_moe(
const torch::Tensor& a, const torch::Tensor& b_q_weights,
const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights,
const torch::Tensor& topk_ids, const torch::Tensor& b_scales,
const torch::Tensor& g_idx, const torch::Tensor& perm,
torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full,
int64_t num_experts, int64_t topk, int64_t moe_block_size,
bool replicate_input, bool apply_weights) {
TORCH_CHECK(*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
torch::Tensor& b_zeros, const torch::Tensor& g_idx,
const torch::Tensor& perm, torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk,
int64_t moe_block_size, bool replicate_input, bool apply_weights) {
bool has_zp = b_zeros.size(1) != 0;
if (has_zp) {
TORCH_CHECK(
*b_q_type == vllm::kU4,
"b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits();
......@@ -521,6 +534,9 @@ torch::Tensor marlin_gemm_moe(
" is not size_n = ", size_n);
num_groups = b_scales.size(1);
TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order),
"if is_k_full is false, has_act_order must be true");
if (has_act_order) {
if (is_k_full) {
TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1");
......@@ -542,13 +558,30 @@ torch::Tensor marlin_gemm_moe(
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3");
TORCH_CHECK(b_zeros.size(1) == num_groups,
"b_zeros dim 1 = ", b_zeros.size(1),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor,
"b_zeros dim 2 = ", b_zeros.size(2),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
marlin_moe::marlin_mm_moe(
a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(),
topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(),
expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(),
*b_q_type, has_act_order, is_k_full, num_groups, group_size, num_experts,
topk, moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, max_par, replicate_input, apply_weights);
*b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size,
num_experts, topk, moe_block_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par,
replicate_input, apply_weights);
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm_moe", &marlin_gemm_moe);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment