"deploy/vscode:/vscode.git/clone" did not exist on "6d4d0a61d00f256652a0af4e9790d7dd88db3e03"
Commit 3fb4b5fa authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.18.0' into v0.18.0-ori

parents bcf25339 89138b21
...@@ -126,7 +126,7 @@ def benchmark_decode( ...@@ -126,7 +126,7 @@ def benchmark_decode(
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
...@@ -136,7 +136,7 @@ def benchmark_decode( ...@@ -136,7 +136,7 @@ def benchmark_decode(
start.record() start.record()
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
......
...@@ -138,7 +138,7 @@ def benchmark_prefill( ...@@ -138,7 +138,7 @@ def benchmark_prefill(
) )
def time_fn(fn, warmup=10, trials=20): def time_fn(fn, warmup=10, trials=20):
torch.cuda.synchronize() torch.accelerator.synchronize()
start = torch.Event(enable_timing=True) start = torch.Event(enable_timing=True)
end = torch.Event(enable_timing=True) end = torch.Event(enable_timing=True)
times = [] times = []
...@@ -148,7 +148,7 @@ def benchmark_prefill( ...@@ -148,7 +148,7 @@ def benchmark_prefill(
start.record() start.record()
fn() fn()
end.record() end.record()
torch.cuda.synchronize() torch.accelerator.synchronize()
times.append(start.elapsed_time(end)) # ms times.append(start.elapsed_time(end)) # ms
return sum(times) / len(times), torch.std(torch.tensor(times)) return sum(times) / len(times), torch.std(torch.tensor(times))
......
...@@ -177,18 +177,18 @@ def benchmark_config( ...@@ -177,18 +177,18 @@ def benchmark_config(
def run(): def run():
w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype)
torch.cuda.synchronize() torch.accelerator.synchronize()
# JIT complication & warmup # JIT complication & warmup
for _ in range(5): for _ in range(5):
run() run()
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event = torch.Event(enable_timing=True) start_event = torch.Event(enable_timing=True)
end_event = torch.Event(enable_timing=True) end_event = torch.Event(enable_timing=True)
latencies: list[float] = [] latencies: list[float] = []
for i in range(num_iters): for i in range(num_iters):
torch.cuda.synchronize() torch.accelerator.synchronize()
start_event.record() start_event.record()
run() run()
end_event.record() end_event.record()
...@@ -285,7 +285,7 @@ def tune_on_gpu(args_dict): ...@@ -285,7 +285,7 @@ def tune_on_gpu(args_dict):
weight_shapes = args_dict["weight_shapes"] weight_shapes = args_dict["weight_shapes"]
args = args_dict["args"] args = args_dict["args"]
torch.cuda.set_device(gpu_id) torch.accelerator.set_device_index(gpu_id)
print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}") print(f"Starting tuning on GPU {gpu_id} with batch sizes {batch_sizes}")
block_n = args.block_n block_n = args.block_n
...@@ -334,7 +334,7 @@ def distribute_batch_sizes(batch_sizes, num_gpus): ...@@ -334,7 +334,7 @@ def distribute_batch_sizes(batch_sizes, num_gpus):
def main(args): def main(args):
print(args) print(args)
num_gpus = torch.cuda.device_count() num_gpus = torch.accelerator.device_count()
if num_gpus == 0: if num_gpus == 0:
raise RuntimeError("No GPU available for tuning") raise RuntimeError("No GPU available for tuning")
print(f"Found {num_gpus} GPUs for parallel tuning") print(f"Found {num_gpus} GPUs for parallel tuning")
......
...@@ -35,7 +35,7 @@ def benchmark_shape( ...@@ -35,7 +35,7 @@ def benchmark_shape(
B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16) B = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
# Reference result in BF16 # Reference result in BF16
torch.cuda.synchronize() torch.accelerator.synchronize()
C_ref = A @ B.t() C_ref = A @ B.t()
# Pre-quantize B for all implementations # Pre-quantize B for all implementations
...@@ -121,14 +121,14 @@ def benchmark_shape( ...@@ -121,14 +121,14 @@ def benchmark_shape(
# Warmup # Warmup
for _ in range(warmup): for _ in range(warmup):
func() func()
torch.cuda.synchronize() torch.accelerator.synchronize()
# Timing loop # Timing loop
torch.cuda.synchronize() torch.accelerator.synchronize()
start = time.time() start = time.time()
for _ in range(repeat): for _ in range(repeat):
func() func()
torch.cuda.synchronize() torch.accelerator.synchronize()
end = time.time() end = time.time()
# Calculate timing and TFLOPS # Calculate timing and TFLOPS
......
...@@ -7,7 +7,7 @@ First start serving your model ...@@ -7,7 +7,7 @@ First start serving your model
```bash ```bash
export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/ export MODEL_PATH=/models/meta-llama/Meta-Llama-3.1-8B-Instruct/
vllm serve $MODEL_PATH --served-model-name Llama --disable-log-requests vllm serve $MODEL_PATH --served-model-name Llama
``` ```
The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface). The variable `MODEL_PATH` should be a path to the model files (e.g. downloaded from huggingface).
......
...@@ -71,7 +71,7 @@ while [[ $# -gt 0 ]]; do ...@@ -71,7 +71,7 @@ while [[ $# -gt 0 ]]; do
usage usage
;; ;;
*) *)
echo "Unknown argument: $1\n" printf "Unknown argument: %s\n" "$1"
usage usage
;; ;;
esac esac
...@@ -84,15 +84,17 @@ mkdir -p "$OUTPUT_DIR" ...@@ -84,15 +84,17 @@ mkdir -p "$OUTPUT_DIR"
QPS_VALUES=(25 20 15 10 5 1) QPS_VALUES=(25 20 15 10 5 1)
# Common parameters # Common parameters
COMMON_PARAMS="--backend $BACKEND \ COMMON_PARAMS=(
--model $MODEL \ --backend "$BACKEND"
--dataset $DATASET \ --model "$MODEL"
--structured-output-ratio $STRUCTURED_OUTPUT_RATIO \ --dataset "$DATASET"
--save-results \ --structured-output-ratio "$STRUCTURED_OUTPUT_RATIO"
--result-dir $OUTPUT_DIR \ --save-results
--output-len $MAX_NEW_TOKENS \ --result-dir "$OUTPUT_DIR"
--port $PORT \ --output-len "$MAX_NEW_TOKENS"
--tokenizer-mode $TOKENIZER_MODE" --port "$PORT"
--tokenizer-mode "$TOKENIZER_MODE"
)
echo "Starting structured output benchmark with model: $MODEL" echo "Starting structured output benchmark with model: $MODEL"
echo "Backend: $BACKEND" echo "Backend: $BACKEND"
...@@ -109,17 +111,17 @@ for qps in "${QPS_VALUES[@]}"; do ...@@ -109,17 +111,17 @@ for qps in "${QPS_VALUES[@]}"; do
GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown")
# Construct filename for this run # Construct filename for this run
FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" FILENAME="${BACKEND}_${qps}qps_$(basename "$MODEL")_${DATASET}_${GIT_HASH}_${GIT_BRANCH}.json"
NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc) NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc)
NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part
echo "Running benchmark with $NUM_PROMPTS prompts" echo "Running benchmark with $NUM_PROMPTS prompts"
# Run the benchmark # Run the benchmark
python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ python "$SCRIPT_DIR/benchmark_serving_structured_output.py" "${COMMON_PARAMS[@]}" \
--request-rate $qps \ --request-rate "$qps" \
--result-filename "$FILENAME" \ --result-filename "$FILENAME" \
--num-prompts $NUM_PROMPTS --num-prompts "$NUM_PROMPTS"
echo "Completed benchmark with QPS: $qps" echo "Completed benchmark with QPS: $qps"
echo "----------------------------------------" echo "----------------------------------------"
......
...@@ -13,27 +13,16 @@ endif() ...@@ -13,27 +13,16 @@ endif()
# #
# Define environment variables for special configurations # Define environment variables for special configurations
# #
set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2}) set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86})
set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512}) set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16})
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16})
include_directories("${CMAKE_SOURCE_DIR}/csrc") include_directories("${CMAKE_SOURCE_DIR}/csrc")
set (ENABLE_NUMA TRUE) set (ENABLE_NUMA TRUE)
# #
# Check the compile flags # Check the compile flags
# #
if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64")
list(APPEND CXX_COMPILE_FLAGS
"-mf16c"
)
endif()
if(MACOSX_FOUND) if(MACOSX_FOUND)
list(APPEND CXX_COMPILE_FLAGS list(APPEND CXX_COMPILE_FLAGS
"-DVLLM_CPU_EXTENSION") "-DVLLM_CPU_EXTENSION")
...@@ -77,18 +66,6 @@ function(check_sysctl TARGET OUT) ...@@ -77,18 +66,6 @@ function(check_sysctl TARGET OUT)
endif() endif()
endfunction() endfunction()
function (is_avx512_disabled OUT)
set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512})
if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true")
set(${OUT} ON PARENT_SCOPE)
else()
set(${OUT} OFF PARENT_SCOPE)
endif()
endfunction()
is_avx512_disabled(AVX512_DISABLED)
if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
message(STATUS "Apple Silicon Detected") message(STATUS "Apple Silicon Detected")
set(APPLE_SILICON_FOUND TRUE) set(APPLE_SILICON_FOUND TRUE)
...@@ -96,84 +73,44 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") ...@@ -96,84 +73,44 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.neon ASIMD_FOUND)
check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND)
else() else()
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "Power11" POWER11_FOUND) find_isa(${CPUINFO} "Power11" POWER11_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support find_isa(${CPUINFO} "asimd" ASIMD_FOUND) # Check for ARM NEON support
find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support find_isa(${CPUINFO} "bf16" ARM_BF16_FOUND) # Check for ARM BF16 support
find_isa(${CPUINFO} "S390" S390_FOUND) find_isa(${CPUINFO} "S390" S390_FOUND)
find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support find_isa(${CPUINFO} "zvfhmin" RVV_FP16_FOUND) # Check for RISC-V Vector FP16 support
find_isa(${CPUINFO} "zvfbfmin" RVV_BF16_FOUND) # Check for RISC-V Vector BF16 support
# Support cross-compilation by allowing override via environment variables # Support cross-compilation by allowing override via environment variables
if (ENABLE_AVX2) if (ENABLE_ARM_BF16)
set(AVX2_FOUND ON) set(ARM_BF16_FOUND ON)
message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable") message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable")
endif()
if (ENABLE_AVX512)
set(AVX512_FOUND ON)
message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable")
endif() endif()
endif() endif()
if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA)
list(APPEND CXX_COMPILE_FLAGS set(ENABLE_X86_ISA ON)
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3))
message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3")
endif()
list(APPEND CXX_COMPILE_FLAGS "-mf16c")
list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS})
list(APPEND CXX_COMPILE_FLAGS_AVX512
"-mavx512f" "-mavx512f"
"-mavx512vl" "-mavx512vl"
"-mavx512bw" "-mavx512bw"
"-mavx512dq") "-mavx512dq")
list(APPEND CXX_COMPILE_FLAGS_AVX512_AMX
find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) ${CXX_COMPILE_FLAGS_AVX512}
if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) "-mamx-bf16"
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND "-mamx-tile"
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) "-mavx512bf16"
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") "-mavx512vnni")
set(ENABLE_AVX512BF16 ON) list(APPEND CXX_COMPILE_FLAGS_AVX2
else() "-mavx2")
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512BF16 OFF)
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
endif()
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
set(ENABLE_AVX512VNNI ON)
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AVX512VNNI OFF)
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
endif()
find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND)
if (AMXBF16_FOUND OR ENABLE_AMXBF16)
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile")
set(ENABLE_AMXBF16 ON)
add_compile_definitions(-DCPU_CAPABILITY_AMXBF16)
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3")
endif()
else()
set(ENABLE_AMXBF16 OFF)
message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.")
endif()
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
message(STATUS "PowerPC detected") message(STATUS "PowerPC detected")
if (POWER9_FOUND) if (POWER9_FOUND)
...@@ -208,18 +145,26 @@ elseif (S390_FOUND) ...@@ -208,18 +145,26 @@ elseif (S390_FOUND)
"-march=native" "-march=native"
"-mtune=native") "-mtune=native")
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64")
if(RVV_FOUND) message(STATUS "RISC-V detected")
message(FAIL_ERROR "Can't support rvv now.") if(RVV_BF16_FOUND)
message(STATUS "BF16 extension detected")
set(MARCH_FLAGS -march=rv64gcv_zvfh_zfbfmin_zvfbfmin_zvl128b -mrvv-vector-bits=zvl -mabi=lp64d)
add_compile_definitions(RISCV_BF16_SUPPORT)
elseif (RVV_FP16_FOUND)
message(WARNING "BF16 functionality is not available")
set(MARCH_FLAGS -march=rv64gcv_zvfh_zvl128b -mrvv-vector-bits=zvl -mabi=lp64d)
else() else()
message(STATUS "compile riscv with scalar")
list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc")
endif() endif()
list(APPEND CXX_COMPILE_FLAGS ${MARCH_FLAGS})
else() else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.")
endif() endif()
# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms) # Build oneDNN for GEMM kernels
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND)
# Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64
# TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN
set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "") set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "")
...@@ -308,13 +253,24 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -308,13 +253,24 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
) )
else() else()
message(STATUS "Downloading oneDNN from GitHub") message(STATUS "Downloading oneDNN from GitHub")
FetchContent_Declare( if(ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
oneDNN message(STATUS "aarch64 detected: using pinned oneDNN commit 9c5be1cc59e368aebf0909e6cf20f981ea61462a")
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git FetchContent_Declare(
GIT_TAG v3.10 oneDNN
GIT_PROGRESS TRUE GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_SHALLOW TRUE GIT_TAG 9c5be1cc59e368aebf0909e6cf20f981ea61462a
) GIT_PROGRESS TRUE
GIT_SHALLOW FALSE
)
else()
FetchContent_Declare(
oneDNN
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
GIT_TAG v3.10
GIT_PROGRESS TRUE
GIT_SHALLOW TRUE
)
endif()
endif() endif()
set(ONEDNN_LIBRARY_TYPE "STATIC") set(ONEDNN_LIBRARY_TYPE "STATIC")
...@@ -324,13 +280,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -324,13 +280,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") set(ONEDNN_ENABLE_WORKLOAD "INFERENCE")
set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER")
set(ONEDNN_BUILD_GRAPH "OFF") set(ONEDNN_BUILD_GRAPH "OFF")
set(ONEDNN_ENABLE_JIT_PROFILING "OFF") set(ONEDNN_ENABLE_JIT_PROFILING "ON")
set(ONEDNN_ENABLE_ITT_TASKS "OFF") set(ONEDNN_ENABLE_ITT_TASKS "OFF")
set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") set(ONEDNN_ENABLE_MAX_CPU_ISA "ON")
set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON")
set(ONEDNN_VERBOSE "OFF") set(ONEDNN_VERBOSE "ON")
set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) set(CMAKE_POLICY_DEFAULT_CMP0077 NEW)
# TODO: Refactor this
if (ENABLE_X86_ISA)
# Note: only enable oneDNN for AVX512
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512})
else()
list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS})
endif()
set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE})
set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size
FetchContent_MakeAvailable(oneDNN) FetchContent_MakeAvailable(oneDNN)
...@@ -343,14 +307,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON ...@@ -343,14 +307,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON
PRIVATE ${oneDNN_SOURCE_DIR}/src PRIVATE ${oneDNN_SOURCE_DIR}/src
) )
target_link_libraries(dnnl_ext dnnl torch) target_link_libraries(dnnl_ext dnnl torch)
target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC)
list(APPEND LIBS dnnl_ext) list(APPEND LIBS dnnl_ext)
set(USE_ONEDNN ON) set(USE_ONEDNN ON)
else() else()
set(USE_ONEDNN OFF) set(USE_ONEDNN OFF)
endif() endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") # TODO: Refactor this
if (ENABLE_X86_ISA)
message(STATUS "CPU extension (AVX512F + BF16 + VNNI + AMX) compile flags: ${CXX_COMPILE_FLAGS_AVX512_AMX}")
message(STATUS "CPU extension (AVX512F) compile flags: ${CXX_COMPILE_FLAGS_AVX512}")
message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}")
else()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
endif()
if(ENABLE_NUMA) if(ENABLE_NUMA)
list(APPEND LIBS numa) list(APPEND LIBS numa)
...@@ -385,25 +356,6 @@ set(VLLM_EXT_SRC ...@@ -385,25 +356,6 @@ set(VLLM_EXT_SRC
"csrc/cpu/cpu_attn.cpp" "csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp") "csrc/cpu/torch_bindings.cpp")
if (AVX512_FOUND AND NOT AVX512_DISABLED)
set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp"
"csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
${VLLM_EXT_SRC})
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
set(VLLM_EXT_SRC
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
${VLLM_EXT_SRC})
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
endif()
endif()
if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND)
set(VLLM_EXT_SRC set(VLLM_EXT_SRC
"csrc/cpu/shm.cpp" "csrc/cpu/shm.cpp"
...@@ -416,21 +368,102 @@ if(USE_ONEDNN) ...@@ -416,21 +368,102 @@ if(USE_ONEDNN)
${VLLM_EXT_SRC}) ${VLLM_EXT_SRC})
endif() endif()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") if (ENABLE_X86_ISA)
set(VLLM_EXT_SRC_SGL
"csrc/cpu/sgl-kernels/gemm.cpp"
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
"csrc/cpu/sgl-kernels/moe.cpp"
"csrc/cpu/sgl-kernels/moe_int8.cpp"
"csrc/cpu/sgl-kernels/moe_fp8.cpp")
# set(VLLM_EXT_SRC_AVX512
# Define extension targets "csrc/cpu/shm.cpp"
# "csrc/cpu/cpu_wna16.cpp"
"csrc/cpu/cpu_fused_moe.cpp"
"csrc/cpu/utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/dnnl_kernels.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
set(VLLM_EXT_SRC_AVX2
"csrc/cpu/utils.cpp"
"csrc/cpu/cpu_attn.cpp"
"csrc/cpu/torch_bindings.cpp"
# TODO: Remove these files
"csrc/cpu/activation.cpp"
"csrc/cpu/layernorm.cpp"
"csrc/cpu/mla_decode.cpp"
"csrc/cpu/pos_encoding.cpp"
"csrc/moe/dynamic_4bit_int_moe_cpu.cpp")
message(STATUS "CPU extension (AVX512F + BF16 + VNNI + AMX) source files: ${VLLM_EXT_SRC_AVX512} ${VLLM_EXT_SRC_SGL}")
message(STATUS "CPU extension (AVX512F) source files: ${VLLM_EXT_SRC_AVX512}")
message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}")
set(_C_LIBS numa dnnl_ext)
set(_C_AVX512_LIBS numa dnnl_ext)
set(_C_AVX2_LIBS numa)
# AMX + AVX512F + AVX512BF16 + AVX512VNNI
define_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX512} ${VLLM_EXT_SRC_SGL}
LIBRARIES ${_C_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512_AMX}
USE_SABI 3
WITH_SOABI
)
define_extension_target( # For AMX kernels
_C target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16")
DESTINATION vllm
LANGUAGE CXX # AVX512F
SOURCES ${VLLM_EXT_SRC} define_extension_target(
LIBRARIES ${LIBS} _C_AVX512
COMPILE_FLAGS ${CXX_COMPILE_FLAGS} DESTINATION vllm
USE_SABI 3 LANGUAGE CXX
WITH_SOABI SOURCES ${VLLM_EXT_SRC_AVX512}
) LIBRARIES ${_C_AVX512_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}
USE_SABI 3
WITH_SOABI
)
# AVX2
define_extension_target(
_C_AVX2
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC_AVX2}
LIBRARIES ${_C_AVX2_LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2}
USE_SABI 3
WITH_SOABI
)
else()
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
#
# Define extension targets
#
define_extension_target(
_C
DESTINATION vllm
LANGUAGE CXX
SOURCES ${VLLM_EXT_SRC}
LIBRARIES ${LIBS}
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
USE_SABI 3
WITH_SOABI
)
endif()
message(STATUS "Enabling C extension.") message(STATUS "Enabling C extension.")
...@@ -19,7 +19,7 @@ else() ...@@ -19,7 +19,7 @@ else()
FetchContent_Declare( FetchContent_Declare(
flashmla flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG c2afa9cb93e674d5a9120a170a6da57b89267208 GIT_TAG 692917b1cda61b93ac9ee2d846ec54e75afe87b1
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
CONFIGURE_COMMAND "" CONFIGURE_COMMAND ""
BUILD_COMMAND "" BUILD_COMMAND ""
......
...@@ -17,7 +17,8 @@ endif() ...@@ -17,7 +17,8 @@ endif()
# They should be identical but if they aren't, this is a massive footgun. # They should be identical but if they aren't, this is a massive footgun.
# #
# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place.
# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). # To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3),
# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files).
# If no component is specified, vllm-flash-attn is still installed. # If no component is specified, vllm-flash-attn is still installed.
# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading.
...@@ -38,22 +39,16 @@ else() ...@@ -38,22 +39,16 @@ else()
FetchContent_Declare( FetchContent_Declare(
vllm-flash-attn vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 188be16520ceefdc625fdf71365585d2ee348fe2 GIT_TAG 1488682bb545f7d020e958a33116b1419d1cfc83
GIT_PROGRESS TRUE GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types # Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
) )
endif() endif()
# Ensure the vllm/vllm_flash_attn directory exists before installation
install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS)
# Make sure vllm-flash-attn install rules are nested under vllm/ # Make sure vllm-flash-attn install rules are nested under vllm/
# This is here to support installing all components under the same prefix with cmake --install. # ALL_COMPONENTS ensures the save/modify/restore runs exactly once regardless
# setup.py installs every component separately but uses the same prefix for all. # of how many components are being installed, avoiding double-append of /vllm/.
# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3,
# and these statements don't hurt when installing neither component.
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS)
install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS)
...@@ -62,22 +57,48 @@ install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_ ...@@ -62,22 +57,48 @@ install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_
FetchContent_MakeAvailable(vllm-flash-attn) FetchContent_MakeAvailable(vllm-flash-attn)
message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}")
# Restore the install prefix # Restore the install prefix after FA's install rules
install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS)
install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS)
# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in # Install shared Python files for both FA2 and FA3 components
# case only one is built, in the case both are built redundant work is done) foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C)
install( # Ensure the vllm/vllm_flash_attn directory exists before installation
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")"
DESTINATION vllm/vllm_flash_attn COMPONENT ${_FA_COMPONENT})
COMPONENT _vllm_fa2_C
FILES_MATCHING PATTERN "*.py" # Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py
) # which are source-controlled in vllm)
install(
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/
DESTINATION vllm/vllm_flash_attn
COMPONENT ${_FA_COMPONENT}
FILES_MATCHING PATTERN "*.py"
PATTERN "__init__.py" EXCLUDE
PATTERN "flash_attn_interface.py" EXCLUDE
)
endforeach()
install( #
DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ # FA4 CuteDSL component
DESTINATION vllm/vllm_flash_attn # This is a Python-only component that copies the flash_attn/cute directory
COMPONENT _vllm_fa3_C # and transforms imports to match our package structure.
FILES_MATCHING PATTERN "*.py" #
) add_custom_target(_vllm_fa4_cutedsl_C)
# Copy flash_attn/cute directory (needed for FA4) and transform imports
# The cute directory uses flash_attn.cute imports internally, which we replace
# with vllm.vllm_flash_attn.cute to match our package structure.
install(CODE "
file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\")
foreach(SRC_FILE \${CUTE_PY_FILES})
file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE})
set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\")
get_filename_component(DST_DIR \${DST_FILE} DIRECTORY)
file(MAKE_DIRECTORY \${DST_DIR})
file(READ \${SRC_FILE} FILE_CONTENTS)
string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\")
file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\")
endforeach()
" COMPONENT _vllm_fa4_cutedsl_C)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include <cmath> #include <cmath>
#include "cuda_compat.h" #include "cuda_compat.h"
#include "cuda_vec_utils.cuh"
#include "dispatch_utils.h" #include "dispatch_utils.h"
namespace vllm { namespace vllm {
...@@ -16,52 +17,55 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x, ...@@ -16,52 +17,55 @@ __device__ __forceinline__ scalar_t compute(const scalar_t& x,
return act_first ? ACT_FN(x) * y : x * ACT_FN(y); return act_first ? ACT_FN(x) * y : x * ACT_FN(y);
} }
// Check if all pointers are 16-byte aligned for int4 vectorized access template <typename packed_t, packed_t (*PACKED_ACT_FN)(const packed_t&),
__device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { bool act_first>
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0; __device__ __forceinline__ packed_t packed_compute(const packed_t& x,
const packed_t& y) {
return act_first ? packed_mul(PACKED_ACT_FN(x), y)
: packed_mul(x, PACKED_ACT_FN(y));
} }
// Activation and gating kernel template. // Activation and gating kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), template <typename scalar_t, typename packed_t,
bool act_first> scalar_t (*ACT_FN)(const scalar_t&),
packed_t (*PACKED_ACT_FN)(const packed_t&), bool act_first,
bool use_vec, bool use_256b = false>
__global__ void act_and_mul_kernel( __global__ void act_and_mul_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., 2, d] const scalar_t* __restrict__ input, // [..., 2, d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const int64_t token_idx = blockIdx.x;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d; const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d; scalar_t* out_ptr = out + blockIdx.x * d;
// Check alignment for 128-bit vectorized access. if constexpr (use_vec) {
// All three pointers must be 16-byte aligned for safe int4 operations. using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && using pvec_t = PackedVec<cuda_t, use_256b>;
is_16byte_aligned(out_ptr);
if (aligned && d >= VEC_SIZE) { const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
// Fast path: 128-bit vectorized loop const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr); pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr); const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; pvec_t x, y;
auto* xp = reinterpret_cast<scalar_t*>(&x); if constexpr (use_256b) {
auto* yp = reinterpret_cast<scalar_t*>(&y); ld256(x, &x_vec[i]);
auto* rp = reinterpret_cast<scalar_t*>(&r); ld256(y, &y_vec[i]);
} else {
ld128(x, &x_vec[i]);
ld128(y, &y_vec[i]);
}
#pragma unroll #pragma unroll
for (int j = 0; j < VEC_SIZE; j++) { for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
rp[j] = compute<scalar_t, ACT_FN, act_first>(xp[j], yp[j]); x.elts[j] = packed_compute<packed_t, PACKED_ACT_FN, act_first>(
x.elts[j], y.elts[j]);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
} else {
st128(x, &out_vec[i]);
} }
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = compute<scalar_t, ACT_FN, act_first>(VLLM_LDG(&x_ptr[i]),
VLLM_LDG(&y_ptr[i]));
} }
} else { } else {
// Scalar fallback for unaligned data or small d // Scalar fallback for unaligned data or small d
...@@ -79,6 +83,15 @@ __device__ __forceinline__ T silu_kernel(const T& x) { ...@@ -79,6 +83,15 @@ __device__ __forceinline__ T silu_kernel(const T& x) {
return (T)(((float)x) / (1.0f + expf((float)-x))); return (T)(((float)x) / (1.0f + expf((float)-x)));
} }
template <typename packed_t>
__device__ __forceinline__ packed_t packed_silu_kernel(const packed_t& val) {
// x * sigmoid(x)
float2 fval = cast_to_float2(val);
fval.x = fval.x / (1.0f + expf(-fval.x));
fval.y = fval.y / (1.0f + expf(-fval.y));
return cast_to_packed<packed_t>(fval);
}
template <typename T> template <typename T>
__device__ __forceinline__ T gelu_kernel(const T& x) { __device__ __forceinline__ T gelu_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'none' approximation. // Equivalent to PyTorch GELU with 'none' approximation.
...@@ -89,6 +102,18 @@ __device__ __forceinline__ T gelu_kernel(const T& x) { ...@@ -89,6 +102,18 @@ __device__ __forceinline__ T gelu_kernel(const T& x) {
return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA)));
} }
template <typename packed_t>
__device__ __forceinline__ packed_t packed_gelu_kernel(const packed_t& val) {
// Equivalent to PyTorch GELU with 'none' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38
constexpr float ALPHA = M_SQRT1_2;
float2 fval = cast_to_float2(val);
fval.x = fval.x * 0.5f * (1.0f + ::erf(fval.x * ALPHA));
fval.y = fval.y * 0.5f * (1.0f + ::erf(fval.y * ALPHA));
return cast_to_packed<packed_t>(fval);
}
template <typename T> template <typename T>
__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
// Equivalent to PyTorch GELU with 'tanh' approximation. // Equivalent to PyTorch GELU with 'tanh' approximation.
...@@ -102,32 +127,86 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { ...@@ -102,32 +127,86 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) {
return (T)(0.5f * f * (1.0f + ::tanhf(inner))); return (T)(0.5f * f * (1.0f + ::tanhf(inner)));
} }
template <typename packed_t>
__device__ __forceinline__ packed_t
packed_gelu_tanh_kernel(const packed_t& val) {
// Equivalent to PyTorch GELU with 'tanh' approximation.
// Refer to:
// https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30
float2 fval = cast_to_float2(val);
constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f;
constexpr float KAPPA = 0.044715;
float x_cube = fval.x * fval.x * fval.x;
float inner = BETA * (fval.x + KAPPA * x_cube);
fval.x = 0.5f * fval.x * (1.0f + ::tanhf(inner));
x_cube = fval.y * fval.y * fval.y;
inner = BETA * (fval.y + KAPPA * x_cube);
fval.y = 0.5f * fval.y * (1.0f + ::tanhf(inner));
return cast_to_packed<packed_t>(fval);
}
} // namespace vllm } // namespace vllm
// Launch activation and gating kernel. // Launch activation and gating kernel.
// Use ACT_FIRST (bool) indicating whether to apply the activation function // Use ACT_FIRST (bool) indicating whether to apply the activation function
// first. // first.
#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, ACT_FIRST) \ #define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \
int d = input.size(-1) / 2; \ auto dtype = input.scalar_type(); \
int64_t num_tokens = input.numel() / input.size(-1); \ int d = input.size(-1) / 2; \
dim3 grid(num_tokens); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 block(std::min(d, 1024)); \ if (num_tokens == 0) { \
if (num_tokens == 0) { \ return; \
return; \ } \
} \ dim3 grid(num_tokens); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ int support_vec = \
VLLM_DISPATCH_FLOATING_TYPES( \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
input.scalar_type(), "act_and_mul_kernel", [&] { \ ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
vllm::act_and_mul_kernel<scalar_t, KERNEL<scalar_t>, ACT_FIRST> \ : vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ int vec_size = support_vec / at::elementSize(dtype); \
input.data_ptr<scalar_t>(), d); \ const bool use_vec = (d % vec_size == 0); \
}); const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \
vllm::act_and_mul_kernel< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
ACT_FIRST, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d); \
}); \
}
void silu_and_mul(torch::Tensor& out, // [..., d] void silu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
true);
} }
void mul_and_silu(torch::Tensor& out, // [..., d] void mul_and_silu(torch::Tensor& out, // [..., d]
...@@ -135,19 +214,22 @@ void mul_and_silu(torch::Tensor& out, // [..., d] ...@@ -135,19 +214,22 @@ void mul_and_silu(torch::Tensor& out, // [..., d]
{ {
// The difference between mul_and_silu and silu_and_mul is that mul_and_silu // The difference between mul_and_silu and silu_and_mul is that mul_and_silu
// applies the silu to the latter half of the input. // applies the silu to the latter half of the input.
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, vllm::packed_silu_kernel,
false);
} }
void gelu_and_mul(torch::Tensor& out, // [..., d] void gelu_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, vllm::packed_gelu_kernel,
true);
} }
void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] void gelu_tanh_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input) // [..., 2 * d] torch::Tensor& input) // [..., 2 * d]
{ {
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true); LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel,
vllm::packed_gelu_tanh_kernel, true);
} }
namespace vllm { namespace vllm {
...@@ -158,42 +240,53 @@ __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) { ...@@ -158,42 +240,53 @@ __device__ __forceinline__ T fatrelu_kernel(const T& x, const float threshold) {
return (T)(f > threshold ? f : 0.0f); return (T)(f > threshold ? f : 0.0f);
} }
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&, const float)> template <typename packed_t>
__device__ __forceinline__ packed_t
packed_fatrelu_kernel(const packed_t& val, const float threshold) {
float2 fval = cast_to_float2(val);
fval.x = fval.x > threshold ? fval.x : 0.0f;
fval.y = fval.y > threshold ? fval.y : 0.0f;
return cast_to_packed<packed_t>(fval);
}
template <typename scalar_t, typename packed_t,
scalar_t (*ACT_FN)(const scalar_t&, const float),
packed_t (*PACKED_ACT_FN)(const packed_t&, const float), bool use_vec,
bool use_256b = false>
__global__ void act_and_mul_kernel_with_param( __global__ void act_and_mul_kernel_with_param(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d, scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const int d,
const float param) { const float param) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const scalar_t* x_ptr = input + blockIdx.x * 2 * d;
const int64_t token_idx = blockIdx.x;
const scalar_t* x_ptr = input + token_idx * 2 * d;
const scalar_t* y_ptr = x_ptr + d; const scalar_t* y_ptr = x_ptr + d;
scalar_t* out_ptr = out + token_idx * d; scalar_t* out_ptr = out + blockIdx.x * d;
// Check alignment for 128-bit vectorized access if constexpr (use_vec) {
const bool aligned = is_16byte_aligned(x_ptr) && is_16byte_aligned(y_ptr) && using cuda_t = typename CUDATypeConverter<scalar_t>::Type;
is_16byte_aligned(out_ptr); using pvec_t = PackedVec<cuda_t, use_256b>;
if (aligned && d >= VEC_SIZE) { const pvec_t* x_vec = reinterpret_cast<const pvec_t*>(x_ptr);
// Fast path: 128-bit vectorized loop const pvec_t* y_vec = reinterpret_cast<const pvec_t*>(y_ptr);
const int4* x_vec = reinterpret_cast<const int4*>(x_ptr); pvec_t* out_vec = reinterpret_cast<pvec_t*>(out_ptr);
const int4* y_vec = reinterpret_cast<const int4*>(y_ptr); const int num_vecs = d / 2 / pvec_t::NUM_ELTS;
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 x = VLLM_LDG(&x_vec[i]), y = VLLM_LDG(&y_vec[i]), r; pvec_t x, y;
auto* xp = reinterpret_cast<scalar_t*>(&x); if constexpr (use_256b) {
auto* yp = reinterpret_cast<scalar_t*>(&y); ld256(x, &x_vec[i]);
auto* rp = reinterpret_cast<scalar_t*>(&r); ld256(y, &y_vec[i]);
} else {
ld128(x, &x_vec[i]);
ld128(y, &y_vec[i]);
}
#pragma unroll #pragma unroll
for (int j = 0; j < VEC_SIZE; j++) { for (int j = 0; j < pvec_t::NUM_ELTS; j++) {
rp[j] = ACT_FN(xp[j], param) * yp[j]; x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]);
}
if constexpr (use_256b) {
st256(x, &out_vec[i]);
} else {
st128(x, &out_vec[i]);
} }
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&x_ptr[i]), param) * VLLM_LDG(&y_ptr[i]);
} }
} else { } else {
// Scalar fallback for unaligned data or small d // Scalar fallback for unaligned data or small d
...@@ -276,20 +369,61 @@ __global__ void swigluoai_and_mul_kernel( ...@@ -276,20 +369,61 @@ __global__ void swigluoai_and_mul_kernel(
} // namespace vllm } // namespace vllm
#define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PARAM) \ #define LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(KERNEL, PACKED_KERNEL, PARAM) \
int d = input.size(-1) / 2; \ auto dtype = input.scalar_type(); \
int64_t num_tokens = input.numel() / input.size(-1); \ int d = input.size(-1) / 2; \
dim3 grid(num_tokens); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 block(std::min(d, 1024)); \ if (num_tokens == 0) { \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ return; \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ } \
VLLM_DISPATCH_FLOATING_TYPES( \ dim3 grid(num_tokens); \
input.scalar_type(), "act_and_mul_kernel_with_param", [&] { \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
vllm::act_and_mul_kernel_with_param<scalar_t, KERNEL<scalar_t>> \ int support_vec = \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
input.data_ptr<scalar_t>(), d, \ ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
PARAM); \ : vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
}); int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, true><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES( \
dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL< \
typename vllm::PackedTypeConverter<scalar_t>::Type>, \
true, false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, \
PARAM); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \
vllm::act_and_mul_kernel_with_param< \
scalar_t, typename vllm::PackedTypeConverter<scalar_t>::Type, \
KERNEL<scalar_t>, \
PACKED_KERNEL<typename vllm::PackedTypeConverter<scalar_t>::Type>, \
false><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), d, PARAM); \
}); \
}
#define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \ #define LAUNCH_SIGLUOAI_AND_MUL(KERNEL, ALPHA, LIMIT) \
int d = input.size(-1) / 2; \ int d = input.size(-1) / 2; \
...@@ -309,7 +443,8 @@ __global__ void swigluoai_and_mul_kernel( ...@@ -309,7 +443,8 @@ __global__ void swigluoai_and_mul_kernel(
void fatrelu_and_mul(torch::Tensor& out, // [..., d], void fatrelu_and_mul(torch::Tensor& out, // [..., d],
torch::Tensor& input, // [..., 2 * d] torch::Tensor& input, // [..., 2 * d]
double threshold) { double threshold) {
LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(vllm::fatrelu_kernel, threshold); LAUNCH_ACTIVATION_GATE_KERNEL_WITH_PARAM(
vllm::fatrelu_kernel, vllm::packed_fatrelu_kernel, threshold);
} }
void swigluoai_and_mul(torch::Tensor& out, // [..., d] void swigluoai_and_mul(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., 2 * d] torch::Tensor& input, // [..., 2 * d]
...@@ -319,39 +454,41 @@ void swigluoai_and_mul(torch::Tensor& out, // [..., d] ...@@ -319,39 +454,41 @@ void swigluoai_and_mul(torch::Tensor& out, // [..., d]
namespace vllm { namespace vllm {
// Element-wise activation kernel template. // Element-wise activation kernel template.
template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)> template <typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&), bool use_vec,
bool use_256b = false>
__global__ void activation_kernel( __global__ void activation_kernel(
scalar_t* __restrict__ out, // [..., d] scalar_t* __restrict__ out, // [..., d]
const scalar_t* __restrict__ input, // [..., d] const scalar_t* __restrict__ input, // [..., d]
const int d) { const int d) {
constexpr int VEC_SIZE = 16 / sizeof(scalar_t); const scalar_t* in_ptr = input + blockIdx.x * d;
const int64_t token_idx = blockIdx.x; scalar_t* out_ptr = out + blockIdx.x * d;
const scalar_t* in_ptr = input + token_idx * d;
scalar_t* out_ptr = out + token_idx * d; if constexpr (use_vec) {
// Fast path: 128-bit/256-bit vectorized loop
// Check alignment for 128-bit vectorized access using vec_t = typename VecTraits<use_256b>::vec_t;
const bool aligned = is_16byte_aligned(in_ptr) && is_16byte_aligned(out_ptr); constexpr int ARCH_MAX_VEC_SIZE = VecTraits<use_256b>::ARCH_MAX_VEC_SIZE;
constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(scalar_t);
if (aligned && d >= VEC_SIZE) { const vec_t* in_vec = reinterpret_cast<const vec_t*>(in_ptr);
// Fast path: 128-bit vectorized loop vec_t* out_vec = reinterpret_cast<vec_t*>(out_ptr);
const int4* in_vec = reinterpret_cast<const int4*>(in_ptr);
int4* out_vec = reinterpret_cast<int4*>(out_ptr);
const int num_vecs = d / VEC_SIZE; const int num_vecs = d / VEC_SIZE;
const int vec_end = num_vecs * VEC_SIZE;
for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) {
int4 v = VLLM_LDG(&in_vec[i]), r; vec_t v;
if constexpr (use_256b) {
ld256(v, &in_vec[i]);
} else {
v = VLLM_LDG(&in_vec[i]);
}
auto* vp = reinterpret_cast<scalar_t*>(&v); auto* vp = reinterpret_cast<scalar_t*>(&v);
auto* rp = reinterpret_cast<scalar_t*>(&r);
#pragma unroll #pragma unroll
for (int j = 0; j < VEC_SIZE; j++) { for (int j = 0; j < VEC_SIZE; j++) {
rp[j] = ACT_FN(vp[j]); vp[j] = ACT_FN(vp[j]);
}
if constexpr (use_256b) {
st256(v, &out_vec[i]);
} else {
out_vec[i] = v;
} }
out_vec[i] = r;
}
// Scalar cleanup for remaining elements
for (int i = vec_end + threadIdx.x; i < d; i += blockDim.x) {
out_ptr[i] = ACT_FN(VLLM_LDG(&in_ptr[i]));
} }
} else { } else {
// Scalar fallback for unaligned data or small d // Scalar fallback for unaligned data or small d
...@@ -365,18 +502,46 @@ __global__ void activation_kernel( ...@@ -365,18 +502,46 @@ __global__ void activation_kernel(
} // namespace vllm } // namespace vllm
// Launch element-wise activation kernel. // Launch element-wise activation kernel.
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ #define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
int d = input.size(-1); \ auto dtype = input.scalar_type(); \
int64_t num_tokens = input.numel() / d; \ int d = input.size(-1); \
dim3 grid(num_tokens); \ int64_t num_tokens = input.numel() / input.size(-1); \
dim3 block(std::min(d, 1024)); \ if (num_tokens == 0) { \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ return; \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ } \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ dim3 grid(num_tokens); \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>> \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \ int support_vec = \
input.data_ptr<scalar_t>(), d); \ (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \
}); ? vllm::VecTraits<true>::ARCH_MAX_VEC_SIZE \
: vllm::VecTraits<false>::ARCH_MAX_VEC_SIZE; \
int vec_size = support_vec / at::elementSize(dtype); \
const bool use_vec = (d % vec_size == 0); \
const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
if (use_vec) { \
dim3 block(std::min(d / vec_size, 1024)); \
if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, true> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} else { \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, true, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
} \
} else { \
dim3 block(std::min(d, 1024)); \
VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>, false> \
<<<grid, block, 0, stream>>>(out.data_ptr<scalar_t>(), \
input.data_ptr<scalar_t>(), d); \
}); \
}
namespace vllm { namespace vllm {
......
...@@ -74,6 +74,12 @@ void indexer_k_quant_and_cache( ...@@ -74,6 +74,12 @@ void indexer_k_quant_and_cache(
int64_t quant_block_size, // quantization block size int64_t quant_block_size, // quantization block size
const std::string& scale_fmt); const std::string& scale_fmt);
// Concatenate query nope and rope for MLA/DSA attention
void concat_mla_q(
torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out); // [num_tokens, num_heads, nope_dim + rope_dim]
// Extract function to gather quantized K cache // Extract function to gather quantized K cache
void cp_gather_indexer_k_quant_cache( void cp_gather_indexer_k_quant_cache(
const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] const torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include "cuda_compat.h" #include "cuda_compat.h"
#include "dispatch_utils.h" #include "dispatch_utils.h"
#include "quantization/vectorization_utils.cuh" #include "quantization/vectorization_utils.cuh"
#include "concat_mla_q.cuh"
#ifdef USE_ROCM #ifdef USE_ROCM
#include "quantization/w8a8/fp8/amd/quant_utils.cuh" #include "quantization/w8a8/fp8/amd/quant_utils.cuh"
...@@ -918,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache( ...@@ -918,8 +919,8 @@ __global__ void gather_and_maybe_dequant_cache(
// SCALAR_T is the data type of the destination tensor. // SCALAR_T is the data type of the destination tensor.
// CACHE_T is the stored data type of kv-cache. // CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
#define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE) \ #define CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ) \
vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, 576, \ vllm::gather_and_maybe_dequant_cache<SCALAR_T, CACHE_T, KV_DTYPE, ENTRY_SZ, \
thread_block_size> \ thread_block_size> \
<<<grid, block, 0, stream>>>( \ <<<grid, block, 0, stream>>>( \
reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \ reinterpret_cast<CACHE_T*>(src_cache.data_ptr()), \
...@@ -930,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache( ...@@ -930,6 +931,12 @@ __global__ void gather_and_maybe_dequant_cache(
dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \ dst_entry_stride, reinterpret_cast<const float*>(scale.data_ptr()), \
seq_starts_ptr); seq_starts_ptr);
#define CALL_GATHER_CACHE_576(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 576)
#define CALL_GATHER_CACHE_320(SCALAR_T, CACHE_T, KV_DTYPE) \
CALL_GATHER_CACHE(SCALAR_T, CACHE_T, KV_DTYPE, 320)
// Gather sequences from the cache into the destination tensor. // Gather sequences from the cache into the destination tensor.
// - cu_seq_lens contains the cumulative sequence lengths for each batch // - cu_seq_lens contains the cumulative sequence lengths for each batch
// - block_table contains the cache block indices for each sequence // - block_table contains the cache block indices for each sequence
...@@ -959,9 +966,10 @@ void gather_and_maybe_dequant_cache( ...@@ -959,9 +966,10 @@ void gather_and_maybe_dequant_cache(
TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32, TORCH_CHECK(seq_starts.value().dtype() == torch::kInt32,
"seq_starts must be int32"); "seq_starts must be int32");
} }
TORCH_CHECK(head_dim == 576, TORCH_CHECK(
"gather_and_maybe_dequant_cache only support the head_dim to 576 " head_dim == 320 || head_dim == 576,
"for better performance") "gather_and_maybe_dequant_cache only support the head_dim to 320 or 576 "
"for better performance")
TORCH_CHECK(src_cache.device() == dst.device(), TORCH_CHECK(src_cache.device() == dst.device(),
"src_cache and dst must be on the same device"); "src_cache and dst must be on the same device");
...@@ -986,7 +994,13 @@ void gather_and_maybe_dequant_cache( ...@@ -986,7 +994,13 @@ void gather_and_maybe_dequant_cache(
const int32_t* seq_starts_ptr = const int32_t* seq_starts_ptr =
seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr; seq_starts.has_value() ? seq_starts.value().data_ptr<int32_t>() : nullptr;
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype, CALL_GATHER_CACHE); if (head_dim == 576) {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_576);
} else {
DISPATCH_BY_KV_CACHE_DTYPE(dst.dtype(), kv_cache_dtype,
CALL_GATHER_CACHE_320);
}
} }
namespace vllm { namespace vllm {
...@@ -995,75 +1009,67 @@ namespace vllm { ...@@ -995,75 +1009,67 @@ namespace vllm {
// Similar to cp_gather_cache but specifically for FP8->BF16 conversion // Similar to cp_gather_cache but specifically for FP8->BF16 conversion
__global__ void cp_gather_and_upconvert_fp8_kv_cache( __global__ void cp_gather_and_upconvert_fp8_kv_cache(
const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656] const uint8_t* __restrict__ src_cache, // [NUM_BLOCKS, BLOCK_SIZE, 656]
__nv_bfloat16* __restrict__ dst, // [TOT_TOKENS, 576] __nv_bfloat16* __restrict__ dst, // [total_tokens, 576]
const int32_t* __restrict__ block_table, // [BATCH, BLOCK_INDICES] const int32_t* __restrict__ block_table, // [num_reqs, BLOCK_INDICES]
const int32_t* __restrict__ seq_lens, // [BATCH] const int32_t* __restrict__ workspace_starts, // [num_reqs]
const int32_t* __restrict__ workspace_starts, // [BATCH] const int32_t num_reqs, const int32_t block_size,
const int32_t block_size, const int32_t head_dim, const int32_t total_tokens, const int64_t block_table_stride,
const int64_t block_table_stride, const int64_t cache_block_stride, const int64_t cache_block_stride, const int64_t cache_entry_stride,
const int64_t cache_entry_stride, const int64_t dst_entry_stride) { const int64_t dst_entry_stride) {
const int64_t bid = blockIdx.x; // Batch ID const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
const int32_t num_splits = gridDim.y; if (flat_warp_id >= total_tokens) return;
const int32_t split = blockIdx.y; const int lane_id = threadIdx.x & 31;
const int32_t seq_start = workspace_starts[bid];
const int32_t seq_len = seq_lens[bid]; // Binary search to find which request owns this output token
const int32_t tot_slots = seq_len; int lo = 0, hi = num_reqs - 1;
const int32_t split_slots = cuda_utils::ceil_div(tot_slots, num_splits); while (lo < hi) {
int mid = (lo + hi + 1) >> 1;
if (workspace_starts[mid] <= flat_warp_id)
lo = mid;
else
hi = mid - 1;
}
const int req_id = lo;
const int32_t split_start = split * split_slots; // Compute physical token address via block table
const int32_t split_end = min((split + 1) * split_slots, tot_slots); const int out_token_id = flat_warp_id;
const int token_offset = out_token_id - workspace_starts[req_id];
const int cache_block_idx = token_offset / block_size;
const int offset_in_block = token_offset % block_size;
const int physical_block =
block_table[req_id * block_table_stride + cache_block_idx];
const bool is_active_split = (split_start < tot_slots); const uint8_t* token_ptr = src_cache + physical_block * cache_block_stride +
offset_in_block * cache_entry_stride;
if (!is_active_split) return; const int4* nope_src = reinterpret_cast<const int4*>(token_ptr);
const int4 fp8_data = nope_src[lane_id];
// Adjust the pointer for the block_table for this batch const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const int32_t batch_offset = bid * block_table_stride; const float scale = scales_ptr[lane_id >> 3];
int32_t offset = split_start;
int32_t offset_div = offset / block_size;
offset = offset % block_size;
const int32_t* batch_block_table = block_table + batch_offset;
// Adjust dst pointer based on the cumulative sequence lengths const uint2 fp8_lo = make_uint2(fp8_data.x, fp8_data.y);
dst += seq_start * dst_entry_stride; const uint2 fp8_hi = make_uint2(fp8_data.z, fp8_data.w);
#ifdef USE_ROCM
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale);
#else
const bf16_8_t bf16_lo =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_lo, scale, __NV_E4M3);
const bf16_8_t bf16_hi =
fp8::scaled_vec_conversion<bf16_8_t, uint2>(fp8_hi, scale, __NV_E4M3);
#endif
const int tid = threadIdx.x; __nv_bfloat16* dst_ptr = dst + out_token_id * dst_entry_stride;
int4* nope_dst = reinterpret_cast<int4*>(dst_ptr) + lane_id * 2;
nope_dst[0] = *reinterpret_cast<const int4*>(&bf16_lo);
nope_dst[1] = *reinterpret_cast<const int4*>(&bf16_hi);
// Process each token in this split const int* rope_src = reinterpret_cast<const int*>(token_ptr + 528);
for (int pid = split_start; pid < split_end; ++pid) { int* rope_dst = reinterpret_cast<int*>(dst_ptr + 512);
auto block_id = batch_block_table[offset_div]; rope_dst[lane_id] = rope_src[lane_id];
const uint8_t* token_ptr =
src_cache + block_id * cache_block_stride + offset * cache_entry_stride;
__nv_bfloat16* dst_ptr = dst + pid * dst_entry_stride;
// FP8 format: 512 bytes fp8 + 16 bytes scales + 128 bytes rope (64 bf16)
const uint8_t* no_pe_ptr = token_ptr;
const float* scales_ptr = reinterpret_cast<const float*>(token_ptr + 512);
const __nv_bfloat16* rope_ptr =
reinterpret_cast<const __nv_bfloat16*>(token_ptr + 512 + 16);
// Parallelize fp8 dequant (512 elements) and rope copy (64 elements)
if (tid < 512) {
// FP8 dequantization
const int tile = tid >> 7; // each tile is 128 elements
const float scale = scales_ptr[tile];
const uint8_t val = no_pe_ptr[tid];
dst_ptr[tid] =
fp8::scaled_convert<__nv_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3>(val, scale);
} else if (tid < 576) {
// Rope copy (64 bf16 elements)
const int rope_idx = tid - 512;
dst_ptr[512 + rope_idx] = rope_ptr[rope_idx];
}
// Move to next token
offset += 1;
if (offset == block_size) {
offset_div += 1;
offset = 0;
}
}
} }
template <typename scalar_t> template <typename scalar_t>
...@@ -1234,8 +1240,13 @@ void cp_gather_and_upconvert_fp8_kv_cache( ...@@ -1234,8 +1240,13 @@ void cp_gather_and_upconvert_fp8_kv_cache(
"src_cache and seq_lens must be on the same device"); "src_cache and seq_lens must be on the same device");
TORCH_CHECK(src_cache.device() == workspace_starts.device(), TORCH_CHECK(src_cache.device() == workspace_starts.device(),
"src_cache and workspace_starts must be on the same device"); "src_cache and workspace_starts must be on the same device");
auto dtype = src_cache.scalar_type();
TORCH_CHECK(src_cache.dtype() == torch::kUInt8, "src_cache must be uint8"); TORCH_CHECK(
dtype == at::ScalarType::Byte || // uint8
dtype == at::ScalarType::Float8_e4m3fn || // fp8 e4m3
dtype == at::ScalarType::Float8_e5m2, // fp8 e5m2
"src_cache must be uint8, float8_e4m3fn, or float8_e5m2, but got ",
src_cache.dtype());
TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16"); TORCH_CHECK(dst.dtype() == torch::kBFloat16, "dst must be bfloat16");
TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA"); TORCH_CHECK(head_dim == 576, "head_dim must be 576 for MLA");
...@@ -1244,16 +1255,24 @@ void cp_gather_and_upconvert_fp8_kv_cache( ...@@ -1244,16 +1255,24 @@ void cp_gather_and_upconvert_fp8_kv_cache(
int64_t cache_entry_stride = src_cache.stride(1); int64_t cache_entry_stride = src_cache.stride(1);
int64_t dst_entry_stride = dst.stride(0); int64_t dst_entry_stride = dst.stride(0);
// Decide on the number of splits based on the batch size const uint8_t* src_ptr = nullptr;
int num_splits = batch_size > 128 ? 2 : batch_size > 64 ? 4 : 16; if (dtype == at::ScalarType::Byte) {
dim3 grid(batch_size, num_splits); src_ptr = src_cache.data_ptr<uint8_t>();
dim3 block(576); } else {
// float8_e4m3fn or float8_e5m2
src_ptr = reinterpret_cast<const uint8_t*>(src_cache.data_ptr());
}
const int total_tokens = dst.size(0);
constexpr int warps_per_block = 8;
const int grid_size = (total_tokens + warps_per_block - 1) / warps_per_block;
const int block_size_threads = warps_per_block * 32; // 256 threads
vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid, block, 0, stream>>>( vllm::cp_gather_and_upconvert_fp8_kv_cache<<<grid_size, block_size_threads, 0,
src_cache.data_ptr<uint8_t>(), stream>>>(
reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()), src_ptr, reinterpret_cast<__nv_bfloat16*>(dst.data_ptr()),
block_table.data_ptr<int32_t>(), seq_lens.data_ptr<int32_t>(), block_table.data_ptr<int32_t>(), workspace_starts.data_ptr<int32_t>(),
workspace_starts.data_ptr<int32_t>(), block_size, head_dim, static_cast<int32_t>(batch_size), block_size, total_tokens,
block_table_stride, cache_block_stride, cache_entry_stride, block_table_stride, cache_block_stride, cache_entry_stride,
dst_entry_stride); dst_entry_stride);
} }
...@@ -1293,7 +1312,8 @@ void indexer_k_quant_and_cache( ...@@ -1293,7 +1312,8 @@ void indexer_k_quant_and_cache(
const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); const at::cuda::OptionalCUDAGuard device_guard(device_of(k));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", static const std::string kv_cache_dtype = "fp8_e4m3";
DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), kv_cache_dtype,
CALL_INDEXER_K_QUANT_AND_CACHE); CALL_INDEXER_K_QUANT_AND_CACHE);
} }
...@@ -1352,3 +1372,43 @@ void cp_gather_indexer_k_quant_cache( ...@@ -1352,3 +1372,43 @@ void cp_gather_indexer_k_quant_cache(
CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32); CALL_CP_GATHER_INDEXER_K_QUANT_CACHE(32);
} }
} }
// Concatenate ql_nope and q_pe into a contiguous q_out tensor for MLA/DSA.
// Replaces torch.cat((ql_nope, q_pe), dim=-1).
void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
torch::Tensor& q_pe, // [num_tokens, num_heads, rope_dim]
torch::Tensor& q_out // [num_tokens, num_heads, nope_dim +
// rope_dim]
) {
const int num_tokens = ql_nope.size(0);
const int num_heads = ql_nope.size(1);
const int nope_dim = ql_nope.size(2);
const int rope_dim = q_pe.size(2);
TORCH_CHECK(nope_dim % 512 == 0, "nope_dim must be a multiple of 512, got ",
nope_dim);
TORCH_CHECK(rope_dim == 64, "rope_dim must be 64, got ", rope_dim);
TORCH_CHECK(q_out.size(2) == nope_dim + rope_dim);
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
if (num_tokens == 0) return;
constexpr int warps_per_block = 8;
const int total_warps = num_tokens * num_heads;
const int grid_size = (total_warps + warps_per_block - 1) / warps_per_block;
const int block_size = warps_per_block * 32;
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
q_out.stride(1), ql_nope.stride(0), ql_nope.stride(1), q_pe.stride(0),
q_pe.stride(1));
});
}
#ifndef CONCAT_MLA_Q_CUH_
#define CONCAT_MLA_Q_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include "cuda_vec_utils.cuh"
namespace vllm {
// Concatenates ql_nope [num_tokens, num_heads, NOPE_DIM] and
// q_pe [num_tokens, num_heads, 64]
// into q_out [num_tokens, num_heads, NOPE_DIM+64].
// Currently instantiated only for NOPE_DIM=512.
// Rope dim is hardcoded to 64 (DeepSeek V3.2 MLA)
template <typename DType, int NOPE_DIM>
__global__ void ConcatMLAQKernel(
DType* __restrict__ q_out, const DType* __restrict__ ql_nope,
const DType* __restrict__ q_pe, const int num_tokens, const int num_heads,
const int64_t out_stride_0, const int64_t out_stride_1,
const int64_t nope_stride_0, const int64_t nope_stride_1,
const int64_t pe_stride_0, const int64_t pe_stride_1) {
const int flat_warp_id = (blockIdx.x * blockDim.x + threadIdx.x) >> 5;
if (flat_warp_id >= num_tokens * num_heads) return;
const int token_id = flat_warp_id / num_heads;
const int head_id = flat_warp_id % num_heads;
const int lane_id = threadIdx.x & 31;
constexpr bool use_256b = VLLM_256B_PTX_ENABLED;
constexpr int nope_vec_loads =
NOPE_DIM * sizeof(DType) / (VecTraits<use_256b>::ARCH_MAX_VEC_SIZE * 32);
const DType* nope_src =
ql_nope + token_id * nope_stride_0 + head_id * nope_stride_1;
DType* nope_dst = q_out + token_id * out_stride_0 + head_id * out_stride_1;
#pragma unroll
for (int i = 0; i < nope_vec_loads; i++) {
const int offset = i * 32 + lane_id;
if constexpr (use_256b) {
st256_cs(reinterpret_cast<u32x8_t*>(nope_dst) + offset,
ld256_cs(reinterpret_cast<const u32x8_t*>(nope_src) + offset));
} else {
st128_cs(reinterpret_cast<int4*>(nope_dst) + offset,
ld128_cs(reinterpret_cast<const int4*>(nope_src) + offset));
}
}
const int* rope_src = reinterpret_cast<const int*>(
q_pe + token_id * pe_stride_0 + head_id * pe_stride_1);
int* rope_dst = reinterpret_cast<int*>(q_out + token_id * out_stride_0 +
head_id * out_stride_1 + NOPE_DIM);
st32_cs(rope_dst + lane_id, ld32_cs(rope_src + lane_id));
}
} // namespace vllm
#endif // CONCAT_MLA_Q_CUH_
...@@ -16,6 +16,8 @@ torch::Tensor get_scheduler_metadata( ...@@ -16,6 +16,8 @@ torch::Tensor get_scheduler_metadata(
isa = cpu_attention::ISA::VEC16; isa = cpu_attention::ISA::VEC16;
} else if (isa_hint == "neon") { } else if (isa_hint == "neon") {
isa = cpu_attention::ISA::NEON; isa = cpu_attention::ISA::NEON;
} else if (isa_hint == "vxe") {
isa = cpu_attention::ISA::VXE;
} else { } else {
TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint); TORCH_CHECK(false, "Unsupported CPU attention ISA hint: " + isa_hint);
} }
...@@ -100,6 +102,8 @@ void cpu_attn_reshape_and_cache( ...@@ -100,6 +102,8 @@ void cpu_attn_reshape_and_cache(
return cpu_attention::ISA::VEC16; return cpu_attention::ISA::VEC16;
} else if (isa == "neon") { } else if (isa == "neon") {
return cpu_attention::ISA::NEON; return cpu_attention::ISA::NEON;
} else if (isa == "vxe") {
return cpu_attention::ISA::VXE;
} else { } else {
TORCH_CHECK(false, "Invalid ISA type: " + isa); TORCH_CHECK(false, "Invalid ISA type: " + isa);
} }
......
...@@ -420,7 +420,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> { ...@@ -420,7 +420,7 @@ class AttentionImpl<ISA::AMX, scalar_t, head_dim> {
const int64_t block_size, const int64_t block_size_stride) { const int64_t block_size, const int64_t block_size_stride) {
// For AMX 2D tiles, size of each line is 64 bytes // For AMX 2D tiles, size of each line is 64 bytes
constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES; constexpr int64_t amx_tile_row_size = AMX_TILE_ROW_BYTES;
// For AMX B martix, N always is 16 // For AMX B matrix, N always is 16
constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4; constexpr int64_t amx_b_tile_n_size = AMX_TILE_ROW_BYTES / 4;
constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t); constexpr int64_t amx_b_tile_k_size = amx_tile_row_size / sizeof(scalar_t);
// For now suppose block_size is divisible by amx_tile_column_num // For now suppose block_size is divisible by amx_tile_column_num
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
#include "cpu/utils.hpp" #include "cpu/utils.hpp"
namespace cpu_attention { namespace cpu_attention {
enum class ISA { AMX, VEC, VEC16, NEON }; enum class ISA { AMX, VEC, VEC16, NEON, VXE };
template <ISA isa, typename scalar_t, int64_t head_dim> template <ISA isa, typename scalar_t, int64_t head_dim>
class AttentionImpl {}; class AttentionImpl {};
...@@ -821,7 +821,7 @@ struct VecTypeTrait<c10::BFloat16> { ...@@ -821,7 +821,7 @@ struct VecTypeTrait<c10::BFloat16> {
using vec_t = vec_op::BF16Vec16; using vec_t = vec_op::BF16Vec16;
}; };
#if !defined(__powerpc__) && !defined(__s390x__) #if !defined(__powerpc__)
template <> template <>
struct VecTypeTrait<c10::Half> { struct VecTypeTrait<c10::Half> {
using vec_t = vec_op::FP16Vec16; using vec_t = vec_op::FP16Vec16;
......
#ifndef CPU_ATTN_VXE_HPP
#define CPU_ATTN_VXE_HPP
#include "cpu_attn_impl.hpp"
#include <vecintrin.h>
#include <type_traits>
namespace cpu_attention {
namespace {
// s390x Vector = 16 bytes (128 bits)
#define BLOCK_SIZE_ALIGNMENT 32
#define HEAD_SIZE_ALIGNMENT 32
#define MAX_Q_HEAD_NUM_PER_ITER 16
template <typename kv_cache_t>
FORCE_INLINE void load_row8_B_as_f32(const kv_cache_t* p, __vector float& b0,
__vector float& b1);
// [1] Float Specialization
template <>
FORCE_INLINE void load_row8_B_as_f32<float>(const float* p, __vector float& b0,
__vector float& b1) {
// Explicitly cast to long long for offset, and float* for pointer
b0 = vec_xl((long long)0, const_cast<float*>(p));
b1 = vec_xl((long long)0, const_cast<float*>(p + 4));
}
// [2] BFloat16 Specialization (Big Endian Fix)
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::BFloat16>(const c10::BFloat16* p,
__vector float& b0,
__vector float& b1) {
// 1. Load 8 BF16s (16 bytes) into one vector
// Explicit cast to unsigned short* for vec_xl to return vector unsigned short
__vector unsigned short raw = vec_xl((long long)0, (unsigned short*)p);
// 2. Prepare Zero vector
__vector unsigned short zeros = vec_splat_u16(0);
// 3. Merge High/Low to expand BF16 -> Float32
// On Big Endian, a float is [BF16_bits | 16_zero_bits]
b0 = (__vector float)vec_mergeh(raw, zeros);
b1 = (__vector float)vec_mergel(raw, zeros);
}
template <>
FORCE_INLINE void load_row8_B_as_f32<c10::Half>(const c10::Half* p,
__vector float& b0,
__vector float& b1) {
alignas(16) float tmp[8];
// Manual unroll / conversion
tmp[0] = static_cast<float>(p[0]);
tmp[1] = static_cast<float>(p[1]);
tmp[2] = static_cast<float>(p[2]);
tmp[3] = static_cast<float>(p[3]);
tmp[4] = static_cast<float>(p[4]);
tmp[5] = static_cast<float>(p[5]);
tmp[6] = static_cast<float>(p[6]);
tmp[7] = static_cast<float>(p[7]);
// Explicit arguments for intrinsic: (long long offset, float* ptr)
b0 = vec_xl((long long)0, (float*)tmp);
b1 = vec_xl((long long)0, (float*)(tmp + 4));
}
template <int32_t M, typename kv_cache_t>
FORCE_INLINE void gemm_micro_s390x_Mx8_Ku4(
const float* __restrict A, // [M x K]
const kv_cache_t* __restrict B, // [K x 8]
float* __restrict C, // [M x 8]
int64_t lda, int64_t ldb, int64_t ldc, int32_t K, bool accumulate) {
static_assert(1 <= M && M <= 8, "M must be in [1,8]");
// Helper macros to unroll codegen for M rows
#define ROWS_APPLY(OP) OP(0) OP(1) OP(2) OP(3) OP(4) OP(5) OP(6) OP(7)
#define IF_M(i) if constexpr (M > (i))
// 1. Define A pointers
#define DECL_A(i) const float* a##i = A + (i) * lda;
ROWS_APPLY(DECL_A)
#undef DECL_A
// 2. Define Accumulators (2 vectors covers 8 columns)
#define DECL_ACC(i) __vector float acc##i##_0, acc##i##_1;
ROWS_APPLY(DECL_ACC)
#undef DECL_ACC
// 3. Initialize Accumulators (Load C or Zero)
#define INIT_ACC(i) \
IF_M(i) { \
if (accumulate) { \
acc##i##_0 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 0)); \
acc##i##_1 = \
vec_xl((long long)0, const_cast<float*>(C + (i) * ldc + 4)); \
} else { \
acc##i##_0 = vec_splats(0.0f); \
acc##i##_1 = vec_splats(0.0f); \
} \
}
ROWS_APPLY(INIT_ACC)
#undef INIT_ACC
int32_t k = 0;
for (; k + 3 < K; k += 4) {
// Load 4 values of A for each Row M: A[k...k+3]
#define LOAD_A4(i) \
__vector float a##i##v; \
IF_M(i) a##i##v = vec_xl((long long)0, const_cast<float*>(a##i + k));
ROWS_APPLY(LOAD_A4)
#undef LOAD_A4
// Helper: FMA for specific lane L of A
// s390x: vec_madd(b, vec_splat(a, lane), acc)
#define FMAS_LANE(i, aiv, L) \
IF_M(i) { \
__vector float a_broad = vec_splat(aiv, L); \
acc##i##_0 = vec_madd(b0, a_broad, acc##i##_0); \
acc##i##_1 = vec_madd(b1, a_broad, acc##i##_1); \
}
// Unroll K=0..3
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 0) * ldb, b0, b1);
#define STEP_K0(i) FMAS_LANE(i, a##i##v, 0)
ROWS_APPLY(STEP_K0)
#undef STEP_K0
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 1) * ldb, b0, b1);
#define STEP_K1(i) FMAS_LANE(i, a##i##v, 1)
ROWS_APPLY(STEP_K1)
#undef STEP_K1
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 2) * ldb, b0, b1);
#define STEP_K2(i) FMAS_LANE(i, a##i##v, 2)
ROWS_APPLY(STEP_K2)
#undef STEP_K2
}
{
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)(k + 3) * ldb, b0, b1);
#define STEP_K3(i) FMAS_LANE(i, a##i##v, 3)
ROWS_APPLY(STEP_K3)
#undef STEP_K3
}
#undef FMAS_LANE
}
for (; k < K; ++k) {
__vector float b0, b1;
load_row8_B_as_f32<kv_cache_t>(B + (int64_t)k * ldb, b0, b1);
#define TAIL_ROW(i) \
IF_M(i) { \
__vector float ai = vec_splats(*(a##i + k)); \
acc##i##_0 = vec_madd(b0, ai, acc##i##_0); \
acc##i##_1 = vec_madd(b1, ai, acc##i##_1); \
}
ROWS_APPLY(TAIL_ROW)
#undef TAIL_ROW
}
#define STORE_ROW(i) \
IF_M(i) { \
vec_xst(acc##i##_0, 0, C + (i) * ldc + 0); \
vec_xst(acc##i##_1, 0, C + (i) * ldc + 4); \
}
ROWS_APPLY(STORE_ROW)
#undef STORE_ROW
#undef ROWS_APPLY
#undef IF_M
}
template <int32_t N, typename kv_cache_t>
FORCE_INLINE void gemm_macro_s390x_Mx8_Ku4(const float* __restrict A,
const kv_cache_t* __restrict B,
float* __restrict C, int32_t M,
int32_t K, int64_t lda, int64_t ldb,
int64_t ldc, bool accumulate) {
static_assert(N % 8 == 0, "N must be a multiple of 8");
for (int32_t m = 0; m < M;) {
int32_t mb = (M - m >= 8) ? 8 : (M - m >= 4) ? 4 : (M - m >= 2) ? 2 : 1;
const float* Ab = A + m * lda;
float* Cb = C + m * ldc;
for (int32_t n = 0; n < N; n += 8) {
const kv_cache_t* Bn = B + n;
float* Cn = Cb + n;
switch (mb) {
case 8:
gemm_micro_s390x_Mx8_Ku4<8, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 4:
gemm_micro_s390x_Mx8_Ku4<4, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
case 2:
gemm_micro_s390x_Mx8_Ku4<2, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
default:
gemm_micro_s390x_Mx8_Ku4<1, kv_cache_t>(Ab, Bn, Cn, lda, ldb, ldc, K,
accumulate);
break;
}
}
m += mb;
}
}
template <typename kv_cache_t>
class TileGemmS390X {
public:
template <AttentionGemmPhase phase, int32_t k_size>
FORCE_INLINE static void gemm(const int32_t m_size,
float* __restrict__ a_tile,
kv_cache_t* __restrict__ b_tile,
float* __restrict__ c_tile, const int64_t lda,
const int64_t ldb, const int64_t ldc,
const int32_t block_size,
const int32_t dynamic_k_size,
const bool accum_c) {
if constexpr (phase == AttentionGemmPhase::QK) {
gemm_macro_s390x_Mx8_Ku4<BLOCK_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, k_size, lda, ldb, ldc, accum_c);
} else {
gemm_macro_s390x_Mx8_Ku4<HEAD_SIZE_ALIGNMENT, kv_cache_t>(
a_tile, b_tile, c_tile, m_size, dynamic_k_size, lda, ldb, ldc,
accum_c);
}
}
};
} // namespace
template <typename scalar_t, int64_t head_dim>
class AttentionImpl<ISA::VXE, scalar_t, head_dim> {
public:
using query_t = scalar_t;
using q_buffer_t = float;
using kv_cache_t = scalar_t;
using logits_buffer_t = float;
using partial_output_buffer_t = float;
using prob_buffer_t = float;
constexpr static int64_t BlockSizeAlignment = BLOCK_SIZE_ALIGNMENT;
constexpr static int64_t HeadDimAlignment = HEAD_SIZE_ALIGNMENT;
constexpr static int64_t MaxQHeadNumPerIteration = MAX_Q_HEAD_NUM_PER_ITER;
constexpr static int64_t HeadDim = head_dim;
constexpr static ISA ISAType = ISA::VXE;
constexpr static bool scale_on_logits =
false; // Scale is applied to Q during copy
public:
AttentionImpl() {}
template <template <typename tile_gemm_t> typename attention>
FORCE_INLINE void execute_attention(DEFINE_CPU_ATTENTION_PARAMS) {
attention<TileGemmS390X<kv_cache_t>> attention_iteration;
attention_iteration(CPU_ATTENTION_PARAMS);
}
// Strides for Memory Layout
constexpr static int64_t k_cache_token_group_stride(
const int32_t block_size) {
return BlockSizeAlignment; // [head_dim, block_size] layout
}
constexpr static int64_t v_cache_token_group_stride(
const int32_t block_size) {
return head_dim * BlockSizeAlignment;
}
constexpr static int64_t v_cache_head_group_stride(const int32_t block_size) {
return HeadDimAlignment;
}
static void copy_q_heads_tile(scalar_t* __restrict__ src,
float* __restrict__ q_buffer,
const int32_t q_num,
const int32_t q_heads_per_kv,
const int64_t q_num_stride,
const int64_t q_head_stride, float scale) {
__vector float scale_vec = vec_splats(scale);
constexpr bool is_bf16 = std::is_same<scalar_t, c10::BFloat16>::value;
// Process 8 elements at a time (32 bytes of float output)
for (int32_t i = 0; i < q_num; ++i) {
for (int32_t h = 0; h < q_heads_per_kv; ++h) {
scalar_t* curr_src = src + i * q_num_stride + h * q_head_stride;
float* curr_dst =
q_buffer + i * q_heads_per_kv * head_dim + h * head_dim;
int32_t d = 0;
for (; d <= head_dim - 8; d += 8) {
if constexpr (is_bf16) {
__vector float v0, v1;
// Reuse our Big-Endian-Safe loader
load_row8_B_as_f32<scalar_t>(curr_src + d, v0, v1);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
} else {
__vector float v0 = vec_xl((long long)0, (float*)curr_src + d);
__vector float v1 = vec_xl((long long)0, (float*)curr_src + d + 4);
v0 = vec_mul(v0, scale_vec);
v1 = vec_mul(v1, scale_vec);
vec_xst(v0, 0, curr_dst + d);
vec_xst(v1, 0, curr_dst + d + 4);
}
}
for (; d < head_dim; ++d) {
float val = static_cast<float>(curr_src[d]);
curr_dst[d] = val * scale;
}
}
}
}
static void reshape_and_cache(
const scalar_t* __restrict__ key, const scalar_t* __restrict__ value,
scalar_t* __restrict__ key_cache, scalar_t* __restrict__ value_cache,
const int64_t* __restrict__ slot_mapping, const int64_t token_num,
const int64_t key_token_num_stride, const int64_t value_token_num_stride,
const int64_t head_num, const int64_t key_head_num_stride,
const int64_t value_head_num_stride, const int64_t num_blocks,
const int64_t num_blocks_stride, const int64_t cache_head_num_stride,
const int64_t block_size, const int64_t block_size_stride) {
#pragma omp parallel for collapse(2)
for (int64_t token_idx = 0; token_idx < token_num; ++token_idx) {
for (int64_t head_idx = 0; head_idx < head_num; ++head_idx) {
const int64_t pos = slot_mapping[token_idx];
if (pos < 0) continue;
const int64_t block_idx = pos / block_size;
const int64_t block_offset = pos % block_size;
{
const scalar_t* key_src = key + token_idx * key_token_num_stride +
head_idx * key_head_num_stride;
scalar_t* key_dst = key_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride + block_offset;
for (int64_t i = 0, j = 0; i < head_dim; ++i, j += block_size) {
key_dst[j] = key_src[i];
}
}
{
const scalar_t* val_src = value + token_idx * value_token_num_stride +
head_idx * value_head_num_stride;
scalar_t* val_dst = value_cache + block_idx * num_blocks_stride +
head_idx * cache_head_num_stride +
block_offset * head_dim;
std::memcpy(val_dst, val_src, sizeof(scalar_t) * head_dim);
}
}
}
}
};
} // namespace cpu_attention
#undef BLOCK_SIZE_ALIGNMENT
#undef HEAD_SIZE_ALIGNMENT
#undef MAX_Q_HEAD_NUM_PER_ITER
#endif
\ No newline at end of file
...@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, ...@@ -147,7 +147,7 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
const int32_t token_num, const int32_t expert_num, const int32_t token_num, const int32_t expert_num,
const int32_t topk_num, const int32_t input_size_13, const int32_t topk_num, const int32_t input_size_13,
const int32_t output_size_13, const int32_t input_size_2, const int32_t output_size_13, const int32_t input_size_2,
const int32_t output_size_2) { const int32_t output_size_2, const bool skip_weighted) {
using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t; using scalar_vec_t = typename cpu_utils::VecTypeTrait<scalar_t>::vec_t;
constexpr int32_t gemm_n_tile_size = gemm_t::NSize; constexpr int32_t gemm_n_tile_size = gemm_t::NSize;
constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize; constexpr int32_t gemm_m_tile_size = gemm_t::MaxMSize;
...@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input, ...@@ -582,6 +582,11 @@ void fused_moe_impl(scalar_t* __restrict__ output, scalar_t* __restrict__ input,
scalar_t* __restrict__ curr_output_buffer = scalar_t* __restrict__ curr_output_buffer =
output + token_id * output_size_2; output + token_id * output_size_2;
if (skip_weighted) {
// Only for topk_num == 1
*curr_weight = 1.0f;
}
if (topk_num > 1) { if (topk_num > 1) {
{ {
int32_t w2_output_idx = curr_expand_token_id_index_buffer[0]; int32_t w2_output_idx = curr_expand_token_id_index_buffer[0];
...@@ -699,7 +704,7 @@ void cpu_fused_moe( ...@@ -699,7 +704,7 @@ void cpu_fused_moe(
const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2] const std::optional<torch::Tensor>& w2_bias, // [expert_num, output_size_2]
const torch::Tensor& topk_weights, // [token_num, k], float32 const torch::Tensor& topk_weights, // [token_num, k], float32
const torch::Tensor& topk_id, // [token_num, k], int32 const torch::Tensor& topk_id, // [token_num, k], int32
const std::string& act, const std::string& isa) { const bool skip_weighted, const std::string& act, const std::string& isa) {
const int32_t token_num = input.size(0); const int32_t token_num = input.size(0);
const int32_t input_size_13 = input.size(1); const int32_t input_size_13 = input.size(1);
const int64_t input_stride = input.stride(0); const int64_t input_stride = input.stride(0);
...@@ -711,6 +716,8 @@ void cpu_fused_moe( ...@@ -711,6 +716,8 @@ void cpu_fused_moe(
const int32_t topk_num = topk_id.size(1); const int32_t topk_num = topk_id.size(1);
const FusedMOEAct act_type = get_act_type(act); const FusedMOEAct act_type = get_act_type(act);
cpu_utils::ISA isa_type = cpu_utils::get_isa(isa); cpu_utils::ISA isa_type = cpu_utils::get_isa(isa);
TORCH_CHECK(!skip_weighted || topk_num == 1,
"skip_weighted is only supported for topk=1 on CPU");
VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() { VLLM_DISPATCH_FLOATING_TYPES(w13.scalar_type(), "cpu_fused_moe", [&]() {
CPU_ISA_DISPATCH_IMPL(isa_type, [&]() { CPU_ISA_DISPATCH_IMPL(isa_type, [&]() {
...@@ -721,7 +728,7 @@ void cpu_fused_moe( ...@@ -721,7 +728,7 @@ void cpu_fused_moe(
w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr, w2_bias.has_value() ? w2_bias->data_ptr<scalar_t>() : nullptr,
topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type, topk_weights.data_ptr<float>(), topk_id.data_ptr<int32_t>(), act_type,
token_num, expert_num, topk_num, input_size_13, output_size_13, token_num, expert_num, topk_num, input_size_13, output_size_13,
input_size_2, output_size_2); input_size_2, output_size_2, skip_weighted);
}); });
}); });
} }
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
#elif defined(__aarch64__) #elif defined(__aarch64__)
// arm implementation // arm implementation
#include "cpu_types_arm.hpp" #include "cpu_types_arm.hpp"
#elif defined(__riscv_v)
// riscv implementation
#include "cpu_types_riscv.hpp"
#else #else
#warning "unsupported vLLM cpu implementation, vLLM will compile with scalar" #warning "unsupported vLLM cpu implementation, vLLM will compile with scalar"
#include "cpu_types_scalar.hpp" #include "cpu_types_scalar.hpp"
......
#ifndef CPU_TYPES_RISCV_HPP
#define CPU_TYPES_RISCV_HPP
#include <algorithm>
#include <cmath>
#include <cstring>
#include <iostream>
#include <limits>
#include <riscv_vector.h>
#include <torch/all.h>
// ============================================================================
// Vector Register Type Definitions (VLEN=128 bits)
// ============================================================================
typedef vfloat16m1_t fixed_vfloat16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vfloat16m2_t fixed_vfloat16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vfloat32m1_t fixed_vfloat32m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vfloat32m2_t fixed_vfloat32m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vfloat32m4_t fixed_vfloat32m4_t
__attribute__((riscv_rvv_vector_bits(512)));
typedef vfloat32m8_t fixed_vfloat32m8_t
__attribute__((riscv_rvv_vector_bits(1024)));
typedef vint32m2_t fixed_vint32m2_t __attribute__((riscv_rvv_vector_bits(256)));
typedef vint32m4_t fixed_vint32m4_t __attribute__((riscv_rvv_vector_bits(512)));
typedef vuint16m1_t fixed_vuint16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vuint16m2_t fixed_vuint16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vuint16m4_t fixed_vuint16m4_t
__attribute__((riscv_rvv_vector_bits(512)));
#ifdef RISCV_BF16_SUPPORT
typedef vbfloat16m1_t fixed_vbfloat16m1_t
__attribute__((riscv_rvv_vector_bits(128)));
typedef vbfloat16m2_t fixed_vbfloat16m2_t
__attribute__((riscv_rvv_vector_bits(256)));
typedef vbfloat16m4_t fixed_vbfloat16m4_t
__attribute__((riscv_rvv_vector_bits(512)));
#endif
namespace vec_op {
#ifdef RISCV_BF16_SUPPORT
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#else
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
#endif
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F&& f) {
(f(std::integral_constant<T, indexes>{}), ...);
};
} // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F&& f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T>
struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; };
};
struct FP32Vec8;
struct FP32Vec16;
// ============================================================================
// FP16 Implementation
// ============================================================================
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat16m1_t reg;
explicit FP16Vec8(const void* ptr)
: reg(__riscv_vle16_v_f16m1(static_cast<const _Float16*>(ptr),
VEC_ELEM_NUM)) {};
explicit FP16Vec8(const FP32Vec8&);
void save(void* ptr) const {
__riscv_vse16_v_f16m1(static_cast<_Float16*>(ptr), reg, VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_f16m1(static_cast<_Float16*>(ptr), reg, elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(_Float16);
__riscv_vsse16_v_f16m1(static_cast<_Float16*>(ptr), byte_stride, reg,
VEC_ELEM_NUM);
}
};
struct FP16Vec16 : public Vec<FP16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat16m2_t reg;
explicit FP16Vec16(const void* ptr)
: reg(__riscv_vle16_v_f16m2(static_cast<const _Float16*>(ptr),
VEC_ELEM_NUM)) {};
explicit FP16Vec16(const FP32Vec16& vec);
void save(void* ptr) const {
__riscv_vse16_v_f16m2(static_cast<_Float16*>(ptr), reg, VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_f16m2(static_cast<_Float16*>(ptr), reg, elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(_Float16);
__riscv_vsse16_v_f16m2(static_cast<_Float16*>(ptr), byte_stride, reg,
VEC_ELEM_NUM);
}
};
// ============================================================================
// BF16 Implementation
// ============================================================================
#ifdef RISCV_BF16_SUPPORT
FORCE_INLINE fixed_vuint16m1_t bf16_to_u16(fixed_vbfloat16m1_t v) {
return __riscv_vreinterpret_v_bf16m1_u16m1(v);
}
FORCE_INLINE fixed_vuint16m2_t bf16_to_u16(fixed_vbfloat16m2_t v) {
return __riscv_vreinterpret_v_bf16m2_u16m2(v);
}
FORCE_INLINE fixed_vuint16m4_t bf16_to_u16(fixed_vbfloat16m4_t v) {
return __riscv_vreinterpret_v_bf16m4_u16m4(v);
}
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vbfloat16m1_t reg;
explicit BF16Vec8(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m1_bf16m1(__riscv_vle16_v_u16m1(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec8(fixed_vbfloat16m1_t data) : reg(data) {};
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
__riscv_vse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m1(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vbfloat16m2_t reg;
explicit BF16Vec16(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m2_bf16m2(__riscv_vle16_v_u16m2(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec16(fixed_vbfloat16m2_t data) : reg(data) {};
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
__riscv_vse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m2(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
fixed_vbfloat16m4_t reg;
explicit BF16Vec32(const void* ptr)
: reg(__riscv_vreinterpret_v_u16m4_bf16m4(__riscv_vle16_v_u16m4(
reinterpret_cast<const uint16_t*>(ptr), VEC_ELEM_NUM))) {};
explicit BF16Vec32(fixed_vbfloat16m4_t data) : reg(data) {};
explicit BF16Vec32(const BF16Vec8& v) {
fixed_vuint16m1_t u16_val = bf16_to_u16(v.reg);
fixed_vuint16m4_t u16_combined =
__riscv_vcreate_v_u16m1_u16m4(u16_val, u16_val, u16_val, u16_val);
reg = __riscv_vreinterpret_v_u16m4_bf16m4(u16_combined);
};
void save(void* ptr) const {
__riscv_vse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
VEC_ELEM_NUM);
}
void save(void* ptr, int elem_num) const {
__riscv_vse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), bf16_to_u16(reg),
elem_num);
}
void save_strided(void* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
__riscv_vsse16_v_u16m4(reinterpret_cast<uint16_t*>(ptr), byte_stride,
bf16_to_u16(reg), VEC_ELEM_NUM);
}
};
#else
// ============================================================================
// BF16 Fallback Implementation (FP32 Simulation)
// ============================================================================
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat32m2_t reg_fp32;
explicit BF16Vec8(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[8];
for (int i = 0; i < 8; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m2(tmp, 8);
}
explicit BF16Vec8(const FP32Vec8&);
void save(void* ptr) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 8; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[8];
__riscv_vse32_v_f32m2(tmp, reg_fp32, 8);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 8; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat32m4_t reg_fp32;
explicit BF16Vec16(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[16];
for (int i = 0; i < 16; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m4(tmp, 16);
}
explicit BF16Vec16(const FP32Vec16&);
void save(void* ptr) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 16; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[16];
__riscv_vse32_v_f32m4(tmp, reg_fp32, 16);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 16; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
fixed_vfloat32m8_t reg_fp32;
explicit BF16Vec32(const void* ptr) {
const uint16_t* u16 = static_cast<const uint16_t*>(ptr);
float tmp[32];
for (int i = 0; i < 32; ++i) {
uint32_t v = static_cast<uint32_t>(u16[i]) << 16;
std::memcpy(&tmp[i], &v, 4);
}
reg_fp32 = __riscv_vle32_v_f32m8(tmp, 32);
}
explicit BF16Vec32(const BF16Vec8& v) {
float tmp_small[8];
__riscv_vse32_v_f32m2(tmp_small, v.reg_fp32, 8);
float tmp_large[32];
for (int i = 0; i < 4; ++i) {
std::memcpy(tmp_large + (i * 8), tmp_small, 8 * sizeof(float));
}
reg_fp32 = __riscv_vle32_v_f32m8(tmp_large, 32);
}
void save(void* ptr) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < 32; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save(void* ptr, int elem_num) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint16_t* u16 = static_cast<uint16_t*>(ptr);
for (int i = 0; i < elem_num; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
u16[i] = static_cast<uint16_t>(v >> 16);
}
}
void save_strided(void* ptr, ptrdiff_t stride) const {
float tmp[32];
__riscv_vse32_v_f32m8(tmp, reg_fp32, 32);
uint8_t* u8 = static_cast<uint8_t*>(ptr);
ptrdiff_t byte_stride = stride * sizeof(uint16_t);
for (int i = 0; i < 32; ++i) {
uint32_t v;
std::memcpy(&v, &tmp[i], 4);
uint16_t val = static_cast<uint16_t>(v >> 16);
*reinterpret_cast<uint16_t*>(u8 + i * byte_stride) = val;
}
}
};
#endif
// ============================================================================
// FP32 Implementation
// ============================================================================
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
fixed_vfloat32m1_t reg;
explicit FP32Vec4(float v) : reg(__riscv_vfmv_v_f_f32m1(v, VEC_ELEM_NUM)) {};
explicit FP32Vec4() : reg(__riscv_vfmv_v_f_f32m1(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec4(const float* ptr)
: reg(__riscv_vle32_v_f32m1(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec4(fixed_vfloat32m1_t data) : reg(data) {};
explicit FP32Vec4(const FP32Vec4& data) : reg(data.reg) {};
void save(float* ptr) const { __riscv_vse32_v_f32m1(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m1(ptr, reg, elem_num);
}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
fixed_vfloat32m2_t reg;
explicit FP32Vec8(float v) : reg(__riscv_vfmv_v_f_f32m2(v, VEC_ELEM_NUM)) {};
explicit FP32Vec8() : reg(__riscv_vfmv_v_f_f32m2(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec8(const float* ptr)
: reg(__riscv_vle32_v_f32m2(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec8(fixed_vfloat32m2_t data) : reg(data) {};
explicit FP32Vec8(const FP32Vec8& data) : reg(data.reg) {};
explicit FP32Vec8(const FP16Vec8& v)
: reg(__riscv_vfwcvt_f_f_v_f32m2(v.reg, VEC_ELEM_NUM)) {};
explicit FP32Vec8(fixed_vfloat16m1_t v)
: reg(__riscv_vfwcvt_f_f_v_f32m2(v, VEC_ELEM_NUM)) {};
#ifdef RISCV_BF16_SUPPORT
explicit FP32Vec8(fixed_vbfloat16m1_t v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m2(v, VEC_ELEM_NUM)) {};
explicit FP32Vec8(const BF16Vec8& v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m2(v.reg, VEC_ELEM_NUM)) {};
#else
explicit FP32Vec8(const BF16Vec8& v) : reg(v.reg_fp32) {};
#endif
float reduce_sum() const {
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar = __riscv_vfredusum_vs_f32m2_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
FP32Vec8 operator*(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmul_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator+(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfadd_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator-(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfsub_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 operator/(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfdiv_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 min(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmin_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 max(const FP32Vec8& b) const {
return FP32Vec8(__riscv_vfmax_vv_f32m2(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec8 abs() const {
return FP32Vec8(__riscv_vfabs_v_f32m2(reg, VEC_ELEM_NUM));
}
FP32Vec8 min(const FP32Vec8& b, int elem_num) const {
return FP32Vec8(__riscv_vfmin_vv_f32m2(reg, b.reg, elem_num));
}
FP32Vec8 max(const FP32Vec8& b, int elem_num) const {
return FP32Vec8(__riscv_vfmax_vv_f32m2(reg, b.reg, elem_num));
}
FP32Vec8 clamp(const FP32Vec8& min_v, const FP32Vec8& max_v) const {
fixed_vfloat32m2_t temp =
__riscv_vfmax_vv_f32m2(min_v.reg, reg, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfmin_vv_f32m2(max_v.reg, temp, VEC_ELEM_NUM));
}
void save(float* ptr) const { __riscv_vse32_v_f32m2(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m2(ptr, reg, elem_num);
}
void save_strided(float* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(float);
__riscv_vsse32_v_f32m2(ptr, byte_stride, reg, VEC_ELEM_NUM);
}
FP32Vec8 exp() const {
const float inv_ln2 = 1.44269504088896341f;
fixed_vfloat32m2_t x_scaled =
__riscv_vfmul_vf_f32m2(reg, inv_ln2, VEC_ELEM_NUM);
fixed_vint32m2_t n_int = __riscv_vfcvt_x_f_v_i32m2(x_scaled, VEC_ELEM_NUM);
fixed_vfloat32m2_t n_float = __riscv_vfcvt_f_x_v_f32m2(n_int, VEC_ELEM_NUM);
fixed_vfloat32m2_t r =
__riscv_vfsub_vv_f32m2(x_scaled, n_float, VEC_ELEM_NUM);
fixed_vfloat32m2_t poly =
__riscv_vfmv_v_f_f32m2(0.001333355810164f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.009618129107628f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.055504108664821f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.240226506959101f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 0.693147180559945f, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, r, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(poly, 1.0f, VEC_ELEM_NUM);
fixed_vint32m2_t biased_exp =
__riscv_vadd_vx_i32m2(n_int, 127, VEC_ELEM_NUM);
biased_exp = __riscv_vmax_vx_i32m2(biased_exp, 0, VEC_ELEM_NUM);
fixed_vint32m2_t exponent_bits =
__riscv_vsll_vx_i32m2(biased_exp, 23, VEC_ELEM_NUM);
fixed_vfloat32m2_t scale =
__riscv_vreinterpret_v_i32m2_f32m2(exponent_bits);
return FP32Vec8(__riscv_vfmul_vv_f32m2(poly, scale, VEC_ELEM_NUM));
}
FP32Vec8 tanh() const {
fixed_vfloat32m2_t x_clamped = __riscv_vfmin_vf_f32m2(
__riscv_vfmax_vf_f32m2(reg, -9.0f, VEC_ELEM_NUM), 9.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t x2 =
__riscv_vfmul_vf_f32m2(x_clamped, 2.0f, VEC_ELEM_NUM);
FP32Vec8 exp_val = FP32Vec8(x2).exp();
fixed_vfloat32m2_t num =
__riscv_vfsub_vf_f32m2(exp_val.reg, 1.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t den =
__riscv_vfadd_vf_f32m2(exp_val.reg, 1.0f, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfdiv_vv_f32m2(num, den, VEC_ELEM_NUM));
}
FP32Vec8 er() const {
const float p = 0.3275911f, a1 = 0.254829592f, a2 = -0.284496736f,
a3 = 1.421413741f, a4 = -1.453152027f, a5 = 1.061405429f;
fixed_vfloat32m2_t abs_x = __riscv_vfabs_v_f32m2(reg, VEC_ELEM_NUM);
fixed_vfloat32m2_t t = __riscv_vfadd_vf_f32m2(
__riscv_vfmul_vf_f32m2(abs_x, p, VEC_ELEM_NUM), 1.0f, VEC_ELEM_NUM);
t = __riscv_vfrdiv_vf_f32m2(t, 1.0f, VEC_ELEM_NUM);
fixed_vfloat32m2_t poly = __riscv_vfmv_v_f_f32m2(a5, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a4, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a3, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a2, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m2(__riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM),
a1, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m2(poly, t, VEC_ELEM_NUM);
fixed_vfloat32m2_t exp_val =
FP32Vec8(__riscv_vfneg_v_f32m2(
__riscv_vfmul_vv_f32m2(abs_x, abs_x, VEC_ELEM_NUM),
VEC_ELEM_NUM))
.exp()
.reg;
fixed_vfloat32m2_t res = __riscv_vfrsub_vf_f32m2(
__riscv_vfmul_vv_f32m2(poly, exp_val, VEC_ELEM_NUM), 1.0f,
VEC_ELEM_NUM);
vbool16_t mask = __riscv_vmflt_vf_f32m2_b16(reg, 0.0f, VEC_ELEM_NUM);
return FP32Vec8(__riscv_vfneg_v_f32m2_m(mask, res, VEC_ELEM_NUM));
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
fixed_vfloat32m4_t reg;
explicit FP32Vec16(float v) : reg(__riscv_vfmv_v_f_f32m4(v, VEC_ELEM_NUM)) {};
explicit FP32Vec16() : reg(__riscv_vfmv_v_f_f32m4(0.0f, VEC_ELEM_NUM)) {};
explicit FP32Vec16(const float* ptr)
: reg(__riscv_vle32_v_f32m4(ptr, VEC_ELEM_NUM)) {};
explicit FP32Vec16(fixed_vfloat32m4_t data) : reg(data) {};
explicit FP32Vec16(const FP32Vec8& data)
: reg(__riscv_vcreate_v_f32m2_f32m4(data.reg, data.reg)) {};
explicit FP32Vec16(const FP32Vec16& data) : reg(data.reg) {};
explicit FP32Vec16(const FP16Vec16& v);
#ifdef RISCV_BF16_SUPPORT
explicit FP32Vec16(fixed_vbfloat16m2_t v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m4(v, VEC_ELEM_NUM)) {};
explicit FP32Vec16(const BF16Vec16& v)
: reg(__riscv_vfwcvtbf16_f_f_v_f32m4(v.reg, VEC_ELEM_NUM)) {};
#else
explicit FP32Vec16(const BF16Vec16& v) : reg(v.reg_fp32) {};
#endif
FP32Vec16 operator+(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfadd_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator-(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfsub_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator*(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmul_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 operator/(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfdiv_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 fma(const FP32Vec16& a, const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmacc_vv_f32m4(reg, a.reg, b.reg, VEC_ELEM_NUM));
}
float reduce_sum() const {
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar = __riscv_vfredusum_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
float reduce_max() const {
fixed_vfloat32m1_t scalar =
__riscv_vfmv_s_f_f32m1(std::numeric_limits<float>::lowest(), 1);
scalar = __riscv_vfredmax_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
float reduce_min() const {
fixed_vfloat32m1_t scalar =
__riscv_vfmv_s_f_f32m1(std::numeric_limits<float>::max(), 1);
scalar = __riscv_vfredmin_vs_f32m4_f32m1(reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
}
template <int group_size>
float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
const int start = idx * group_size;
vuint32m4_t indices = __riscv_vid_v_u32m4(VEC_ELEM_NUM);
vbool8_t mask = __riscv_vmand_mm_b8(
__riscv_vmsgeu_vx_u32m4_b8(indices, start, VEC_ELEM_NUM),
__riscv_vmsltu_vx_u32m4_b8(indices, start + group_size, VEC_ELEM_NUM),
VEC_ELEM_NUM);
fixed_vfloat32m1_t scalar = __riscv_vfmv_s_f_f32m1(0.0f, 1);
scalar =
__riscv_vfredusum_vs_f32m4_f32m1_m(mask, reg, scalar, VEC_ELEM_NUM);
return __riscv_vfmv_f_s_f32m1_f32(scalar);
};
FP32Vec16 max(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmax_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 min(const FP32Vec16& b) const {
return FP32Vec16(__riscv_vfmin_vv_f32m4(reg, b.reg, VEC_ELEM_NUM));
}
FP32Vec16 abs() const {
return FP32Vec16(__riscv_vfabs_v_f32m4(reg, VEC_ELEM_NUM));
}
FP32Vec16 clamp(const FP32Vec16& min_v, const FP32Vec16& max_v) const {
return FP32Vec16(__riscv_vfmin_vv_f32m4(
max_v.reg, __riscv_vfmax_vv_f32m4(min_v.reg, reg, VEC_ELEM_NUM),
VEC_ELEM_NUM));
}
void save(float* ptr) const { __riscv_vse32_v_f32m4(ptr, reg, VEC_ELEM_NUM); }
void save(float* ptr, int elem_num) const {
__riscv_vse32_v_f32m4(ptr, reg, elem_num);
}
void save_strided(float* ptr, ptrdiff_t stride) const {
ptrdiff_t byte_stride = stride * sizeof(float);
__riscv_vsse32_v_f32m4(ptr, byte_stride, reg, VEC_ELEM_NUM);
}
FP32Vec16 exp() const {
const float inv_ln2 = 1.44269504088896341f;
fixed_vfloat32m4_t x_scaled =
__riscv_vfmul_vf_f32m4(reg, inv_ln2, VEC_ELEM_NUM);
fixed_vint32m4_t n_int = __riscv_vfcvt_x_f_v_i32m4(x_scaled, VEC_ELEM_NUM);
fixed_vfloat32m4_t n_float = __riscv_vfcvt_f_x_v_f32m4(n_int, VEC_ELEM_NUM);
fixed_vfloat32m4_t r =
__riscv_vfsub_vv_f32m4(x_scaled, n_float, VEC_ELEM_NUM);
fixed_vfloat32m4_t poly =
__riscv_vfmv_v_f_f32m4(0.001333355810164f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.009618129107628f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.055504108664821f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.240226506959101f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
0.693147180559945f, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, r, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM);
fixed_vint32m4_t biased_exp = __riscv_vmax_vx_i32m4(
__riscv_vadd_vx_i32m4(n_int, 127, VEC_ELEM_NUM), 0, VEC_ELEM_NUM);
fixed_vfloat32m4_t scale = __riscv_vreinterpret_v_i32m4_f32m4(
__riscv_vsll_vx_i32m4(biased_exp, 23, VEC_ELEM_NUM));
return FP32Vec16(__riscv_vfmul_vv_f32m4(poly, scale, VEC_ELEM_NUM));
}
FP32Vec16 tanh() const {
fixed_vfloat32m4_t x_clamped = __riscv_vfmin_vf_f32m4(
__riscv_vfmax_vf_f32m4(reg, -9.0f, VEC_ELEM_NUM), 9.0f, VEC_ELEM_NUM);
FP32Vec16 exp_val =
FP32Vec16(__riscv_vfmul_vf_f32m4(x_clamped, 2.0f, VEC_ELEM_NUM)).exp();
return FP32Vec16(__riscv_vfdiv_vv_f32m4(
__riscv_vfsub_vf_f32m4(exp_val.reg, 1.0f, VEC_ELEM_NUM),
__riscv_vfadd_vf_f32m4(exp_val.reg, 1.0f, VEC_ELEM_NUM), VEC_ELEM_NUM));
}
FP32Vec16 er() const {
const float p = 0.3275911f, a1 = 0.254829592f, a2 = -0.284496736f,
a3 = 1.421413741f, a4 = -1.453152027f, a5 = 1.061405429f;
fixed_vfloat32m4_t abs_x = __riscv_vfabs_v_f32m4(reg, VEC_ELEM_NUM);
fixed_vfloat32m4_t t = __riscv_vfrdiv_vf_f32m4(
__riscv_vfadd_vf_f32m4(__riscv_vfmul_vf_f32m4(abs_x, p, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM),
1.0f, VEC_ELEM_NUM);
fixed_vfloat32m4_t poly = __riscv_vfmv_v_f_f32m4(a5, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a4, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a3, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a2, VEC_ELEM_NUM);
poly = __riscv_vfadd_vf_f32m4(__riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM),
a1, VEC_ELEM_NUM);
poly = __riscv_vfmul_vv_f32m4(poly, t, VEC_ELEM_NUM);
fixed_vfloat32m4_t exp_val =
FP32Vec16(__riscv_vfneg_v_f32m4(
__riscv_vfmul_vv_f32m4(abs_x, abs_x, VEC_ELEM_NUM),
VEC_ELEM_NUM))
.exp()
.reg;
fixed_vfloat32m4_t res = __riscv_vfrsub_vf_f32m4(
__riscv_vfmul_vv_f32m4(poly, exp_val, VEC_ELEM_NUM), 1.0f,
VEC_ELEM_NUM);
vbool8_t mask = __riscv_vmflt_vf_f32m4_b8(reg, 0.0f, VEC_ELEM_NUM);
return FP32Vec16(__riscv_vfneg_v_f32m4_m(mask, res, VEC_ELEM_NUM));
}
};
// ============================================================================
// Type Traits & Global Helpers
// ============================================================================
template <typename T>
struct VecType {
using vec_type = void;
using vec_t = void;
};
template <typename T>
using vec_t = typename VecType<T>::vec_type;
template <>
struct VecType<float> {
using vec_type = FP32Vec8;
using vec_t = FP32Vec8;
};
template <>
struct VecType<c10::Half> {
using vec_type = FP16Vec8;
using vec_t = FP16Vec8;
};
template <>
struct VecType<c10::BFloat16> {
using vec_type = BF16Vec8;
using vec_t = BF16Vec8;
};
template <typename T>
void storeFP32(float v, T* ptr) {
*ptr = v;
}
template <>
inline void storeFP32<c10::Half>(float v, c10::Half* ptr) {
*reinterpret_cast<_Float16*>(ptr) = static_cast<_Float16>(v);
}
inline FP16Vec16::FP16Vec16(const FP32Vec16& v) {
reg = __riscv_vfncvt_f_f_w_f16m2(v.reg, VEC_ELEM_NUM);
}
inline FP16Vec8::FP16Vec8(const FP32Vec8& v) {
reg = __riscv_vfncvt_f_f_w_f16m1(v.reg, VEC_ELEM_NUM);
}
inline FP32Vec16::FP32Vec16(const FP16Vec16& v) {
reg = __riscv_vfwcvt_f_f_v_f32m4(v.reg, VEC_ELEM_NUM);
}
inline void fma(FP32Vec16& acc, const FP32Vec16& a, const FP32Vec16& b) {
acc = acc.fma(a, b);
}
#ifdef RISCV_BF16_SUPPORT
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
*ptr = static_cast<__bf16>(v);
};
inline BF16Vec8::BF16Vec8(const FP32Vec8& v)
: reg(__riscv_vfncvtbf16_f_f_w_bf16m1(v.reg, VEC_ELEM_NUM)) {};
inline BF16Vec16::BF16Vec16(const FP32Vec16& v)
: reg(__riscv_vfncvtbf16_f_f_w_bf16m2(v.reg, VEC_ELEM_NUM)) {};
#else
template <>
inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16* ptr) {
uint32_t val;
std::memcpy(&val, &v, 4);
*reinterpret_cast<uint16_t*>(ptr) = static_cast<uint16_t>(val >> 16);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8& v) : reg_fp32(v.reg) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16& v) : reg_fp32(v.reg) {}
#endif
inline void prefetch(const void* addr) { __builtin_prefetch(addr, 0, 1); }
} // namespace vec_op
#ifndef CPU_KERNEL_GUARD_IN
#define CPU_KERNEL_GUARD_IN(NAME)
#endif
#ifndef CPU_KERNEL_GUARD_OUT
#define CPU_KERNEL_GUARD_OUT(NAME)
#endif
#endif // CPU_TYPES_RISCV_HPP
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment