Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
......@@ -126,7 +126,7 @@ def benchmark_decode(
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
......@@ -136,7 +136,7 @@ def benchmark_decode(
start.record()
fn()
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
......
......@@ -138,7 +138,7 @@ def benchmark_prefill(
)
def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True)
times = []
......@@ -148,7 +148,7 @@ def benchmark_prefill(
start.record()
fn()
end.record()
torch.cuda.synchronize()
torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times))
......
......@@ -177,18 +177,18 @@ def benchmark_config(
def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize()
torch.accelerator.synchronize()
# JIT complication & warmup
for _ in range(5):
run()
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
torch.accelerator.synchronize()
start_event.record()
run()
end_event.record()
......@@ -285,7 +285,7 @@ def tune_on_gpu(args_dict):
weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"]
torch.cuda.set_device(gpu_id)
torch.accelerator.set_device_index(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n
......@@ -334,7 +334,7 @@ def distribute_batch_sizes(batch_sizes, num_gpus):
def main(args):
print(args)
num_gpus = torch.cuda.device_count()
num_gpus = torch.accelerator.device_count()
if num_gpus == 0:
raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning")
......
......@@ -35,7 +35,7 @@ def benchmark_shape(
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16
torch.cuda.synchronize()
torch.accelerator.synchronize()
C_ref = A @ B.t()
# Pre-quantize B for all implementations
......@@ -121,14 +121,14 @@ def benchmark_shape(
# Warmup
for _ in range(warmup):
func()
torch.cuda.synchronize()
torch.accelerator.synchronize()
# Timing loop
torch.cuda.synchronize()
torch.accelerator.synchronize()
start = time.time()
for _ in range(repeat):
func()
torch.cuda.synchronize()
torch.accelerator.synchronize()
end = time.time()
# Calculate timing and TFLOPS
......
......@@ -7,7 +7,7 @@ First start serving your model
```bash
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests
vllm serve $MODEL_PATH --served-model-name Llama
```
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
......
......@@ -71,7 +71,7 @@ while [[ $# -gt 0 ]]; do
usage
;;
*)
echo "Unknown argument: $1\n"
printf "Unknown argument: %s\n" "$1"
usage
;;
esac
......@@ -84,15 +84,17 @@ mkdir -p "$OUTPUT_DIR"
QPS_VALUES=(25 20 15 10 5 1)
# Common parameters
COMMON_PARAMS="--backend $BACKEND \
--model $MODEL \
--dataset $DATASET \
--structured-output-ratio $STRUCTURED_OUTPUT_RATIO \
--save-results \
--result-dir $OUTPUT_DIR \
--output-len $MAX_NEW_TOKENS \
--port $PORT \
--tokenizer-mode $TOKENIZER_MODE"
COMMON_PARAMS=(
--backend "$BACKEND"
--model "$MODEL"
--dataset "$DATASET"
--structured-output-ratio "$STRUCTURED_OUTPUT_RATIO"
--save-results
--result-dir "$OUTPUT_DIR"
--output-len "$MAX_NEW_TOKENS"
--port "$PORT"
--tokenizer-mode "$TOKENIZER_MODE"
)
echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND"
......@@ -109,17 +111,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json"
FILENAME="${BACKEND}_${qps}qps_$(basename "$MODEL")_${DATASET}_${GIT_HASH}_${GIT_BRANCH}.json"
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
echo "Running benchmark with $NUM_PROMPTS prompts"
# Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \
--request-rate $qps \
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" "${COMMON_PARAMS[@]}" \
--request-rate "$qps" \
--result-filename "$FILENAME" \
--num-prompts $NUM_PROMPTS
--num-prompts "$NUM_PROMPTS"
echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------"
......
......@@ -13,27 +13,16 @@ endif()
#
# Define environment variables for special configurations
#
set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2})
set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512})
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86})
set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc")
set (ENABLE_NUMA TRUE)
#
# Check the compile flags
#
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
list(APPEND CXX_COMPILE_FLAGS
"-mf16c"
)
endif()
if(MACOSX_FOUND)
list(APPEND CXX_COMPILE_FLAGS
"-DVLLM_CPU_EXTENSION")
......@@ -77,18 +66,6 @@ function(check_sysctl TARGET OUT)
endif()
endfunction()
function (is_avx512_disabled OUT)
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
set(${OUT} ON PARENT_SCOPE)
else()
set(${OUT} OFF PARENT_SCOPE)
endif()
endfunction()
is_avx512_disabled(AVX512_DISABLED)
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Apple Silicon Detected")
set(APPLE_SILICON_FOUND TRUE)
......@@ -96,84 +73,44 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
check_sysctl(hw.optional.neon ASIMD_FOUND)
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
else()
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "Power11" POWER11_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
find_isa(${CPUINFO} "S390" S390_FOUND)
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support
find_isa(${CPUINFO} "zvfhmin" RVV_FP16_FOUND) # Check for RISC-V Vector FP16 support
find_isa(${CPUINFO} "zvfbfmin" RVV_BF16_FOUND) # Check for RISC-V Vector BF16 support
# Support cross-compilation by allowing override via environment variables
if (ENABLE_AVX2)
set(AVX2_FOUND ON)
message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable")
endif()
if (ENABLE_AVX512)
set(AVX512_FOUND ON)
message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable")
if (ENABLE_ARM_BF16)
set(ARM_BF16_FOUND ON)
message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
endif()
endif()
if (AVX512_FOUND AND NOT AVX512_DISABLED)
list(APPEND CXX_COMPILE_FLAGS
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA)
set(ENABLE_X86_ISA ON)
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3))
message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3")
endif()
list(APPEND CXX_COMPILE_FLAGS "-mf16c")
list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX512
"-mavx512f"
"-mavx512vl"
"-mavx512bw"
"-mavx512dq")
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND)
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
set(ENABLE_AVX512BF16 ON)
else()
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
set(ENABLE_AVX512VNNI ON)
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
list(APPEND CXX_COMPILE_FLAGS_AVX512_AMX
${CXX_COMPILE_FLAGS_AVX512}
"-mamx-bf16"
"-mamx-tile"
"-mavx512bf16"
"-mavx512vnni")
list(APPEND CXX_COMPILE_FLAGS_AVX2
"-mavx2")
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
message(STATUS "PowerPC detected")
if (POWER9_FOUND)
......@@ -208,18 +145,26 @@ elseif (S390_FOUND)
"-march=native"
"-mtune=native")
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
if(RVV_FOUND)
message(FAIL_ERROR "Can't support rvv now.")
message(STATUS "RISC-V detected")
if(RVV_BF16_FOUND)
message(STATUS "BF16 extension detected")
set(MARCH_FLAGS -march=rv64gcv_zvfh_zfbfmin_zvfbfmin_zvl128b -mrvv-vector-bits=zvl -mabi=lp64d)
add_compile_definitions(RISCV_BF16_SUPPORT)
elseif (RVV_FP16_FOUND)
message(WARNING "BF16 functionality is not available")
set(MARCH_FLAGS -march=rv64gcv_zvfh_zvl128b -mrvv-vector-bits=zvl -mabi=lp64d)
else()
message(STATUS "compile riscv with scalar")
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
endif()
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms)
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
# Build oneDNN for GEMM kernels
if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
......@@ -308,13 +253,24 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
)
else()
message(STATUS "Downloading oneDNN from GitHub")
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
if(ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
message(STATUS "aarch64 detected: using pinned oneDNN commit 9c5be1cc59e368aebf0909e6cf20f981ea61462a")
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG 9c5be1cc59e368aebf0909e6cf20f981ea61462a
GIT_PROGRESS TRUE
GIT_SHALLOW FALSE
)
else()
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
endif()
endif()
set(ONEDNN_LIBRARY_TYPE "STATIC")
......@@ -324,13 +280,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "ON")
set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF")
set(ONEDNN_VERBOSE "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "ON")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON")
set(ONEDNN_VERBOSE "ON")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
# TODO: Refactor this
if (ENABLE_X86_ISA)
# Note: only enable oneDNN for AVX512
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512})
else()
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS})
endif()
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN)
......@@ -343,14 +307,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
PRIVATE ${oneDNN_SOURCE_DIR}/src
)
target_link_libraries(dnnl_ext dnnl torch)
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC)
target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC)
list(APPEND LIBS dnnl_ext)
set(USE_ONEDNN ON)
else()
set(USE_ONEDNN OFF)
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
# TODO: Refactor this
if (ENABLE_X86_ISA)
message(STATUS "CPU extension (AVX512F + BF16 + VNNI + AMX) compile flags: ${CXX_COMPILE_FLAGS_AVX512_AMX}")
message(STATUS "CPU extension (AVX512F) compile flags: ${CXX_COMPILE_FLAGS_AVX512}")
message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}")
else()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
endif()
if(ENABLE_NUMA)
list(APPEND LIBS numa)
......@@ -385,25 +356,6 @@ set(VLLM_EXT_SRC
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
endif()
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
......@@ -416,21 +368,102 @@ if(USE_ONEDNN)
${VLLM_EXT_SRC})
endif()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
if (ENABLE_X86_ISA)
set(VLLM_EXT_SRC_SGL
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
#
# Define extension targets
#
set(VLLM_EXT_SRC_AVX512
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/dnnl_kernels.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
set(VLLM_EXT_SRC_AVX2
"csrc/cpu/utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
message(STATUS "CPU extension (AVX512F + BF16 + VNNI + AMX) source files: ${VLLM_EXT_SRC_AVX512} ${VLLM_EXT_SRC_SGL}")
message(STATUS "CPU extension (AVX512F) source files: ${VLLM_EXT_SRC_AVX512}")
message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}")
set(_C_LIBS numa dnnl_ext)
set(_C_AVX512_LIBS numa dnnl_ext)
set(_C_AVX2_LIBS numa)
# AMX + AVX512F + AVX512BF16 + AVX512VNNI
define_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX512} ${VLLM_EXT_SRC_SGL}
LIBRARIES ${_C_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512_AMX}
USE_SABI 3
WITH_SOABI
)
define_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI
)
# For AMX kernels
target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16")
# AVX512F
define_extension_target(
_C_AVX512
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX512}
LIBRARIES ${_C_AVX512_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}
USE_SABI 3
WITH_SOABI
)
# AVX2
define_extension_target(
_C_AVX2
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX2}
LIBRARIES ${_C_AVX2_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2}
USE_SABI 3
WITH_SOABI
)
else()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
#
# Define extension targets
#
define_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI
)
endif()
message(STATUS "Enabling C extension.")
......@@ -19,7 +19,7 @@ else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208
GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
......
......@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun.
#
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3).
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
......@@ -38,22 +39,16 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2
GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
)
endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
# Make sure vllm-flash-attn install rules are nested under vllm/
# This is here to support installing all components under the same prefix with cmake --install.
# setup.py installs every component separately but uses the same prefix for all.
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
# and these statements don't hurt when installing neither component.
# ALL_COMPONENTS ensures the save/modify/restore runs exactly once regardless
# of how many components are being installed, avoiding double-append of /vllm/.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
......@@ -62,22 +57,48 @@ install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_
FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix
# Restore the install prefix after FA's install rules
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in
# case only one is built, in the case both are built redundant work is done)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py"
)
# Install shared Python files for both FA2 and FA3 components
foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
COMPONENT ${_FA_COMPONENT})
# Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
# which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py"
PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT _vllm_fa3_C
FILES_MATCHING PATTERN "*.py"
)
#
# FA4 CuteDSL component
# This is a Python-only component that copies the flash_attn/cute directory
# and transforms imports to match our package structure.
#
add_custom_target(_vllm_fa4_cutedsl_C)
# Copy flash_attn/cute directory (needed for FA4) and transform imports
# The cute directory uses flash_attn.cute imports internally, which we replace
# with vllm.vllm_flash_attn.cute to match our package structure.
install(CODE "
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)
......@@ -5,6 +5,7 @@
#include <cmath>
#include "cuda_compat.h"
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h"
namespace vllm {
......@@ -16,52 +17,55 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
}
// Check if all pointers are 16-byte aligned for int4 vectorized access
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
bool act_first>
__device__ __forceinline__ packed_t packed_compute(const packed_t& x,
const packed_t& y) {
return act_first ? packed_mul(PACKED_ACT_FN(x), y)
: packed_mul(x, PACKED_ACT_FN(y));
}
// Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&),
bool act_first>
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&),
packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
bool use_vec, bool use_256b = false>
__global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
scalar_t* out_ptr = out + blockIdx.x * d;
// Check alignment for 128-bit vectorized access.
// All three pointers must be 16-byte aligned for safe int4 operations.
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if constexpr (use_vec) {
using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
using pvec_t = PackedVec<cuda_t, use_256b>;
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
pvec_t x, y;
if constexpr (use_256b) {
ld256(x, &x_vec[i]);
ld256(y, &y_vec[i]);
} else {
ld128(x, &x_vec[i]);
ld128(y, &y_vec[i]);
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]);
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
x.elts[j], y.elts[j]);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
} else {
st128(x, &out_vec[i]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
......@@ -79,6 +83,15 @@ __device__ __forceinline__ T silu_kernel(const T& x) {
return (T)(((float)x) / (1.0f + expf((float)-x)));
}
template <typename packed_t>
__device__ __forceinline__ packed_t packed_silu_kernel(const packed_t& val) {
// x * sigmoid(x)
float2 fval = cast_to_float2(val);
fval.x = fval.x / (1.0f + expf(-fval.x));
fval.y = fval.y / (1.0f + expf(-fval.y));
return cast_to_packed<packed_t>(fval);
}
template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation.
......@@ -89,6 +102,18 @@ __device__ __forceinline__ T gelu_kernel(const T& x) {
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
}
template <typename packed_t>
__device__ __forceinline__ packed_t packed_gelu_kernel(const packed_t& val) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
constexpr float ALPHA = M_SQRT1_2;
float2 fval = cast_to_float2(val);
fval.x = fval.x * 0.5f * (1.0f + ::erf(fval.x * ALPHA));
fval.y = fval.y * 0.5f * (1.0f + ::erf(fval.y * ALPHA));
return cast_to_packed<packed_t>(fval);
}
template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
......@@ -102,32 +127,86 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
}
template <typename packed_t>
__device__ __forceinline__ packed_t
packed_gelu_tanh_kernel(const packed_t& val) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
float2 fval = cast_to_float2(val);
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = fval.x * fval.x * fval.x;
float inner = BETA * (fval.x + KAPPA * x_cube);
fval.x = 0.5f * fval.x * (1.0f + ::tanhf(inner));
x_cube = fval.y * fval.y * fval.y;
inner = BETA * (fval.y + KAPPA * x_cube);
fval.y = 0.5f * fval.y * (1.0f + ::tanhf(inner));
return cast_to_packed<packed_t>(fval);
}
} // namespace vllm
// Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function
// first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
if (num_tokens == 0) { \
return; \
} \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
auto dtype = input.scalar_type(); \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
if (num_tokens == 0) { \
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
}
void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true);
}
void mul_and_silu(torch::Tensor& out, // [..., d]
......@@ -135,19 +214,22 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
{
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
false);
}
void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true);
}
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d]
{
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
vllm::packed_gelu_tanh_kernel, true);
}
namespace vllm {
......@@ -158,42 +240,53 @@ __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
return (T)(f > threshold ? f : 0.0f);
}
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)>
template <typename packed_t>
__device__ __forceinline__ packed_t
packed_fatrelu_kernel(const packed_t& val, const float threshold) {
float2 fval = cast_to_float2(val);
fval.x = fval.x > threshold ? fval.x : 0.0f;
fval.y = fval.y > threshold ? fval.y : 0.0f;
return cast_to_packed<packed_t>(fval);
}
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&, const float),
packed_t (*PACKED_ACT_FN)(const packed_t&, const float), bool use_vec,
bool use_256b = false>
__global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d;
scalar_t* out_ptr = out + blockIdx.x * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) &&
is_16byte_aligned(out_ptr);
if constexpr (use_vec) {
using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
using pvec_t = PackedVec<cuda_t, use_256b>;
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r;
auto* xp = reinterpret_cast<scalar_t*>(&x);
auto* yp = reinterpret_cast<scalar_t*>(&y);
auto* rp = reinterpret_cast<scalar_t*>(&r);
pvec_t x, y;
if constexpr (use_256b) {
ld256(x, &x_vec[i]);
ld256(y, &y_vec[i]);
} else {
ld128(x, &x_vec[i]);
ld128(y, &y_vec[i]);
}
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j];
for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
} else {
st128(x, &out_vec[i]);
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
}
} else {
// Scalar fallback for unaligned data or small d
......@@ -276,20 +369,61 @@ __global__ void swigluoai_and_mul_kernel(
} // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d, \
PARAM); \
});
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PACKED_KERNEL, PARAM) \
auto dtype = input.scalar_type(); \
int d = input.size(-1) / 2; \
int64_t num_tokens = input.numel() / input.size(-1); \
if (num_tokens == 0) { \
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
}); \
}
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \
......@@ -309,7 +443,8 @@ __global__ void swigluoai_and_mul_kernel(
void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d]
double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold);
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(
vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold);
}
void swigluoai_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d]
......@@ -319,39 +454,41 @@ void swigluoai_and_mul(torch::Tensor& out, // [..., d]
namespace vllm {
// Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), bool use_vec,
bool use_256b = false>
__global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d]
const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t);
const int64_t token_idx = blockIdx.x;
const scalar_t* in_ptr = input + token_idx * d;
scalar_t* out_ptr = out + token_idx * d;
// Check alignment for 128-bit vectorized access
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) {
// Fast path: 128-bit vectorized loop
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const scalar_t* in_ptr = input + blockIdx.x * d;
scalar_t* out_ptr = out + blockIdx.x * d;
if constexpr (use_vec) {
// Fast path: 128-bit/256-bit vectorized loop
using vec_t = typename VecTraits<use_256b>::vec_t;
constexpr int ARCH_MAX_VEC_SIZE = VecTraits<use_256b>::ARCH_MAX_VEC_SIZE;
constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(scalar_t);
const vec_t* in_vec = reinterpret_cast<const vec_t*>(in_ptr);
vec_t* out_vec = reinterpret_cast<vec_t*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r;
vec_t v;
if constexpr (use_256b) {
ld256(v, &in_vec[i]);
} else {
v = VLLM_LDG(&in_vec[i]);
}
auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll
for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]);
vp[j] = ACT_FN(vp[j]);
}
if constexpr (use_256b) {
st256(v, &out_vec[i]);
} else {
out_vec[i] = v;
}
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
}
} else {
// Scalar fallback for unaligned data or small d
......@@ -365,18 +502,46 @@ __global__ void activation_kernel(
} // namespace vllm
// Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / d; \
dim3 grid(num_tokens); \
dim3 block(std::min(d, 1024)); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
});
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
auto dtype = input.scalar_type(); \
int d = input.size(-1); \
int64_t num_tokens = input.numel() / input.size(-1); \
if (num_tokens == 0) { \
return; \
} \
dim3 grid(num_tokens); \
int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
int support_vec = \
(CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
}
namespace vllm {
......
......@@ -74,6 +74,12 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]
// Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
......
......@@ -8,6 +8,7 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh"
#ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh"
......@@ -918,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ, \
thread_block_size> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
......@@ -930,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache(
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr);
#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576)
#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320)
// Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence
......@@ -959,9 +966,10 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32");
}
TORCH_CHECK(head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 576 "
"for better performance")
TORCH_CHECK(
head_dim == 320 || head_dim == 576,
"gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device");
......@@ -986,7 +994,13 @@ void gather_and_maybe_dequant_cache(
const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE);
if (head_dim == 576) {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_576);
} else {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_320);
}
}
namespace vllm {
......@@ -995,75 +1009,67 @@ namespace vllm {
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH]
const int32_t* __restrict__ workspace_starts, // [BATCH]
const int32_t block_size, const int32_t head_dim,
const int64_t block_table_stride, const int64_t cache_block_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID
const int32_t num_splits = gridDim.y;
const int32_t split = blockIdx.y;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid];
const int32_t tot_slots = seq_len;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits);
__nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
const int32_t* __restrict__ workspace_starts, // [num_reqs]
const int32_t num_reqs, const int32_t block_size,
const int32_t total_tokens, const int64_t block_table_stride,
const int64_t cache_block_stride, const int64_t cache_entry_stride,
const int64_t dst_entry_stride) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= total_tokens) return;
const int lane_id = threadIdx.x & 31;
// Binary search to find which request owns this output token
int lo = 0, hi = num_reqs - 1;
while (lo < hi) {
int mid = (lo + hi + 1) >> 1;
if (workspace_starts[mid] <= flat_warp_id)
lo = mid;
else
hi = mid - 1;
}
const int req_id = lo;
const int32_t split_start = split * split_slots;
const int32_t split_end = min((split + 1) * split_slots, tot_slots);
// Compute physical token address via block table
const int out_token_id = flat_warp_id;
const int token_offset = out_token_id - workspace_starts[req_id];
const int cache_block_idx = token_offset / block_size;
const int offset_in_block = token_offset % block_size;
const int physical_block =
block_table[req_id * block_table_stride + cache_block_idx];
const bool is_active_split = (split_start < tot_slots);
const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
offset_in_block * cache_entry_stride;
if (!is_active_split) return;
const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
const int4 fp8_data = nope_src[lane_id];
// Adjust the pointer for the block_table for this batch
const int32_t batch_offset = bid * block_table_stride;
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const float scale = scales_ptr[lane_id >> 3];
// Adjust dst pointer based on the cumulative sequence lengths
dst += seq_start * dst_entry_stride;
const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
#else
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
#endif
const int tid = threadIdx.x;
__nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
// Process each token in this split
for (int pid = split_start; pid < split_end; ++pid) {
auto block_id = batch_block_table[offset_div];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
rope_dst[lane_id] = rope_src[lane_id];
}
template <typename scalar_t>
......@@ -1234,8 +1240,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
"src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device");
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
......@@ -1244,16 +1255,24 @@ void cp_gather_and_upconvert_fp8_kv_cache(
int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16;
dim3 grid(batch_size, num_splits);
dim3 block(576);
const uint8_t* src_ptr = nullptr;
if (dtype == at::ScalarType::Byte) {
src_ptr = src_cache.data_ptr<uint8_t>();
} else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
const int total_tokens = dst.size(0);
constexpr int warps_per_block = 8;
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
const int block_size_threads = warps_per_block * 32; // 256 threads
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>(
src_cache.data_ptr<uint8_t>(),
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim,
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
stream>>>(
src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride);
}
......@@ -1293,7 +1312,8 @@ void indexer_k_quant_and_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3",
static const std::string kv_cache_dtype = "fp8_e4m3";
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), kv_cache_dtype,
CALL_INDEXER_K_QUANT_AND_CACHE);
}
......@@ -1352,3 +1372,43 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
}
}
// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
// rope_dim]
) {
const int num_tokens = ql_nope.size(0);
const int num_heads = ql_nope.size(1);
const int nope_dim = ql_nope.size(2);
const int rope_dim = q_pe.size(2);
TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
nope_dim);
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
if (num_tokens == 0) return;
constexpr int warps_per_block = 8;
const int total_warps = num_tokens * num_heads;
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
const int block_size = warps_per_block * 32;
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
q_pe.stride(1));
});
}
#ifndef CONCAT_MLA_Q_CUH_
#define CONCAT_MLA_Q_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cuda_vec_utils.cuh"
namespace vllm {
// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and
// q_pe [num_tokens, num_heads, 64]
// into q_out [num_tokens, num_heads, NOPE_DIM+64].
// Currently instantiated only for NOPE_DIM=512.
// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA)
template <typename DType, int NOPE_DIM>
__global__ void ConcatMLAQKernel(
DType* __restrict__ q_out, const DType* __restrict__ ql_nope,
const DType* __restrict__ q_pe, const int num_tokens, const int num_heads,
const int64_t out_stride_0, const int64_t out_stride_1,
const int64_t nope_stride_0, const int64_t nope_stride_1,
const int64_t pe_stride_0, const int64_t pe_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= num_tokens * num_heads) return;
const int token_id = flat_warp_id / num_heads;
const int head_id = flat_warp_id % num_heads;
const int lane_id = threadIdx.x & 31;
constexpr bool use_256b = VLLM_256B_PTX_ENABLED;
constexpr int nope_vec_loads =
NOPE_DIM * sizeof(DType) / (VecTraits<use_256b>::ARCH_MAX_VEC_SIZE * 32);
const DType* nope_src =
ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1;
DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1;
#pragma unroll
for (int i = 0; i < nope_vec_loads; i++) {
const int offset = i * 32 + lane_id;
if constexpr (use_256b) {
st256_cs(reinterpret_cast<u32x8_t*>(nope_dst) + offset,
ld256_cs(reinterpret_cast<const u32x8_t*>(nope_src) + offset));
} else {
st128_cs(reinterpret_cast<int4*>(nope_dst) + offset,
ld128_cs(reinterpret_cast<const int4*>(nope_src) + offset));
}
}
const int* rope_src = reinterpret_cast<const int*>(
q_pe + token_id * pe_stride_0 + head_id * pe_stride_1);
int* rope_dst = reinterpret_cast<int*>(q_out + token_id * out_stride_0 +
head_id * out_stride_1 + NOPE_DIM);
st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id));
}
} // namespace vllm
#endif // CONCAT_MLA_Q_CUH_
......@@ -16,6 +16,8 @@ torch::Tensor get_scheduler_metadata(
isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON;
} else if (isa_hint == "vxe") {
isa = cpu_attention::ISA::VXE;
} else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
}
......@@ -100,6 +102,8 @@ void cpu_attn_reshape_and_cache(
return cpu_attention::ISA::VEC16;
} else if (isa == "neon") {
return cpu_attention::ISA::NEON;
} else if (isa == "vxe") {
return cpu_attention::ISA::VXE;
} else {
TORCH_CHECK(false, "Invalid ISA type: " + isa);
}
......
......@@ -420,7 +420,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const int64_t block_size, const int64_t block_size_stride) {
// For AMX 2D tiles, size of each line is 64 bytes
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
// For AMX B martix, N always is 16
// For AMX B matrix, N always is 16
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
// For now suppose block_size is divisible by amx_tile_column_num
......
......@@ -12,7 +12,7 @@
#include "cpu/utils.hpp"
namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON };
enum class ISA { AMX, VEC, VEC16, NEON, VXE };
template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {};
......@@ -821,7 +821,7 @@ struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16;
};
#if !defined(__powerpc__) && !defined(__s390x__)
#if !defined(__powerpc__)
template <>
struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16;
......
#ifndef CPU_ATTN_VXE_HPP
#define CPU_ATTN_VXE_HPP
#include "cpu_attn_impl.hpp"
#include <vecintrin.h>
#include <type_traits>
namespace cpu_attention {
namespace {
// s390x Vector = 16 bytes (128 bits)
#define BLOCK_SIZE_ALIGNMENT 32
#define HEAD_SIZE_ALIGNMENT 32
#define MAX_Q_HEAD_NUM_PER_ITER 16
template <typename kv_cache_t>
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, __vector float& b0,
__vector float& b1);
// [1] Float Specialization
template <>
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, __vector float& b0,
__vector float& b1) {
// Explicitly cast to long long for offset, and float* for pointer
b0 = vec_xl((long long)0, const_cast<float*>(p));
b1 = vec_xl((long long)0, const_cast<float*>(p + 4));
}
// [2] BFloat16 Specialization (Big Endian Fix)
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
__vector float& b0,
__vector float& b1) {
// 1. Load 8 BF16s (16 bytes) into one vector
// Explicit cast to unsigned short* for vec_xl to return vector unsigned short
__vector unsigned short raw = vec_xl((long long)0, (unsigned short*)p);
// 2. Prepare Zero vector
__vector unsigned short zeros = vec_splat_u16(0);
// 3. Merge High/Low to expand BF16 -> Float32
// On Big Endian, a float is [BF16_bits | 16_zero_bits]
b0 = (__vector float)vec_mergeh(raw, zeros);
b1 = (__vector float)vec_mergel(raw, zeros);
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::Half>(const c10::Half* p,
__vector float& b0,
__vector float& b1) {
alignas(16) float tmp[8];
// Manual unroll / conversion
tmp[0] = static_cast<float>(p[0]);
tmp[1] = static_cast<float>(p[1]);
tmp[2] = static_cast<float>(p[2]);
tmp[3] = static_cast<float>(p[3]);
tmp[4] = static_cast<float>(p[4]);
tmp[5] = static_cast<float>(p[5]);
tmp[6] = static_cast<float>(p[6]);
tmp[7] = static_cast<float>(p[7]);
// Explicit arguments for intrinsic: (long long offset, float* ptr)
b0 = vec_xl((long long)0, (float*)tmp);
b1 = vec_xl((long long)0, (float*)(tmp + 4));
}
template <int32_t M, typename kv_cache_t>
FORCE_INLINE void gemm_micro_s390x_Mx8_Ku4(
const float* __restrict A, // [M x K]
const kv_cache_t* __restrict B, // [K x 8]
float* __restrict C, // [M x 8]
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
// Helper macros to unroll codegen for M rows
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
#define IF_M(i) if constexpr (M > (i))
// 1. Define A pointers
#define DECL_A(i) const float* a##i = A + (i) * lda;
ROWS_APPLY(DECL_A)
#undef DECL_A
// 2. Define Accumulators (2 vectors covers 8 columns)
#define DECL_ACC(i) __vector float acc##i##_0, acc##i##_1;
ROWS_APPLY(DECL_ACC)
#undef DECL_ACC
// 3. Initialize Accumulators (Load C or Zero)
#define INIT_ACC(i) \
IF_M(i) { \
if (accumulate) { \
acc##i##_0 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 0)); \
acc##i##_1 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 4)); \
} else { \
acc##i##_0 = vec_splats(0.0f); \
acc##i##_1 = vec_splats(0.0f); \
} \
}
ROWS_APPLY(INIT_ACC)
#undef INIT_ACC
int32_t k = 0;
for (; k + 3 < K; k += 4) {
// Load 4 values of A for each Row M: A[k...k+3]
#define LOAD_A4(i) \
__vector float a##i##v; \
IF_M(i) a##i##v = vec_xl((long long)0, const_cast<float*>(a##i + k));
ROWS_APPLY(LOAD_A4)
#undef LOAD_A4
// Helper: FMA for specific lane L of A
// s390x: vec_madd(b, vec_splat(a, lane), acc)
#define FMAS_LANE(i, aiv, L) \
IF_M(i) { \
__vector float a_broad = vec_splat(aiv, L); \
acc##i##_0 = vec_madd(b0, a_broad, acc##i##_0); \
acc##i##_1 = vec_madd(b1, a_broad, acc##i##_1); \
}
// Unroll K=0..3
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
ROWS_APPLY(STEP_K0)
#undef STEP_K0
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
ROWS_APPLY(STEP_K1)
#undef STEP_K1
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
ROWS_APPLY(STEP_K2)
#undef STEP_K2
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
ROWS_APPLY(STEP_K3)
#undef STEP_K3
}
#undef FMAS_LANE
}
for (; k < K; ++k) {
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
#define TAIL_ROW(i) \
IF_M(i) { \
__vector float ai = vec_splats(*(a##i + k)); \
acc##i##_0 = vec_madd(b0, ai, acc##i##_0); \
acc##i##_1 = vec_madd(b1, ai, acc##i##_1); \
}
ROWS_APPLY(TAIL_ROW)
#undef TAIL_ROW
}
#define STORE_ROW(i) \
IF_M(i) { \
vec_xst(acc##i##_0, 0, C + (i) * ldc + 0); \
vec_xst(acc##i##_1, 0, C + (i) * ldc + 4); \
}
ROWS_APPLY(STORE_ROW)
#undef STORE_ROW
#undef ROWS_APPLY
#undef IF_M
}
template <int32_t N, typename kv_cache_t>
FORCE_INLINE void gemm_macro_s390x_Mx8_Ku4(const float* __restrict A,
const kv_cache_t* __restrict B,
float* __restrict C, int32_t M,
int32_t K, int64_t lda, int64_t ldb,
int64_t ldc, bool accumulate) {
static_assert(N % 8 == 0, "N must be a multiple of 8");
for (int32_t m = 0; m < M;) {
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
const float* Ab = A + m * lda;
float* Cb = C + m * ldc;
for (int32_t n = 0; n < N; n += 8) {
const kv_cache_t* Bn = B + n;
float* Cn = Cb + n;
switch (mb) {
case 8:
gemm_micro_s390x_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 4:
gemm_micro_s390x_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 2:
gemm_micro_s390x_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
default:
gemm_micro_s390x_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
}
}
m += mb;
}
}
template <typename kv_cache_t>
class TileGemmS390X {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
if constexpr (phase == AttentionGemmPhase::QK) {
gemm_macro_s390x_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
} else {
gemm_macro_s390x_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
accum_c);
}
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VXE, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment = BLOCK_SIZE_ALIGNMENT;
constexpr static int64_t HeadDimAlignment = HEAD_SIZE_ALIGNMENT;
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VXE;
constexpr static bool scale_on_logits =
false; // Scale is applied to Q during copy
public:
AttentionImpl() {}
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemmS390X<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// Strides for Memory Layout
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // [head_dim, block_size] layout
}
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment;
}
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment;
}
static void copy_q_heads_tile(scalar_t* __restrict__ src,
float* __restrict__ q_buffer,
const int32_t q_num,
const int32_t q_heads_per_kv,
const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
__vector float scale_vec = vec_splats(scale);
constexpr bool is_bf16 = std::is_same<scalar_t, c10::BFloat16>::value;
// Process 8 elements at a time (32 bytes of float output)
for (int32_t i = 0; i < q_num; ++i) {
for (int32_t h = 0; h < q_heads_per_kv; ++h) {
scalar_t* curr_src = src + i * q_num_stride + h * q_head_stride;
float* curr_dst =
q_buffer + i * q_heads_per_kv * head_dim + h * head_dim;
int32_t d = 0;
for (; d <= head_dim - 8; d += 8) {
if constexpr (is_bf16) {
__vector float v0, v1;
// Reuse our Big-Endian-Safe loader
load_row8_B_as_f32<scalar_t>(curr_src + d, v0, v1);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
} else {
__vector float v0 = vec_xl((long long)0, (float*)curr_src + d);
__vector float v1 = vec_xl((long long)0, (float*)curr_src + d + 4);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
}
}
for (; d < head_dim; ++d) {
float val = static_cast<float>(curr_src[d]);
curr_dst[d] = val * scale;
}
}
}
}
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) continue;
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
const scalar_t* key_src = key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_dst = key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_dst[j] = key_src[i];
}
}
{
const scalar_t* val_src = value + token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* val_dst = value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
block_offset * head_dim;
std::memcpy(val_dst, val_src, sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif
\ No newline at end of file
......@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2) {
const int32_t output_size_2, const bool skip_weighted) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
......@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2;
if (skip_weighted) {
// Only for topk_num == 1
*curr_weight = 1.0f;
}
if (topk_num > 1) {
{
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
......@@ -699,7 +704,7 @@ void cpu_fused_moe(
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32
const std::string& act, const std::string& isa) {
const bool skip_weighted, const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0);
......@@ -711,6 +716,8 @@ void cpu_fused_moe(
const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
TORCH_CHECK(!skip_weighted || topk_num == 1,
"skip_weighted is only supported for topk=1 on CPU");
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
......@@ -721,7 +728,7 @@ void cpu_fused_moe(
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2);
input_size_2, output_size_2, skip_weighted);
});
});
}
......@@ -13,6 +13,9 @@
#elif defined(__aarch64__)
// arm implementation
#include "cpu_types_arm.hpp"
#elif defined(__riscv_v)
// riscv implementation
#include "cpu_types_riscv.hpp"
#else
#warning "unsupported vLLM cpu implementation, vLLM will compile with scalar"
#include "cpu_types_scalar.hpp"
......
#ifndef CPU_TYPES_RISCV_HPP
#define CPU_TYPES_RISCV_HPP
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
#include <riscv_vector.h>
#include <torch/all.h>
// ============================================================================
// Vector Register Type Definitions (VLEN=128 bits)
// ============================================================================
typedef vfloat16m1_t fixed_vfloat16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vfloat16m2_t fixed_vfloat16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vfloat32m1_t fixed_vfloat32m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vfloat32m2_t fixed_vfloat32m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vfloat32m4_t fixed_vfloat32m4_t
__attribute__((riscv_rvv_vector_bits(512)));
typedef vfloat32m8_t fixed_vfloat32m8_t
__attribute__((riscv_rvv_vector_bits(1024)));
typedef vint32m2_t fixed_vint32m2_t __attribute__((riscv_rvv_vector_bits(256)));
typedef vint32m4_t fixed_vint32m4_t __attribute__((riscv_rvv_vector_bits(512)));
typedef vuint16m1_t fixed_vuint16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vuint16m2_t fixed_vuint16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vuint16m4_t fixed_vuint16m4_t
__attribute__((riscv_rvv_vector_bits(512)));
#ifdef RISCV_BF16_SUPPORT
typedef vbfloat16m1_t fixed_vbfloat16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vbfloat16m2_t fixed_vbfloat16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vbfloat16m4_t fixed_vbfloat16m4_t
__attribute__((riscv_rvv_vector_bits(512)));
#endif
namespace vec_op {
#ifdef RISCV_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
} // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
struct FP32Vec8;
struct FP32Vec16;
// ============================================================================
// FP16 Implementation
// ============================================================================
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat16m1_t reg;
explicit FP16Vec8(const void* ptr)
: reg(__riscv_vle16_v_f16m1(static_cast<const _Float16*>(ptr),
VEC_ELEM_NUM)) {};
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const {
__riscv_vse16_v_f16m1(static_cast<_Float16*>(ptr), reg, VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_f16m1(static_cast<_Float16*>(ptr), reg, elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(_Float16);
__riscv_vsse16_v_f16m1(static_cast<_Float16*>(ptr), byte_stride, reg,
VEC_ELEM_NUM);
}
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat16m2_t reg;
explicit FP16Vec16(const void* ptr)
: reg(__riscv_vle16_v_f16m2(static_cast<const _Float16*>(ptr),
VEC_ELEM_NUM)) {};
explicit FP16Vec16(const FP32Vec16& vec);
void save(void* ptr) const {
__riscv_vse16_v_f16m2(static_cast<_Float16*>(ptr), reg, VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_f16m2(static_cast<_Float16*>(ptr), reg, elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(_Float16);
__riscv_vsse16_v_f16m2(static_cast<_Float16*>(ptr), byte_stride, reg,
VEC_ELEM_NUM);
}
};
// ============================================================================
// BF16 Implementation
// ============================================================================
#ifdef RISCV_BF16_SUPPORT
FORCE_INLINE fixed_vuint16m1_t bf16_to_u16(fixed_vbfloat16m1_t v) {
return __riscv_vreinterpret_v_bf16m1_u16m1(v);
}
FORCE_INLINE fixed_vuint16m2_t bf16_to_u16(fixed_vbfloat16m2_t v) {
return __riscv_vreinterpret_v_bf16m2_u16m2(v);
}
FORCE_INLINE fixed_vuint16m4_t bf16_to_u16(fixed_vbfloat16m4_t v) {
return __riscv_vreinterpret_v_bf16m4_u16m4(v);
}
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vbfloat16m1_t reg;
explicit BF16Vec8(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vle16_v_u16m1(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec8(fixed_vbfloat16m1_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
__riscv_vse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vbfloat16m2_t reg;
explicit BF16Vec16(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vle16_v_u16m2(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec16(fixed_vbfloat16m2_t data) : reg(data) {};
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
__riscv_vse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
fixed_vbfloat16m4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m4_bf16m4(__riscv_vle16_v_u16m4(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec32(fixed_vbfloat16m4_t data) : reg(data) {};
explicit BF16Vec32(const BF16Vec8& v) {
fixed_vuint16m1_t u16_val = bf16_to_u16(v.reg);
fixed_vuint16m4_t u16_combined =
__riscv_vcreate_v_u16m1_u16m4(u16_val, u16_val, u16_val, u16_val);
reg = __riscv_vreinterpret_v_u16m4_bf16m4(u16_combined);
};
void save(void* ptr) const {
__riscv_vse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
#else
// ============================================================================
// BF16 Fallback Implementation (FP32 Simulation)
// ============================================================================
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat32m2_t reg_fp32;
explicit BF16Vec8(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[8];
for (int i = 0; i < 8; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m2(tmp, 8);
}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 8; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 8; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat32m4_t reg_fp32;
explicit BF16Vec16(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[16];
for (int i = 0; i < 16; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m4(tmp, 16);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 16; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 16; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
fixed_vfloat32m8_t reg_fp32;
explicit BF16Vec32(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[32];
for (int i = 0; i < 32; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m8(tmp, 32);
}
explicit BF16Vec32(const BF16Vec8& v) {
float tmp_small[8];
__riscv_vse32_v_f32m2(tmp_small, v.reg_fp32, 8);
float tmp_large[32];
for (int i = 0; i < 4; ++i) {
std::memcpy(tmp_large + (i * 8), tmp_small, 8 * sizeof(float));
}
reg_fp32 = __riscv_vle32_v_f32m8(tmp_large, 32);
}
void save(void* ptr) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 32; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 32; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
#endif
// ============================================================================
// FP32 Implementation
// ============================================================================
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
fixed_vfloat32m1_t reg;
explicit FP32Vec4(float v) : reg(__riscv_vfmv_v_f_f32m1(v, VEC_ELEM_NUM)) {};
explicit FP32Vec4() : reg(__riscv_vfmv_v_f_f32m1(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec4(const float* ptr)
: reg(__riscv_vle32_v_f32m1(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec4(fixed_vfloat32m1_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
void save(float* ptr) const { __riscv_vse32_v_f32m1(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m1(ptr, reg, elem_num);
}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat32m2_t reg;
explicit FP32Vec8(float v) : reg(__riscv_vfmv_v_f_f32m2(v, VEC_ELEM_NUM)) {};
explicit FP32Vec8() : reg(__riscv_vfmv_v_f_f32m2(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec8(const float* ptr)
: reg(__riscv_vle32_v_f32m2(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec8(fixed_vfloat32m2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v)
: reg(__riscv_vfwcvt_f_f_v_f32m2(v.reg, VEC_ELEM_NUM)) {};
explicit FP32Vec8(fixed_vfloat16m1_t v)
: reg(__riscv_vfwcvt_f_f_v_f32m2(v, VEC_ELEM_NUM)) {};
#ifdef RISCV_BF16_SUPPORT
explicit FP32Vec8(fixed_vbfloat16m1_t v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m2(v, VEC_ELEM_NUM)) {};
explicit FP32Vec8(const BF16Vec8& v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m2(v.reg, VEC_ELEM_NUM)) {};
#else
explicit FP32Vec8(const BF16Vec8& v) : reg(v.reg_fp32) {};
#endif
float reduce_sum() const {
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar = __riscv_vfredusum_vs_f32m2_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmul_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfadd_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfsub_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfdiv_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 min(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmin_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 max(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmax_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 abs() const {
return FP32Vec8(__riscv_vfabs_v_f32m2(reg, VEC_ELEM_NUM));
}
FP32Vec8 min(const FP32Vec8& b, int elem_num) const {
return FP32Vec8(__riscv_vfmin_vv_f32m2(reg, b.reg, elem_num));
}
FP32Vec8 max(const FP32Vec8& b, int elem_num) const {
return FP32Vec8(__riscv_vfmax_vv_f32m2(reg, b.reg, elem_num));
}
FP32Vec8 clamp(const FP32Vec8& min_v, const FP32Vec8& max_v) const {
fixed_vfloat32m2_t temp =
__riscv_vfmax_vv_f32m2(min_v.reg, reg, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfmin_vv_f32m2(max_v.reg, temp, VEC_ELEM_NUM));
}
void save(float* ptr) const { __riscv_vse32_v_f32m2(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m2(ptr, reg, elem_num);
}
void save_strided(float* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(float);
__riscv_vsse32_v_f32m2(ptr, byte_stride, reg, VEC_ELEM_NUM);
}
FP32Vec8 exp() const {
const float inv_ln2 = 1.44269504088896341f;
fixed_vfloat32m2_t x_scaled =
__riscv_vfmul_vf_f32m2(reg, inv_ln2, VEC_ELEM_NUM);
fixed_vint32m2_t n_int = __riscv_vfcvt_x_f_v_i32m2(x_scaled, VEC_ELEM_NUM);
fixed_vfloat32m2_t n_float = __riscv_vfcvt_f_x_v_f32m2(n_int, VEC_ELEM_NUM);
fixed_vfloat32m2_t r =
__riscv_vfsub_vv_f32m2(x_scaled, n_float, VEC_ELEM_NUM);
fixed_vfloat32m2_t poly =
__riscv_vfmv_v_f_f32m2(0.001333355810164f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.009618129107628f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.055504108664821f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.240226506959101f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.693147180559945f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 1.0f, VEC_ELEM_NUM);
fixed_vint32m2_t biased_exp =
__riscv_vadd_vx_i32m2(n_int, 127, VEC_ELEM_NUM);
biased_exp = __riscv_vmax_vx_i32m2(biased_exp, 0, VEC_ELEM_NUM);
fixed_vint32m2_t exponent_bits =
__riscv_vsll_vx_i32m2(biased_exp, 23, VEC_ELEM_NUM);
fixed_vfloat32m2_t scale =
__riscv_vreinterpret_v_i32m2_f32m2(exponent_bits);
return FP32Vec8(__riscv_vfmul_vv_f32m2(poly, scale, VEC_ELEM_NUM));
}
FP32Vec8 tanh() const {
fixed_vfloat32m2_t x_clamped = __riscv_vfmin_vf_f32m2(
__riscv_vfmax_vf_f32m2(reg, -9.0f, VEC_ELEM_NUM), 9.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t x2 =
__riscv_vfmul_vf_f32m2(x_clamped, 2.0f, VEC_ELEM_NUM);
FP32Vec8 exp_val = FP32Vec8(x2).exp();
fixed_vfloat32m2_t num =
__riscv_vfsub_vf_f32m2(exp_val.reg, 1.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t den =
__riscv_vfadd_vf_f32m2(exp_val.reg, 1.0f, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfdiv_vv_f32m2(num, den, VEC_ELEM_NUM));
}
FP32Vec8 er() const {
const float p = 0.3275911f, a1 = 0.254829592f, a2 = -0.284496736f,
a3 = 1.421413741f, a4 = -1.453152027f, a5 = 1.061405429f;
fixed_vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(reg, VEC_ELEM_NUM);
fixed_vfloat32m2_t t = __riscv_vfadd_vf_f32m2(
__riscv_vfmul_vf_f32m2(abs_x, p, VEC_ELEM_NUM), 1.0f, VEC_ELEM_NUM);
t = __riscv_vfrdiv_vf_f32m2(t, 1.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t poly = __riscv_vfmv_v_f_f32m2(a5, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a4, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a3, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a2, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a1, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM);
fixed_vfloat32m2_t exp_val =
FP32Vec8(__riscv_vfneg_v_f32m2(
__riscv_vfmul_vv_f32m2(abs_x, abs_x, VEC_ELEM_NUM),
VEC_ELEM_NUM))
.exp()
.reg;
fixed_vfloat32m2_t res = __riscv_vfrsub_vf_f32m2(
__riscv_vfmul_vv_f32m2(poly, exp_val, VEC_ELEM_NUM), 1.0f,
VEC_ELEM_NUM);
vbool16_t mask = __riscv_vmflt_vf_f32m2_b16(reg, 0.0f, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfneg_v_f32m2_m(mask, res, VEC_ELEM_NUM));
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat32m4_t reg;
explicit FP32Vec16(float v) : reg(__riscv_vfmv_v_f_f32m4(v, VEC_ELEM_NUM)) {};
explicit FP32Vec16() : reg(__riscv_vfmv_v_f_f32m4(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec16(const float* ptr)
: reg(__riscv_vle32_v_f32m4(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec16(fixed_vfloat32m4_t data) : reg(data) {};
explicit FP32Vec16(const FP32Vec8& data)
: reg(__riscv_vcreate_v_f32m2_f32m4(data.reg, data.reg)) {};
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v);
#ifdef RISCV_BF16_SUPPORT
explicit FP32Vec16(fixed_vbfloat16m2_t v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m4(v, VEC_ELEM_NUM)) {};
explicit FP32Vec16(const BF16Vec16& v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m4(v.reg, VEC_ELEM_NUM)) {};
#else
explicit FP32Vec16(const BF16Vec16& v) : reg(v.reg_fp32) {};
#endif
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfadd_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfsub_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmul_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfdiv_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 fma(const FP32Vec16& a, const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmacc_vv_f32m4(reg, a.reg, b.reg, VEC_ELEM_NUM));
}
float reduce_sum() const {
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar = __riscv_vfredusum_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
float reduce_max() const {
fixed_vfloat32m1_t scalar =
__riscv_vfmv_s_f_f32m1(std::numeric_limits<float>::lowest(), 1);
scalar = __riscv_vfredmax_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
float reduce_min() const {
fixed_vfloat32m1_t scalar =
__riscv_vfmv_s_f_f32m1(std::numeric_limits<float>::max(), 1);
scalar = __riscv_vfredmin_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
const int start = idx * group_size;
vuint32m4_t indices = __riscv_vid_v_u32m4(VEC_ELEM_NUM);
vbool8_t mask = __riscv_vmand_mm_b8(
__riscv_vmsgeu_vx_u32m4_b8(indices, start, VEC_ELEM_NUM),
__riscv_vmsltu_vx_u32m4_b8(indices, start + group_size, VEC_ELEM_NUM),
VEC_ELEM_NUM);
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar =
__riscv_vfredusum_vs_f32m4_f32m1_m(mask, reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
};
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmax_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmin_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 abs() const {
return FP32Vec16(__riscv_vfabs_v_f32m4(reg, VEC_ELEM_NUM));
}
FP32Vec16 clamp(const FP32Vec16& min_v, const FP32Vec16& max_v) const {
return FP32Vec16(__riscv_vfmin_vv_f32m4(
max_v.reg, __riscv_vfmax_vv_f32m4(min_v.reg, reg, VEC_ELEM_NUM),
VEC_ELEM_NUM));
}
void save(float* ptr) const { __riscv_vse32_v_f32m4(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m4(ptr, reg, elem_num);
}
void save_strided(float* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(float);
__riscv_vsse32_v_f32m4(ptr, byte_stride, reg, VEC_ELEM_NUM);
}
FP32Vec16 exp() const {
const float inv_ln2 = 1.44269504088896341f;
fixed_vfloat32m4_t x_scaled =
__riscv_vfmul_vf_f32m4(reg, inv_ln2, VEC_ELEM_NUM);
fixed_vint32m4_t n_int = __riscv_vfcvt_x_f_v_i32m4(x_scaled, VEC_ELEM_NUM);
fixed_vfloat32m4_t n_float = __riscv_vfcvt_f_x_v_f32m4(n_int, VEC_ELEM_NUM);
fixed_vfloat32m4_t r =
__riscv_vfsub_vv_f32m4(x_scaled, n_float, VEC_ELEM_NUM);
fixed_vfloat32m4_t poly =
__riscv_vfmv_v_f_f32m4(0.001333355810164f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.009618129107628f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.055504108664821f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.240226506959101f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.693147180559945f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM);
fixed_vint32m4_t biased_exp = __riscv_vmax_vx_i32m4(
__riscv_vadd_vx_i32m4(n_int, 127, VEC_ELEM_NUM), 0, VEC_ELEM_NUM);
fixed_vfloat32m4_t scale = __riscv_vreinterpret_v_i32m4_f32m4(
__riscv_vsll_vx_i32m4(biased_exp, 23, VEC_ELEM_NUM));
return FP32Vec16(__riscv_vfmul_vv_f32m4(poly, scale, VEC_ELEM_NUM));
}
FP32Vec16 tanh() const {
fixed_vfloat32m4_t x_clamped = __riscv_vfmin_vf_f32m4(
__riscv_vfmax_vf_f32m4(reg, -9.0f, VEC_ELEM_NUM), 9.0f, VEC_ELEM_NUM);
FP32Vec16 exp_val =
FP32Vec16(__riscv_vfmul_vf_f32m4(x_clamped, 2.0f, VEC_ELEM_NUM)).exp();
return FP32Vec16(__riscv_vfdiv_vv_f32m4(
__riscv_vfsub_vf_f32m4(exp_val.reg, 1.0f, VEC_ELEM_NUM),
__riscv_vfadd_vf_f32m4(exp_val.reg, 1.0f, VEC_ELEM_NUM), VEC_ELEM_NUM));
}
FP32Vec16 er() const {
const float p = 0.3275911f, a1 = 0.254829592f, a2 = -0.284496736f,
a3 = 1.421413741f, a4 = -1.453152027f, a5 = 1.061405429f;
fixed_vfloat32m4_t abs_x = __riscv_vfabs_v_f32m4(reg, VEC_ELEM_NUM);
fixed_vfloat32m4_t t = __riscv_vfrdiv_vf_f32m4(
__riscv_vfadd_vf_f32m4(__riscv_vfmul_vf_f32m4(abs_x, p, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM);
fixed_vfloat32m4_t poly = __riscv_vfmv_v_f_f32m4(a5, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a4, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a3, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a2, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a1, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM);
fixed_vfloat32m4_t exp_val =
FP32Vec16(__riscv_vfneg_v_f32m4(
__riscv_vfmul_vv_f32m4(abs_x, abs_x, VEC_ELEM_NUM),
VEC_ELEM_NUM))
.exp()
.reg;
fixed_vfloat32m4_t res = __riscv_vfrsub_vf_f32m4(
__riscv_vfmul_vv_f32m4(poly, exp_val, VEC_ELEM_NUM), 1.0f,
VEC_ELEM_NUM);
vbool8_t mask = __riscv_vmflt_vf_f32m4_b8(reg, 0.0f, VEC_ELEM_NUM);
return FP32Vec16(__riscv_vfneg_v_f32m4_m(mask, res, VEC_ELEM_NUM));
}
};
// ============================================================================
// Type Traits & Global Helpers
// ============================================================================
template <typename T>
struct VecType {
using vec_type = void;
using vec_t = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
using vec_t = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
using vec_t = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
using vec_t = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<_Float16*>(ptr) = static_cast<_Float16>(v);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
reg = __riscv_vfncvt_f_f_w_f16m2(v.reg, VEC_ELEM_NUM);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
reg = __riscv_vfncvt_f_f_w_f16m1(v.reg, VEC_ELEM_NUM);
}
inline FP32Vec16::FP32Vec16(const FP16Vec16& v) {
reg = __riscv_vfwcvt_f_f_v_f32m4(v.reg, VEC_ELEM_NUM);
}
inline void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc = acc.fma(a, b);
}
#ifdef RISCV_BF16_SUPPORT
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*ptr = static_cast<__bf16>(v);
};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(__riscv_vfncvtbf16_f_f_w_bf16m1(v.reg, VEC_ELEM_NUM)) {};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg(__riscv_vfncvtbf16_f_f_w_bf16m2(v.reg, VEC_ELEM_NUM)) {};
#else
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
uint32_t val;
std::memcpy(&val, &v, 4);
*reinterpret_cast<uint16_t*>(ptr) = static_cast<uint16_t>(val >> 16);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg_fp32(v.reg) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg_fp32(v.reg) {}
#endif
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); }
} // namespace vec_op
#ifndef CPU_KERNEL_GUARD_IN
#define CPU_KERNEL_GUARD_IN(NAME)
#endif
#ifndef CPU_KERNEL_GUARD_OUT
#define CPU_KERNEL_GUARD_OUT(NAME)
#endif
#endif // CPU_TYPES_RISCV_HPP
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment