diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index 72c52d5bb5e9ba8fe9b9602862ed2a9b20aa9ab0..cdf6a645147e5aae5ca9a066a8685bc149e7988e 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc ## Performance benchmark quick overview -**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models. +**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models. **Benchmarking Duration**: about 1hr. @@ -31,13 +31,27 @@ Performance benchmark will be triggered when: - A PR being merged into vllm. - Every commit for those PRs with `perf-benchmarks` label AND `ready` label. +Manually Trigger the benchmark + +```bash +bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +``` + +Runtime environment variables: +- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0. +- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file). +- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file). +- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file). +- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string. +- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string. + Nightly benchmark will be triggered when: - Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. - +> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead. ### Latency test Here is an example of one test inside `latency-tests.json`: @@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running. The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file. The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking. +The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`. +When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`. +`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT. + +Here is an example using the script to compare result_a and result_b without detail test name. +`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name` + +| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio | +|----|----------------------------------------|----------------------------------------|----------| +| 0 | 142.633982 | 156.526018 | 1.097396 | +| 1 | 241.620334 | 294.018783 | 1.216863 | +| 2 | 218.298905 | 262.664916 | 1.203235 | +| 3 | 242.743860 | 299.816190 | 1.235113 | + +Here is an example using the script to compare result_a and result_b with detail test name. +`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json` +| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio | +|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------| +| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 | +| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 | +| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 | +| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 | +| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 | + ## Nightly test details See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines. diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md index e43ea765f1556125db422bd18007fff5fa0a17f6..ef11c040057c8d218464e792c7af838871ea8393 100644 --- a/.buildkite/nightly-benchmarks/nightly-annotation.md +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -16,7 +16,7 @@ Please download the visualization scripts in the post - Download `nightly-benchmarks.zip`. - In the same folder, run the following code: - ```console + ```bash export HF_TOKEN= apt update apt install -y git diff --git a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md index cacaef986c658ae4e4145ad738589a8c229efe95..a1f8441ccdac8a11586adbc4da4b99e4d77183f6 100644 --- a/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md +++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md @@ -4,7 +4,8 @@ - Input length: 32 tokens. - Output length: 128 tokens. - Batch size: fixed (8). -- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- CPU Models: llama-3.1 8B. - Evaluation metrics: end-to-end latency (mean, median, p99). {latency_tests_markdown_table} @@ -14,7 +15,8 @@ - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). - Output length: the corresponding output length of these 200 prompts. - Batch size: dynamically determined by vllm to achieve maximum throughput. -- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- CPU Models: llama-3.1 8B. - Evaluation metrics: throughput. {throughput_tests_markdown_table} @@ -25,12 +27,18 @@ - Output length: the corresponding output length of these 200 prompts. - Batch size: dynamically determined by vllm and the arrival pattern of the requests. - **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. -- We also added a speculative decoding test for llama-3 70B, under QPS 2 +- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2 +- CPU Models: llama-3.1 8B. - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). +- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts. {serving_tests_markdown_table} +## Platform Information + +{platform_markdown_table} + ## json version of the benchmarking tables This section contains the data of the markdown tables above in JSON format. diff --git a/.buildkite/nightly-benchmarks/scripts/compare-json-results.py b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py new file mode 100644 index 0000000000000000000000000000000000000000..20c106234935c30f05e62b2b59896e7f1b84df81 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/compare-json-results.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse + +import pandas as pd + + +def compare_data_columns( + files, name_column, data_column, drop_column, ignore_test_name=False +): + print("\ncompare_data_column: " + data_column) + frames = [] + compare_frames = [] + for file in files: + data_df = pd.read_json(file) + serving_df = data_df.dropna(subset=[drop_column], ignore_index=True) + if ignore_test_name is False: + serving_df = serving_df.rename(columns={name_column: file + "_name"}) + frames.append(serving_df[file + "_name"]) + serving_df = serving_df.rename(columns={data_column: file}) + frames.append(serving_df[file]) + compare_frames.append(serving_df[file]) + if len(compare_frames) >= 2: + # Compare numbers among two files + ratio_df = compare_frames[1] / compare_frames[0] + frames.append(ratio_df) + compare_frames.pop(1) + + concat_df = pd.concat(frames, axis=1) + return concat_df + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-f", "--file", action="append", type=str, help="input file name" + ) + parser.add_argument( + "--ignore_test_name", action="store_true", help="ignore_test_name or not" + ) + args = parser.parse_args() + files = args.file + print("comparing : " + ", ".join(files)) + + drop_column = "P99" + name_column = "Test name" + data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"] + html_msgs_for_data_cols = [ + "Compare Output Tokens /n", + "Median TTFT /n", + "Median TPOT /n", + ] + ignore_test_name = args.ignore_test_name + with open("perf_comparison.html", "w") as text_file: + for i in range(len(data_cols_to_compare)): + output_df = compare_data_columns( + files, + name_column, + data_cols_to_compare[i], + drop_column, + ignore_test_name=ignore_test_name, + ) + print(output_df) + html = output_df.to_html() + text_file.write(html_msgs_for_data_cols[i]) + text_file.write(html) diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index a4f1638c1adb8db7336960f9e227b23b054181a8..724b53056ca8fc93a1a0c576b0d7ac9cf060cb14 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -3,9 +3,11 @@ import json import os +from importlib import util from pathlib import Path import pandas as pd +import psutil from tabulate import tabulate results_folder = Path("results/") @@ -29,11 +31,11 @@ throughput_results = [] throughput_results_column_mapping = { "test_name": "Test name", "gpu_type": "GPU", - # "num_requests": "# of req.", - # "total_num_tokens": "Total # of tokens", - # "elapsed_time": "Elapsed time (s)", + "num_requests": "# of req.", + "total_num_tokens": "Total # of tokens", + "elapsed_time": "Elapsed time (s)", "requests_per_second": "Tput (req/s)", - # "tokens_per_second": "Tput (tok/s)", + "tokens_per_second": "Tput (tok/s)", } # serving results and the keys that will be printed into markdown @@ -41,16 +43,18 @@ serving_results = [] serving_column_mapping = { "test_name": "Test name", "gpu_type": "GPU", - # "completed": "# of req.", + "completed": "# of req.", "request_throughput": "Tput (req/s)", - # "input_throughput": "Input Tput (tok/s)", - # "output_throughput": "Output Tput (tok/s)", + "total_token_throughput": "Total Token Tput (tok/s)", + "output_throughput": "Output Tput (tok/s)", + "total_input_tokens": "Total input tokens", + "total_output_tokens": "Total output tokens", "mean_ttft_ms": "Mean TTFT (ms)", "median_ttft_ms": "Median TTFT (ms)", "p99_ttft_ms": "P99 TTFT (ms)", - # "mean_tpot_ms": "Mean TPOT (ms)", - # "median_tpot_ms": "Median", - # "p99_tpot_ms": "P99", + "mean_tpot_ms": "Mean TPOT (ms)", + "median_tpot_ms": "Median", + "p99_tpot_ms": "P99", "mean_itl_ms": "Mean ITL (ms)", "median_itl_ms": "Median ITL (ms)", "p99_itl_ms": "P99 ITL (ms)", @@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving): ) +def get_size_with_unit(bytes, suffix="B"): + """ + Scale bytes to its proper format + e.g: + 1253656 => '1.20MB' + 1253656678 => '1.17GB' + """ + factor = 1024 + for unit in ["", "K", "M", "G", "T", "P"]: + if bytes < factor: + return f"{bytes:.2f}{unit}{suffix}" + bytes /= factor + + if __name__ == "__main__": # collect results for test_file in results_folder.glob("*.json"): @@ -155,6 +173,27 @@ if __name__ == "__main__": serving_results = pd.DataFrame.from_dict(serving_results) throughput_results = pd.DataFrame.from_dict(throughput_results) + svmem = psutil.virtual_memory() + platform_data = { + "Physical cores": [psutil.cpu_count(logical=False)], + "Total cores": [psutil.cpu_count(logical=True)], + "Total Memory": [get_size_with_unit(svmem.total)], + } + + if util.find_spec("numa") is not None: + from numa import info + + platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()] + + if util.find_spec("cpuinfo") is not None: + from cpuinfo import get_cpu_info + + platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]] + + platform_results = pd.DataFrame.from_dict( + platform_data, orient="index", columns=["Platform Info"] + ) + raw_results_json = results_to_json( latency_results, throughput_results, serving_results ) @@ -200,6 +239,9 @@ if __name__ == "__main__": throughput_md_table = tabulate( throughput_results, headers="keys", tablefmt="pipe", showindex=False ) + platform_md_table = tabulate( + platform_results, headers="keys", tablefmt="pipe", showindex=True + ) # document the result with open(results_folder / "benchmark_results.md", "w") as f: @@ -211,6 +253,7 @@ if __name__ == "__main__": latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, serving_tests_markdown_table=serving_md_table, + platform_markdown_table=platform_md_table, benchmarking_results_in_json_string=processed_results_json, ) f.write(results) diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index 80ebb370ad4615a954c2e80b56c49d64a19bf3a8..f05040618981cc29430936145d13dc2d8e028632 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -31,6 +31,20 @@ check_gpus() { echo "GPU type is $gpu_type" } +check_cpus() { + # check the number of CPUs and NUMA Node and GPU type. + declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)") + if [[ $numa_count -gt 0 ]]; then + echo "NUMA found." + echo $numa_count + else + echo "Need at least 1 NUMA to run benchmarking." + exit 1 + fi + declare -g gpu_type="cpu" + echo "GPU type is $gpu_type" +} + check_hf_token() { # check if HF_TOKEN is available and valid if [[ -z "$HF_TOKEN" ]]; then @@ -69,6 +83,22 @@ json2args() { echo "$args" } +json2envs() { + # transforms the JSON string to environment variables. + # example: + # input: { "VLLM_CPU_KVCACHE_SPACE": 5 } + # output: VLLM_CPU_KVCACHE_SPACE=5 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map((.key ) + "=" + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + wait_for_server() { # wait for vllm server to start # return 1 if vllm server crashes @@ -158,15 +188,24 @@ run_latency_tests() { # get arguments latency_params=$(echo "$params" | jq -r '.parameters') latency_args=$(json2args "$latency_params") + latency_environment_variables=$(echo "$params" | jq -r '.environment_variables') + latency_envs=$(json2envs "$latency_environment_variables") # check if there is enough GPU to run the test tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue + if [ "$ON_CPU" == "1" ];then + if [[ $numa_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + continue + fi + else + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi fi - latency_command="python3 benchmark_latency.py \ + latency_command=" $latency_envs python3 benchmark_latency.py \ --output-json $RESULTS_FOLDER/${test_name}.json \ $latency_args" @@ -216,15 +255,24 @@ run_throughput_tests() { # get arguments throughput_params=$(echo "$params" | jq -r '.parameters') throughput_args=$(json2args "$throughput_params") + throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables') + throughput_envs=$(json2envs "$throughput_environment_variables") # check if there is enough GPU to run the test tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue + if [ "$ON_CPU" == "1" ];then + if [[ $numa_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + continue + fi + else + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi fi - throughput_command="python3 benchmark_throughput.py \ + throughput_command=" $throughput_envs python3 benchmark_throughput.py \ --output-json $RESULTS_FOLDER/${test_name}.json \ $throughput_args" @@ -272,18 +320,27 @@ run_serving_tests() { # get client and server arguments server_params=$(echo "$params" | jq -r '.server_parameters') + server_envs=$(echo "$params" | jq -r '.server_environment_variables') client_params=$(echo "$params" | jq -r '.client_parameters') server_args=$(json2args "$server_params") + server_envs=$(json2envs "$server_envs") client_args=$(json2args "$client_params") qps_list=$(echo "$params" | jq -r '.qps_list') qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') echo "Running over qps list $qps_list" - # check if there is enough GPU to run the test + # check if there is enough resources to run the test tp=$(echo "$server_params" | jq -r '.tensor_parallel_size') - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue + if [ "$ON_CPU" == "1" ];then + if [[ $numa_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name." + continue + fi + else + if [[ $gpu_count -lt $tp ]]; then + echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi fi # check if server model and client model is aligned @@ -294,23 +351,33 @@ run_serving_tests() { continue fi - server_command="python3 \ + server_command="$server_envs python3 \ -m vllm.entrypoints.openai.api_server \ $server_args" # run the server echo "Running test case $test_name" echo "Server command: $server_command" - bash -c "$server_command" & - server_pid=$! - - # wait until the server is alive - if wait_for_server; then - echo "" - echo "vllm server is up and running." + # support remote vllm server + client_remote_args="" + if [[ -z "${REMOTE_HOST}" ]]; then + bash -c "$server_command" & + server_pid=$! + # wait until the server is alive + if wait_for_server; then + echo "" + echo "vLLM server is up and running." + else + echo "" + echo "vLLM failed to start within the timeout period." + fi else - echo "" - echo "vllm failed to start within the timeout period." + server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT" + if [[ ${REMOTE_PORT} ]]; then + client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT " + else + client_remote_args=" --host=$REMOTE_HOST " + fi fi # iterate over different QPS @@ -332,7 +399,7 @@ run_serving_tests() { --result-filename ${new_test_name}.json \ --request-rate $qps \ --metadata "tensor_parallel_size=$tp" \ - $client_args" + $client_args $client_remote_args " echo "Running test case $test_name with qps $qps" echo "Client command: $client_command" @@ -360,7 +427,14 @@ run_serving_tests() { } main() { - check_gpus + local ARCH + ARCH='' + if [ "$ON_CPU" == "1" ];then + check_cpus + ARCH='-cpu' + else + check_gpus + fi check_hf_token # Set to v1 to run v1 benchmark @@ -386,9 +460,9 @@ main() { QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ # benchmarking - run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json - run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json - run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json + run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}" + run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}" + run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}" # postprocess benchmarking results pip install tabulate pandas diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json new file mode 100644 index 0000000000000000000000000000000000000000..da93fdd1dbac1f496845bf931f69a016960317ef --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json @@ -0,0 +1,30 @@ +[ + { + "test_name": "latency_llama8B_tp1", + "environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "load_format": "dummy", + "num_iters_warmup": 5, + "num_iters": 15 + } + }, + { + "test_name": "latency_llama8B_tp4", + "environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "load_format": "dummy", + "num_iters_warmup": 5, + "num_iters": 15 + } + } +] diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json new file mode 100644 index 0000000000000000000000000000000000000000..22f71c993ff3345189d5affb3548062c0720a2de --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json @@ -0,0 +1,158 @@ +[ + { + "test_name": "serving_llama8B_tp1_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "disable_log_requests": "", + "enforce_eager": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "max_concurrency": 60, + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_tp2_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 2, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "disable_log_requests": "", + "enforce_eager": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "max_concurrency": 60, + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_tp4_sharegpt", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "disable_log_stats": "", + "disable_log_requests": "", + "enforce_eager": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "sharegpt", + "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "max_concurrency": 60, + "num_prompts": 200 + } + }, + { + "test_name": "serving_llama8B_tp4_random_1024_128", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "disable_log_requests": "", + "enforce_eager": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 1024, + "random-output-len": 128, + "ignore-eos": "", + "max_concurrency": 100, + "num_prompts": 100 + } + }, + { + "test_name": "serving_llama8B_pp6_random_1024_128", + "qps_list": [1, 4, 16, "inf"], + "server_environment_variables": { + "VLLM_RPC_TIMEOUT": 100000, + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_ENGINE_ITERATION_TIMEOUT_S": 120, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "server_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "pipeline_parallel_size": 6, + "dtype": "bfloat16", + "distributed_executor_backend": "mp", + "block_size": 128, + "trust_remote_code": "", + "enable_chunked_prefill": "", + "disable_log_stats": "", + "disable_log_requests": "", + "enforce_eager": "", + "load_format": "dummy" + }, + "client_parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "backend": "vllm", + "dataset_name": "random", + "random-input-len": 1024, + "random-output-len": 128, + "ignore-eos": "", + "max_concurrency": 100, + "num_prompts": 100 + } + } +] diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json new file mode 100644 index 0000000000000000000000000000000000000000..f159c30637d349c4d446fe48569bf5440721e1ad --- /dev/null +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests-cpu.json @@ -0,0 +1,32 @@ +[ + { + "test_name": "throughput_llama8B_tp1", + "environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 1, + "load_format": "dummy", + "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200, + "backend": "vllm" + } + }, + { + "test_name": "throughput_llama8B_tp4", + "environment_variables": { + "VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1, + "VLLM_CPU_KVCACHE_SPACE": 40 + }, + "parameters": { + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "tensor_parallel_size": 4, + "load_format": "dummy", + "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "num_prompts": 200, + "backend": "vllm" + } + } +] diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 16b5ad0297fe79ced4bbb25f7a1ced0c58425c22..6314afd652340352056e2770c4ee6dd2cbe121ab 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -52,7 +52,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - label: "Annotate release workflow" @@ -101,7 +101,8 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" @@ -117,6 +118,7 @@ steps: commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ." + - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest" - "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)" env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 6e9af1e721bb70cdf38914ea32f48156bbb7248f..156456c92e63cc92ee2b2dda446d987fb6190354 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -107,10 +107,9 @@ fi if [[ $commands == *" kernels/attention"* ]]; then commands="${commands} \ - --ignore=kernels/attention/stest_attention_selector.py \ + --ignore=kernels/attention/test_attention_selector.py \ --ignore=kernels/attention/test_blocksparse_attention.py \ --ignore=kernels/attention/test_encoder_decoder_attn.py \ - --ignore=kernels/attention/test_attention_selector.py \ --ignore=kernels/attention/test_flash_attn.py \ --ignore=kernels/attention/test_flashinfer.py \ --ignore=kernels/attention/test_prefix_prefill.py \ diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index bbcde4009c0eb9e9f30ed14bb4b7cfc7a25ea0a2..42506730e868cc601d110686b1c447fc3f083305 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -24,13 +24,22 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e export NUMA_NODE=$2 + # list packages + docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " + set -e + pip list" + + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pip list" + # offline inference docker exec cpu-test-"$NUMA_NODE"-avx2 bash -c " set -e @@ -42,6 +51,7 @@ function cpu_tests() { pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model pytest -v -s tests/models/language/generation -m cpu_model + VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model pytest -v -s tests/models/language/pooling -m cpu_model pytest -v -s tests/models/multimodal/generation \ --ignore=tests/models/multimodal/generation/test_mllama.py \ @@ -72,7 +82,7 @@ function cpu_tests() { set -e python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - python3 benchmarks/benchmark_serving.py \ + VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \ --backend vllm \ --dataset-name random \ --model facebook/opt-125m \ @@ -89,4 +99,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-hpu-test.sh b/.buildkite/scripts/hardware_ci/run-hpu-test.sh index 5efac3ddf469f244ee3889e97482849d7c16feb8..ae5b35a9ac6bd37290b415ac2382c3ad13551b16 100644 --- a/.buildkite/scripts/hardware_ci/run-hpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-hpu-test.sh @@ -2,10 +2,34 @@ # This script build the CPU docker image and run the offline inference inside the container. # It serves a sanity check for compilation and basic model usage. -set -ex +set -exuo pipefail # Try building the docker image -docker build -t hpu-test-env -f docker/Dockerfile.hpu . +cat <&2 +fi + +# The trap will handle the container removal and final exit. \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-neuron-test.sh b/.buildkite/scripts/hardware_ci/run-neuron-test.sh index 3d294ea5f8a755f4efc57c76eacafd3d8ba9a94b..a397457c83261757a6b1a2812c032be4f43558da 100644 --- a/.buildkite/scripts/hardware_ci/run-neuron-test.sh +++ b/.buildkite/scripts/hardware_ci/run-neuron-test.sh @@ -54,10 +54,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \ --name "${container_name}" \ ${image_name} \ /bin/bash -c " + set -e; # Exit on first error python3 /workspace/vllm/examples/offline_inference/neuron.py; python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys; for f in /workspace/vllm/tests/neuron/2_core/*.py; do - echo 'Running test file: '$f; + echo \"Running test file: \$f\"; python3 -m pytest \$f -v --capture=tee-sys; done " \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index a2a5c2a02cbb9776eda2a42fa4232861bb82896c..90cad506ab1e90ed9e21d5c8500233b44ec3f847 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py" run_and_track_test 15 "test_spmd_model_weight_loading.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py" +run_and_track_test 16 "test_kv_cache_update_kernel.py" \ + "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py" # After all tests have been attempted, exit with the overall status. if [ "$overall_script_exit_code" -ne 0 ]; then diff --git a/.buildkite/scripts/hardware_ci/run-xpu-test.sh b/.buildkite/scripts/hardware_ci/run-xpu-test.sh index f54010c4231f959c94e3e1c4aa2a4971ba3e6f5e..827649bfcf5487db8b105dcf81145e869b41907b 100644 --- a/.buildkite/scripts/hardware_ci/run-xpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-xpu-test.sh @@ -28,4 +28,5 @@ docker run \ sh -c ' VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2 + VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager ' diff --git a/.buildkite/scripts/tpu/config_v6e_1.env b/.buildkite/scripts/tpu/config_v6e_1.env index 44175864734746458fd1c0c69bf3c5d907f71848..03ec116f698d2d416f58ded48e45510529a0bd5d 100644 --- a/.buildkite/scripts/tpu/config_v6e_1.env +++ b/.buildkite/scripts/tpu/config_v6e_1.env @@ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu # vllm config MODEL=meta-llama/Llama-3.1-8B-Instruct -MAX_NUM_SEQS=512 -MAX_NUM_BATCHED_TOKENS=512 +MAX_NUM_SEQS=256 +MAX_NUM_BATCHED_TOKENS=1024 TENSOR_PARALLEL_SIZE=1 MAX_MODEL_LEN=2048 DOWNLOAD_DIR=/mnt/disks/persist diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh index 6705da03e3d761baf4d8849bba5444cc0e9c7c6b..715afce5f71ab7cbbbff6e614482858022029969 100755 --- a/.buildkite/scripts/tpu/docker_run_bm.sh +++ b/.buildkite/scripts/tpu/docker_run_bm.sh @@ -68,7 +68,7 @@ docker run \ echo "run script..." echo -docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh" +docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh" echo "copy result back..." VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt diff --git a/.buildkite/scripts/tpu/quantized_v6e_1.env b/.buildkite/scripts/tpu/quantized_v6e_1.env new file mode 100644 index 0000000000000000000000000000000000000000..bab34b3be3b9a01cb7dd19120914224f446cf386 --- /dev/null +++ b/.buildkite/scripts/tpu/quantized_v6e_1.env @@ -0,0 +1,14 @@ +# Environment config +TEST_NAME=llama8bw8a8 +CONTAINER_NAME=vllm-tpu + +# vllm config +MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8 +MAX_NUM_SEQS=128 +MAX_NUM_BATCHED_TOKENS=1024 +TENSOR_PARALLEL_SIZE=1 +MAX_MODEL_LEN=2048 +DOWNLOAD_DIR=/mnt/disks/persist +EXPECTED_THROUGHPUT=10.0 +INPUT_LEN=1800 +OUTPUT_LEN=128 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b739851cb90528b0f1b7feab14b14a6c4ded0802..148cf8074232fbfb8db4e22555d2a8c10c432dc5 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -41,6 +41,16 @@ steps: # TODO: add `--strict` once warnings in docstrings are fixed - mkdocs build +- label: Pytorch Nightly Dependency Override Check # 2min + # if this test fails, it means the nightly torch version is not compatible with some + # of the dependencies. Please check the error message and add the package to whitelist + # in /vllm/tools/generate_nightly_torch_test.py + soft_fail: true + source_file_dependencies: + - requirements/nightly_torch_test.txt + commands: + - bash standalone_tests/pytorch_nightly_dependency.sh + - label: Async Engine, Inputs, Utils, Worker Test # 24min mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -89,7 +99,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -145,6 +155,7 @@ steps: - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - tests/v1/test_async_llm_dp.py + - tests/v1/test_external_lb_dp.py - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 @@ -153,8 +164,9 @@ steps: # test with tp=2 and pp=2 - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - - python3 ../examples/offline_inference/data_parallel.py + - python3 ../examples/offline_inference/data_parallel.py --enforce-eager - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py @@ -168,6 +180,23 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd +- label: EPLB Algorithm Test + working_dir: "/vllm-workspace/tests" + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_algo.py + commands: + - pytest -v -s distributed/test_eplb_algo.py + +- label: EPLB Execution Test # 5min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/eplb + - tests/distributed/test_eplb_execute.py + commands: + - pytest -v -s distributed/test_eplb_execute.py + - label: Metrics, Tracing Test # 10min mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 @@ -177,13 +206,18 @@ steps: - tests/tracing commands: - pytest -v -s metrics + - "pip install \ + 'opentelemetry-sdk>=1.26.0' \ + 'opentelemetry-api>=1.26.0' \ + 'opentelemetry-exporter-otlp>=1.26.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1'" - pytest -v -s tracing ##### fast check tests ##### ##### 1 GPU test ##### - label: Regression Test # 5min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/test_regression @@ -193,7 +227,7 @@ steps: working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/engine @@ -266,6 +300,15 @@ steps: commands: - pytest -v -s prefix_caching + +- label: Platform Tests (CUDA) + mirror_hardwares: [amdexperimental] + source_file_dependencies: + - vllm/ + - tests/cuda + commands: + - pytest -v -s cuda/test_cuda_context.py + - label: Samplers Test # 36min mirror_hardwares: [amdexperimental] source_file_dependencies: @@ -297,7 +340,7 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ @@ -305,6 +348,7 @@ steps: commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_fusion_attn.py - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - pytest -v -s compile/test_async_tp.py @@ -378,7 +422,7 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader @@ -470,7 +514,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] torch_nightly: true source_file_dependencies: - vllm/ @@ -494,6 +538,17 @@ steps: - pip freeze | grep -E 'torch' - pytest -v -s models/language -m core_model +- label: Language Models Test (Hybrid) # 35 min + mirror_hardwares: [amdexperimental] + torch_nightly: true + source_file_dependencies: + - vllm/ + - tests/models/language/generation + commands: + # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. + - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pytest -v -s models/language/generation -m hybrid_model + - label: Language Models Test (Extended Generation) # 1hr20min mirror_hardwares: [amdexperimental] optional: true @@ -503,7 +558,7 @@ steps: commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' - - pytest -v -s models/language/generation -m 'not core_model' + - pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)' - label: Language Models Test (Extended Pooling) # 36min mirror_hardwares: [amdexperimental] @@ -548,7 +603,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' - label: Multi-Modal Models Test (Extended) 3 - mirror_hardwares: [amdexperimental, amdproduction] + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ @@ -600,13 +655,18 @@ steps: - vllm/executor/ - vllm/model_executor/models/ - tests/distributed/ + - tests/examples/offline_inference/data_parallel.py commands: - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' + - NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed' + - python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code - label: Distributed Tests (2 GPUs) # 40min mirror_hardwares: [amdexperimental] @@ -624,10 +684,12 @@ steps: - vllm/worker/model_runner.py - entrypoints/llm/test_collective_rpc.py - tests/v1/test_async_llm_dp.py + - tests/v1/test_external_lb_dp.py - tests/v1/entrypoints/openai/test_multi_api_servers.py - vllm/v1/engine/ commands: - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py - DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py - pytest -v -s entrypoints/llm/test_collective_rpc.py - pytest -v -s ./compile/test_basic_correctness.py @@ -669,7 +731,7 @@ steps: - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -730,7 +792,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index e98ccd035ee90997f72cb1ba748b34e6e54fcdac..2acb03d52a67cc2d478545d86f094adb22ef199a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -16,7 +16,11 @@ /vllm/lora @jeejeelee /vllm/reasoning @aarnphm /vllm/entrypoints @aarnphm -CMakeLists.txt @tlrmchlsmth +CMakeLists.txt @tlrmchlsmth @LucasWilkinson + +# Any change to the VllmConfig changes can have a large user-facing impact, +# so spam a lot of people +/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat diff --git a/.github/mergify.yml b/.github/mergify.yml index 5692bb5d363d8013e09dd26c9e1534dfd9871a5a..20f3be8304a81041c887dd844c291f0c76a7b079 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -27,6 +27,22 @@ pull_request_rules: add: - ci/build +- name: label-deepseek + description: Automatically apply deepseek label + conditions: + - or: + - files~=^examples/.*deepseek.*\.py + - files~=^tests/.*deepseek.*\.py + - files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py + - files~=^vllm/model_executor/models/.*deepseek.*\.py + - files~=^vllm/reasoning/.*deepseek.*\.py + - files~=^vllm/transformers_utils/.*deepseek.*\.py + - title~=(?i)DeepSeek + actions: + label: + add: + - deepseek + - name: label-frontend description: Automatically apply frontend label conditions: @@ -45,6 +61,7 @@ pull_request_rules: - files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py - files~=^vllm/model_executor/models/.*llama.*\.py - files~=^vllm/transformers_utils/configs/.*llama.*\.py + - title~=(?i)llama actions: label: add: @@ -57,14 +74,72 @@ pull_request_rules: - files~=^vllm/multimodal/ - files~=^tests/multimodal/ - files~=^tests/models/multimodal/ - - files~=^tests/models/*/audio_language/ - - files~=^tests/models/*/vision_language/ - files=tests/models/test_vision.py actions: label: add: - multi-modality +- name: label-new-model + description: Automatically apply new-model label + conditions: + - and: + - files~=^vllm/model_executor/models/ + - files=vllm/model_executor/models/registry.py + - files=tests/models/registry.py + - files=docs/models/supported_models.md + actions: + label: + add: + - new-model + +- name: label-performance + description: Automatically apply performance label + conditions: + - or: + - files~=^benchmarks/ + - files~=^vllm/benchmarks/ + - files~=^tests/benchmarks/ + - files~=^\.buildkite/nightly-benchmarks/ + actions: + label: + add: + - performance + +- name: label-qwen + description: Automatically apply qwen label + conditions: + - or: + - files~=^examples/.*qwen.*\.py + - files~=^tests/.*qwen.*\.py + - files~=^vllm/model_executor/models/.*qwen.*\.py + - files~=^vllm/reasoning/.*qwen.*\.py + - title~=(?i)Qwen + actions: + label: + add: + - qwen + +- name: label-rocm + description: Automatically apply rocm label + conditions: + - or: + - files~=^csrc/rocm/ + - files~=^docker/Dockerfile.rocm + - files~=^requirements/rocm.*\.txt + - files~=^vllm/attention/backends/rocm.*\.py + - files~=^vllm/attention/ops/rocm.*\.py + - files~=^vllm/model_executor/layers/fused_moe/rocm.*\.py + - files~=^vllm/v1/attention/backends/mla/rocm.*\.py + - files~=^tests/kernels/.*_rocm.*\.py + - files=vllm/platforms/rocm.py + - title~=(?i)AMD + - title~=(?i)ROCm + actions: + label: + add: + - rocm + - name: label-structured-output description: Automatically apply structured-output label conditions: @@ -92,8 +167,14 @@ pull_request_rules: conditions: - or: - files~=^vllm/spec_decode/ + - files~=^vllm/v1/spec_decode/ - files=vllm/model_executor/layers/spec_decode_base_sampler.py - files~=^tests/spec_decode/ + - files~=^tests/v1/spec_decode/ + - files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py + - files~=^vllm/model_executor/models/.*eagle.*\.py + - files=vllm/model_executor/models/mlp_speculator.py + - files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py actions: label: add: diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 64011922ad82535803664331d667d55cf3283c02..74a7a3a3530f50ba2d30a1c75191f1b08af88184 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -68,7 +68,7 @@ jobs: export AWS_ACCESS_KEY_ID=minioadmin export AWS_SECRET_ACCESS_KEY=minioadmin sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & - helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" + helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - name: curl test run: | diff --git a/.gitignore b/.gitignore index e49d1d6ba6191565feb0378871852a35e748780e..88a42a5c0f64419e1583293dd159f12e7f2d0c1d 100644 --- a/.gitignore +++ b/.gitignore @@ -200,5 +200,5 @@ benchmarks/**/*.json actionlint shellcheck*/ -# Ingore moe/marlin_moe gen code +# Ignore moe/marlin_moe gen code csrc/moe/marlin_moe_wna16/kernel_* diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a105b0e14c4aff2330af3139cbcc050cb01b2bcc..d962252eb3dd827ea6ac11a19f7e7f0d2e5a4212 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,12 +20,10 @@ repos: args: [--output-format, github, --fix] - id: ruff-format files: ^(.buildkite|benchmarks|examples)/.* -- repo: https://github.com/codespell-project/codespell - rev: v2.4.1 +- repo: https://github.com/crate-ci/typos + rev: v1.32.0 hooks: - - id: codespell - additional_dependencies: ['tomli'] - args: ['--toml', 'pyproject.toml'] + - id: typos - repo: https://github.com/PyCQA/isort rev: 6.0.1 hooks: @@ -55,6 +53,11 @@ repos: files: ^requirements/test\.(in|txt)$ - repo: local hooks: + - id: format-torch-nightly-test + name: reformat nightly_torch_test.txt to be in sync with test.in + language: python + entry: python tools/generate_nightly_torch_test.py + files: ^requirements/test\.(in|txt)$ - id: mypy-local name: Run mypy for local Python installation entry: tools/mypy.sh 0 "local" @@ -117,6 +120,11 @@ repos: entry: python tools/check_spdx_header.py language: python types: [python] + - id: check-root-lazy-imports + name: Check root lazy imports + entry: python tools/check_init_lazy_imports.py + language: python + types: [python] - id: check-filenames name: Check for spaces in all filenames entry: bash @@ -145,6 +153,20 @@ repos: types: [python] pass_filenames: false additional_dependencies: [regex] + - id: check-pickle-imports + name: Prevent new pickle/cloudpickle imports + entry: python tools/check_pickle_imports.py + language: python + types: [python] + pass_filenames: false + additional_dependencies: [pathspec, regex] + - id: validate-config + name: Validate configuration has default values and that each field has a docstring + entry: python tools/validate_config.py + language: python + types: [python] + pass_filenames: true + files: vllm/config.py|tests/test_config.py # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/CMakeLists.txt b/CMakeLists.txt index dbf0ca291dfb8acb28ea4d9542e9cc10dce191de..52539e474d9ccf041d14789bb006e56540127f28 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -260,7 +260,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use") + set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -421,9 +421,39 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() - # The cutlass_scaled_mm kernels for Blackwell (c3x, i.e. CUTLASS 3.x) require + + # The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require # CUDA 12.8 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;12.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) + # require CUDA 12.8 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" @@ -514,6 +544,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${FP4_ARCHS}") list(APPEND VLLM_EXT_SRC "${SRCS}") list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}") else() message(STATUS "Not building NVFP4 as no compatible archs were found.") @@ -543,13 +574,12 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUTLASS MoE kernels - # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" - "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${SCALED_MM_ARCHS}") @@ -563,6 +593,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "if you intend on running FP8 quantized MoE models on Hopper.") else() message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + # moe_data.cu is used by all CUTLASS MoE kernels. + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found " "in CUDA target architectures") endif() endif() @@ -639,6 +709,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # if CUDA endif endif() +if (VLLM_GPU_LANG STREQUAL "HIP") + # Add QuickReduce kernels + list(APPEND VLLM_EXT_SRC + "csrc/custom_quickreduce.cu" + ) +# if ROCM endif +endif() + message(STATUS "Enabling C extension.") define_gpu_extension_target( _C diff --git a/README.md b/README.md index ec16d758327d4ecde377c9e96f9fdc49fb4da705..3e6ae2acab2a97f741181ce7e5990633168fccef 100644 --- a/README.md +++ b/README.md @@ -154,11 +154,13 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Contact Us + - For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions) - For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai) -- coordinating contributions and development, please use [Slack](https://slack.vllm.ai) +- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai) - For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature - For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu) + ## Media Kit diff --git a/benchmarks/README.md b/benchmarks/README.md index 6f9fbb91cbd9110a36c7a6708c2a8d6cf3a50a35..fb8690d42db98393c118947d7facce057f3cd48f 100644 --- a/benchmarks/README.md +++ b/benchmarks/README.md @@ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive datasets supported on vLLM. It’s a living document, updated as new features and datasets become available. -## Dataset Overview +**Dataset Overview** @@ -82,7 +82,10 @@ become available. **Note**: HuggingFace dataset's `dataset-name` should be set to `hf` --- -## Example - Online Benchmark +
+🚀 Example - Online Benchmark + +
First start serving your model @@ -130,7 +133,8 @@ P99 ITL (ms): 8.39 ================================================== ``` -### Custom Dataset +**Custom Dataset** + If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl ``` @@ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`. -### VisionArena Benchmark for Vision Language Models +**VisionArena Benchmark for Vision Language Models** ```bash # need a model with vision capability here @@ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 1000 ``` -### InstructCoder Benchmark with Speculative Decoding +**InstructCoder Benchmark with Speculative Decoding** ``` bash VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \ @@ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \ --num-prompts 2048 ``` -### Other HuggingFaceDataset Examples +**Other HuggingFaceDataset Examples** ```bash vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests @@ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 80 ``` -### Running With Sampling Parameters +**Running With Sampling Parameters** When using OpenAI-compatible backends such as `vllm`, optional sampling parameters can be specified. Example client command: @@ -269,8 +273,27 @@ python3 vllm/benchmarks/benchmark_serving.py \ --num-prompts 10 ``` ---- -## Example - Offline Throughput Benchmark +**Running With Ramp-Up Request Rate** + +The benchmark tool also supports ramping up the request rate over the +duration of the benchmark run. This can be useful for stress testing the +server or finding the maximum throughput that it can handle, given some latency budget. + +Two ramp-up strategies are supported: +- `linear`: Increases the request rate linearly from a start value to an end value. +- `exponential`: Increases the request rate exponentially. + +The following arguments can be used to control the ramp-up: +- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`). +- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark. +- `--ramp-up-end-rps`: The request rate at the end of the benchmark. + +
+ +
+📈 Example - Offline Throughput Benchmark + +
```bash python3 vllm/benchmarks/benchmark_throughput.py \ @@ -288,7 +311,7 @@ Total num prompt tokens: 5014 Total num output tokens: 1500 ``` -### VisionArena Benchmark for Vision Language Models +**VisionArena Benchmark for Vision Language Models** ``` bash python3 vllm/benchmarks/benchmark_throughput.py \ @@ -308,7 +331,7 @@ Total num prompt tokens: 14527 Total num output tokens: 1280 ``` -### InstructCoder Benchmark with Speculative Decoding +**InstructCoder Benchmark with Speculative Decoding** ``` bash VLLM_WORKER_MULTIPROC_METHOD=spawn \ @@ -332,7 +355,7 @@ Total num prompt tokens: 261136 Total num output tokens: 204800 ``` -### Other HuggingFaceDataset Examples +**Other HuggingFaceDataset Examples** **`lmms-lab/LLaVA-OneVision-Data`** @@ -371,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \ --num-prompts 10 ``` -### Benchmark with LoRA Adapters +**Benchmark with LoRA Adapters** ``` bash # download dataset @@ -387,3 +410,196 @@ python3 vllm/benchmarks/benchmark_throughput.py \ --enable-lora \ --lora-path yard1/llama-2-7b-sql-lora-test ``` + +
+ +
+🛠️ Example - Structured Output Benchmark + +
+ +Benchmark the performance of structured output generation (JSON, grammar, regex). + +**Server Setup** + +```bash +vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests +``` + +**JSON Schema Benchmark** + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset json \ + --structured-output-ratio 1.0 \ + --request-rate 10 \ + --num-prompts 1000 +``` + +**Grammar-based Generation Benchmark** + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset grammar \ + --structure-type grammar \ + --request-rate 10 \ + --num-prompts 1000 +``` + +**Regex-based Generation Benchmark** + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset regex \ + --request-rate 10 \ + --num-prompts 1000 +``` + +**Choice-based Generation Benchmark** + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset choice \ + --request-rate 10 \ + --num-prompts 1000 +``` + +**XGrammar Benchmark Dataset** + +```bash +python3 benchmarks/benchmark_serving_structured_output.py \ + --backend vllm \ + --model NousResearch/Hermes-3-Llama-3.1-8B \ + --dataset xgrammar_bench \ + --request-rate 10 \ + --num-prompts 1000 +``` + +
+ +
+📚 Example - Long Document QA Benchmark + +
+ +Benchmark the performance of long document question-answering with prefix caching. + +**Basic Long Document QA Test** + +```bash +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 16 \ + --document-length 2000 \ + --output-len 50 \ + --repeat-count 5 +``` + +**Different Repeat Modes** + +```bash +# Random mode (default) - shuffle prompts randomly +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode random + +# Tile mode - repeat entire prompt list in sequence +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode tile + +# Interleave mode - repeat each prompt consecutively +python3 benchmarks/benchmark_long_document_qa_throughput.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-documents 8 \ + --document-length 3000 \ + --repeat-count 3 \ + --repeat-mode interleave +``` + +
+ +
+🗂️ Example - Prefix Caching Benchmark + +
+ +Benchmark the efficiency of automatic prefix caching. + +**Fixed Prompt with Prefix Caching** + +```bash +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 \ + --input-length-range 128:256 +``` + +**ShareGPT Dataset with Prefix Caching** + +```bash +# download dataset +# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + +python3 benchmarks/benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +``` + +
+ +
+⚡ Example - Request Prioritization Benchmark + +
+ +Benchmark the performance of request prioritization in vLLM. + +**Basic Prioritization Test** + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority +``` + +**Multiple Sequences per Prompt** + +```bash +python3 benchmarks/benchmark_prioritization.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --input-len 128 \ + --output-len 64 \ + --num-prompts 100 \ + --scheduling-policy priority \ + --n 2 +``` + +
diff --git a/benchmarks/auto_tune.sh b/benchmarks/auto_tune.sh index 1b01bbd61b628f0ed18041b8945d2e5f086951a4..b257b57ce06f5038338bc8a8f114fcb13c0e203b 100644 --- a/benchmarks/auto_tune.sh +++ b/benchmarks/auto_tune.sh @@ -10,6 +10,7 @@ # 3. Set variables (ALL REQUIRED) # BASE: your directory for vllm repo # MODEL: the model served by vllm +# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support. # TP: ways of tensor parallelism # DOWNLOAD_DIR: directory to download and load model weights. # INPUT_LEN: request input len @@ -34,6 +35,7 @@ TAG=$(date +"%Y_%m_%d_%H_%M") BASE="" MODEL="meta-llama/Llama-3.1-8B-Instruct" +SYSTEM="TPU" TP=1 DOWNLOAD_DIR="" INPUT_LEN=4000 @@ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" +PROFILE_PATH="$LOG_FOLDER/profile" echo "result file: $RESULT" echo "model: $MODEL" rm -rf $LOG_FOLDER +rm -rf $PROFILE_PATH mkdir -p $LOG_FOLDER +mkdir -p $PROFILE_PATH cd "$BASE/vllm" @@ -70,10 +75,11 @@ start_server() { local max_num_seqs=$2 local max_num_batched_tokens=$3 local vllm_log=$4 + local profile_dir=$5 pkill -f vllm - VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \ --disable-log-requests \ --port 8004 \ --gpu-memory-utilization $gpu_memory_utilization \ @@ -105,19 +111,37 @@ start_server() { fi } +update_best_profile() { + local profile_dir=$1 + local profile_index=$2 + sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort)) + selected_profile_file= + if [[ "$SYSTEM" == "TPU" ]]; then + selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb" + fi + if [[ "$SYSTEM" == "GPU" ]]; then + selected_profile_file="${sorted_paths[$profile_index]}" + fi + rm -f $PROFILE_PATH/* + cp $selected_profile_file $PROFILE_PATH +} + run_benchmark() { local max_num_seqs=$1 local max_num_batched_tokens=$2 local gpu_memory_utilization=$3 echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" + local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}" echo "vllm_log: $vllm_log" echo rm -f $vllm_log + mkdir -p $profile_dir pkill -f vllm + local profile_index=0 echo "starting server..." - start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log + start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir result=$? if [[ "$result" -eq 1 ]]; then echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" @@ -144,7 +168,8 @@ run_benchmark() { --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 1000 \ --random-prefix-len $prefix_len \ - --port 8004 &> "$bm_log" + --port 8004 \ + --profile &> "$bm_log" throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') @@ -158,6 +183,7 @@ run_benchmark() { # start from request-rate as int(throughput) + 1 request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do + profile_index=$((profile_index+1)) # clear prefix cache curl -X POST http://0.0.0.0:8004/reset_prefix_cache sleep 5 @@ -195,6 +221,12 @@ run_benchmark() { best_max_num_seqs=$max_num_seqs best_num_batched_tokens=$max_num_batched_tokens best_goodput=$goodput + if [[ "$SYSTEM" == "TPU" ]]; then + update_best_profile "$profile_dir/plugins/profile" $profile_index + fi + if [[ "$SYSTEM" == "GPU" ]]; then + update_best_profile "$profile_dir" $profile_index + fi fi else echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" @@ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do done done echo "finish permutations" -echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" -echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT" diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index ddb38e304cd6565f5b5369c9555633fec7dfa373..c7229dbb8e90d1167eab8fda1a5fb353e31e81a4 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -404,8 +404,14 @@ async def async_request_openai_chat_completions( chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue + chunk_bytes = chunk_bytes.decode("utf-8") + # NOTE: SSE comments (often used as pings) start with a colon. + # These are not JSON data payload and should be skipped. + if chunk_bytes.startswith(":"): + continue + + chunk = chunk_bytes.removeprefix("data: ") - chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index 5d2a26cd443c0060eb5ed80ed44565812caa96c6..55c0cf851264f37d271f7a8cd1fc5ad17451f386 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -349,11 +349,12 @@ class RandomDataset(BenchmarkDataset): # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] # To avoid uncontrolled change of the prompt length, # the encoded sequence is truncated before being decode again. + total_input_len = prefix_len + int(input_lens[i]) re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ - : input_lens[i] + :total_input_len ] prompt = tokenizer.decode(re_encoded_sequence) - total_input_len = prefix_len + int(input_lens[i]) + total_input_len = len(re_encoded_sequence) requests.append( SampleRequest( prompt=prompt, diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index c06857247eeed9bbb61c437ff144d3bd9edba4b1..4d2ea126b24a514bde2e7e3808d2120ede353cc7 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -123,7 +123,7 @@ def main(args: argparse.Namespace): save_to_pytorch_benchmark_format(args, results) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser( description="Benchmark the latency of processing a single batch of " "requests till completion." @@ -171,6 +171,12 @@ if __name__ == "__main__": # V1 enables prefix caching by default which skews the latency # numbers. We need to disable prefix caching by default. parser.set_defaults(enable_prefix_caching=False) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() if args.profile and not envs.VLLM_TORCH_PROFILER_DIR: raise OSError( diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 00869fa94e71a78c8ec26f8d35d1ff67871584ee..6e0f3b51c9d284d22dd1acdc1b794b9a5c0f31b1 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -142,7 +142,7 @@ def main(args): ) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser( description="Benchmark the performance with or " "without automatic prefix caching." @@ -192,5 +192,11 @@ if __name__ == "__main__": ) parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 3e4704f0b8205870914a0c1d7ffe100ece91e6ee..b5e2613de1cd4a0a2557cf09baa458aa15a6049c 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -218,7 +218,7 @@ def main(args): ) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser( description="Benchmark the performance with or without " "automatic prefix caching." @@ -268,5 +268,11 @@ if __name__ == "__main__": ) parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 5496703f23ccbe368f4b8dbc20da5ad308109996..bb453791c1862bdd0883f0a47af6fddeef764f8f 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -161,7 +161,7 @@ def main(args: argparse.Namespace): json.dump(results, f, indent=4) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument( "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" @@ -204,6 +204,12 @@ if __name__ == "__main__": ) parser = EngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 81428fb7dae12a9d4a1f6a0755f9bff25e28a585..9b235266dff1aac95d6fe0509f4a5f01641579cc 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -33,7 +33,7 @@ import warnings from collections.abc import AsyncGenerator, Iterable from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional +from typing import Any, Literal, Optional import numpy as np from tqdm.asyncio import tqdm @@ -107,14 +107,42 @@ class BenchmarkMetrics: percentiles_e2el_ms: list[tuple[float, float]] +def _get_current_request_rate( + ramp_up_strategy: Optional[Literal["linear", "exponential"]], + ramp_up_start_rps: Optional[int], + ramp_up_end_rps: Optional[int], + request_index: int, + total_requests: int, + request_rate: float, +) -> float: + if ( + ramp_up_strategy + and ramp_up_start_rps is not None + and ramp_up_end_rps is not None + ): + progress = request_index / max(total_requests - 1, 1) + if ramp_up_strategy == "linear": + increase = (ramp_up_end_rps - ramp_up_start_rps) * progress + return ramp_up_start_rps + increase + elif ramp_up_strategy == "exponential": + ratio = ramp_up_end_rps / ramp_up_start_rps + return ramp_up_start_rps * (ratio**progress) + else: + raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}") + return request_rate + + async def get_request( input_requests: list[SampleRequest], request_rate: float, burstiness: float = 1.0, -) -> AsyncGenerator[SampleRequest, None]: + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, +) -> AsyncGenerator[tuple[SampleRequest, float], None]: """ Asynchronously generates requests at a specified rate - with OPTIONAL burstiness. + with OPTIONAL burstiness and OPTIONAL ramp-up strategy. Args: input_requests: @@ -129,22 +157,44 @@ async def get_request( A lower burstiness value (0 < burstiness < 1) results in more bursty requests, while a higher burstiness value (burstiness > 1) results in a more uniform arrival of requests. + ramp_up_strategy (optional): + The ramp-up strategy. Can be "linear" or "exponential". + If None, uses constant request rate (specified by request_rate). + ramp_up_start_rps (optional): + The starting request rate for ramp-up. + ramp_up_end_rps (optional): + The ending request rate for ramp-up. """ - input_requests: Iterable[SampleRequest] = iter(input_requests) - - # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( f"A positive burstiness factor is expected, but given {burstiness}." ) - theta = 1.0 / (request_rate * burstiness) + # Convert to list to get length for ramp-up calculations + if isinstance(input_requests, Iterable) and not isinstance(input_requests, list): + input_requests = list(input_requests) + + total_requests = len(input_requests) + request_index = 0 for request in input_requests: - yield request + current_request_rate = _get_current_request_rate( + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + request_index, + total_requests, + request_rate, + ) + + yield request, current_request_rate - if request_rate == float("inf"): + request_index += 1 + + if current_request_rate == float("inf"): # If the request rate is infinity, then we don't need to wait. continue + theta = 1.0 / (current_request_rate * burstiness) + # Sample the request interval from the gamma distribution. # If burstiness is 1, it follows exponential distribution. interval = np.random.gamma(shape=burstiness, scale=theta) @@ -290,6 +340,9 @@ async def benchmark( max_concurrency: Optional[int], lora_modules: Optional[Iterable[str]], extra_body: Optional[dict], + ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None, + ramp_up_start_rps: Optional[int] = None, + ramp_up_end_rps: Optional[int] = None, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -353,7 +406,15 @@ async def benchmark( distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" - print(f"Traffic request rate: {request_rate}") + if ramp_up_strategy is not None: + print( + f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase " + f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over " + "the duration of the benchmark." + ) + else: + print(f"Traffic request rate: {request_rate} RPS.") + print(f"Burstiness factor: {burstiness} ({distribution})") print(f"Maximum request concurrency: {max_concurrency}") @@ -373,7 +434,34 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate, burstiness): + + rps_change_events = [] + last_int_rps = -1 + if ramp_up_strategy is not None and ramp_up_start_rps is not None: + last_int_rps = ramp_up_start_rps + rps_change_events.append( + { + "rps": last_int_rps, + "timestamp": datetime.now().isoformat(), + } + ) + + async for request, current_request_rate in get_request( + input_requests, + request_rate, + burstiness, + ramp_up_strategy, + ramp_up_start_rps, + ramp_up_end_rps, + ): + if ramp_up_strategy is not None: + current_int_rps = int(current_request_rate) + if current_int_rps > last_int_rps: + timestamp = datetime.now().isoformat() + for rps_val in range(last_int_rps + 1, current_int_rps + 1): + rps_change_events.append({"rps": rps_val, "timestamp": timestamp}) + last_int_rps = current_int_rps + prompt, prompt_len, output_len, mm_content = ( request.prompt, request.prompt_len, @@ -397,11 +485,8 @@ async def benchmark( ignore_eos=ignore_eos, extra_body=extra_body, ) - tasks.append( - asyncio.create_task( - limited_request_func(request_func_input=request_func_input, pbar=pbar) - ) - ) + task = limited_request_func(request_func_input=request_func_input, pbar=pbar) + tasks.append(asyncio.create_task(task)) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -466,7 +551,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": metrics.request_goodput if goodput_config_dict else None, + "request_goodput": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -477,6 +562,9 @@ async def benchmark( "errors": [output.error for output in outputs], } + if rps_change_events: + result["rps_change_events"] = rps_change_events + def process_one_metric( # E.g., "ttft" metric_attribute_name: str, @@ -610,6 +698,26 @@ def main(args: argparse.Namespace): tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model tokenizer_mode = args.tokenizer_mode + # Validate ramp-up arguments + if args.ramp_up_strategy is not None: + if args.request_rate != float("inf"): + raise ValueError( + "When using ramp-up, do not specify --request-rate. " + "The request rate will be controlled by ramp-up parameters. " + "Please remove the --request-rate argument." + ) + if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None: + raise ValueError( + "When using --ramp-up-strategy, both --ramp-up-start-rps and " + "--ramp-up-end-rps must be specified" + ) + if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0: + raise ValueError("Ramp-up start and end RPS must be non-negative") + if args.ramp_up_start_rps > args.ramp_up_end_rps: + raise ValueError("Ramp-up start RPS must be less than end RPS") + if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0: + raise ValueError("For exponential ramp-up, the start RPS cannot be 0.") + if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" base_url = f"{args.base_url}" @@ -802,6 +910,9 @@ def main(args: argparse.Namespace): max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, + ramp_up_strategy=args.ramp_up_strategy, + ramp_up_start_rps=args.ramp_up_start_rps, + ramp_up_end_rps=args.ramp_up_end_rps, ) ) @@ -834,6 +945,11 @@ def main(args: argparse.Namespace): result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency + if args.ramp_up_strategy is not None: + result_json["ramp_up_strategy"] = args.ramp_up_strategy + result_json["ramp_up_start_rps"] = args.ramp_up_start_rps + result_json["ramp_up_end_rps"] = args.ramp_up_end_rps + # Merge with benchmark result result_json = {**result_json, **benchmark_result} @@ -859,7 +975,10 @@ def main(args: argparse.Namespace): if args.max_concurrency is not None else "" ) - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + if args.ramp_up_strategy is not None: + file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa + else: + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: @@ -875,7 +994,7 @@ def main(args: argparse.Namespace): save_to_pytorch_benchmark_format(args, result_json, file_name) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser( description="Benchmark the online serving throughput." ) @@ -1225,6 +1344,35 @@ if __name__ == "__main__": "script chooses a LoRA module at random.", ) - args = parser.parse_args() + parser.add_argument( + "--ramp-up-strategy", + type=str, + default=None, + choices=["linear", "exponential"], + help="The ramp-up strategy. This would be used to " + "ramp up the request rate from initial RPS to final " + "RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). " + "over the duration of the benchmark.", + ) + parser.add_argument( + "--ramp-up-start-rps", + type=int, + default=None, + help="The starting request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + parser.add_argument( + "--ramp-up-end-rps", + type=int, + default=None, + help="The ending request rate for ramp-up (RPS). " + "Needs to be specified when --ramp-up-strategy is used.", + ) + + return parser + +if __name__ == "__main__": + parser = create_argument_parser() + args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index c1501ad52c25af1f101785e4a2143e8a5e317489..e23a5a9e2233df685998fbae0ae8cc8910b572cf 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -850,7 +850,7 @@ def main(args: argparse.Namespace): json.dump(results, outfile, indent=4) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser( description="Benchmark the online serving throughput." ) @@ -1034,5 +1034,10 @@ if __name__ == "__main__": help="Ratio of Structured Outputs requests", ) + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index d19753d40e497d95c56ed1330a802d788b0d5ded..0ded34c70badd28d972b3c993ddb7881b7b94aff 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -97,7 +97,7 @@ def run_vllm( assert lora_requests is None, "BeamSearch API does not support LoRA" prompts = [request.prompt for request in requests] # output_len should be the same for all requests. - output_len = requests[0][2] + output_len = requests[0].expected_output_len for request in requests: assert request.expected_output_len == output_len start = time.perf_counter() @@ -595,7 +595,7 @@ def validate_args(args): ) -if __name__ == "__main__": +def create_argument_parser(): parser = FlexibleArgumentParser(description="Benchmark the throughput.") parser.add_argument( "--backend", @@ -717,6 +717,12 @@ if __name__ == "__main__": ) parser = AsyncEngineArgs.add_cli_args(parser) + + return parser + + +if __name__ == "__main__": + parser = create_argument_parser() args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index cec422e8d597f1df353eb5a4836c8f88f1685c2a..a5a5b52f603977cc0c967e455724bc7c6a2bf585 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -19,7 +19,7 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul, ) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, cdiv DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512] @@ -117,14 +117,9 @@ def bench_fp8( scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - def ceil_div(x: int, y: int) -> int: - return (x + y - 1) // y - - block_scale_a = torch.rand( - (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 - ) + block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32) block_scale_b = torch.rand( - ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32 ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py index b964ed242edf8f16a8fd63ab27c098d11c9663f4..920961899038061c13125bdd37f9843b5e4d548a 100644 --- a/benchmarks/kernels/bench_fp8_gemm.py +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -11,6 +11,80 @@ from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant from vllm.triton_utils import triton +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "fp8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "fp8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "fp8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "fp8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "fp8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "fp8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "fp8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]] + + +def _quant_weight_fp8(b: torch.Tensor, w_type: str, device: str): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) + else: + b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, use_per_token_if_dynamic=True) + return b_fp8.t(), scale_b_fp8 + + +def build_fp8_runner(cfg, a, b, dtype, device): + b_fp8, scale_b_fp8 = _quant_weight_fp8(b, cfg["w"], device) + + scale_a_const = ( + torch.ones(1, device=device, dtype=torch.float32) + if cfg["a"] == "tensor" + else None + ) + + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + else: + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + + def run(): + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + + if cfg["a"] == "tensor": + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a_const) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + else: + + def run(): + a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, use_per_token_if_dynamic=True) + return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) + + return run + @triton.testing.perf_report( triton.testing.Benchmark( @@ -18,28 +92,8 @@ from vllm.triton_utils import triton x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], x_log=False, line_arg="provider", - line_vals=[ - "torch-bf16", - # "fp8-tensor-w-token-a", - "fp8-tensor-w-tensor-a", - "fp8-channel-w-token-a", - # "fp8-channel-w-tensor-a", - # "fp8-tensor-w-token-a-noquant", - "fp8-tensor-w-tensor-a-noquant", - "fp8-channel-w-token-a-noquant", - # "fp8-channel-w-tensor-a-noquant", - ], - line_names=[ - "torch-bf16", - # "fp8-tensor-w-token-a", - "fp8-tensor-w-tensor-a", - "fp8-channel-w-token-a", - # "fp8-channel-w-tensor-a", - # "fp8-tensor-w-token-a-noquant", - "fp8-tensor-w-tensor-a-noquant", - "fp8-channel-w-token-a-noquant", - # "fp8-channel-w-tensor-a-noquant", - ], + line_vals=_enabled, + line_names=_enabled, ylabel="TFLOP/s (larger is better)", plot_name="BF16 vs FP8 GEMMs", args={}, @@ -50,144 +104,34 @@ def benchmark(batch_size, provider, N, K): device = "cuda" dtype = torch.bfloat16 - # Create input tensors a = torch.randn((M, K), device=device, dtype=dtype) b = torch.randn((N, K), device=device, dtype=dtype) quantiles = [0.5, 0.2, 0.8] - if "torch-bf16" in provider: + if provider == "torch-bf16": ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: torch.nn.functional.linear(a, b), quantiles=quantiles ) - - elif "fp8" in provider: - # Weights are always quantized ahead of time - if "noquant" in provider: - # For no quantization, we just measure the GEMM - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, per-tensor quant for B - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) - assert scale_b_fp8.numel() == 1 - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales - # for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - assert scale_b_fp8.numel() == 1 - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-token-a" in provider: - # Static per-channel quantization for weights, per-token - # quant for A - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-tensor-a" in provider: - # Static per-channel quantization for weights, per-tensor - # quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - - def run_quant(): - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - else: - # In these cases, we quantize the activations during the GEMM call - if "tensor-w-token-a" in provider: - # Dynamic per-token quant for A, per-tensor quant for B - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b) - assert scale_b_fp8.numel() == 1 - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "tensor-w-tensor-a" in provider: - # Static per-tensor quantization with fixed scales - # for both A and B - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor([1.0], device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - assert scale_b_fp8.numel() == 1 - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-token-a" in provider: - # Static per-channel quantization for weights, per-token - # quant for A - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant( - a, use_per_token_if_dynamic=True - ) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - elif "channel-w-tensor-a" in provider: - # Static per-channel quantization for weights, per-tensor - # quant for A - scale_a = torch.tensor([1.0], device=device, dtype=torch.float32) - scale_b = torch.tensor((N,), device=device, dtype=torch.float32) - b_fp8, scale_b_fp8 = vllm_scaled_fp8_quant(b, scale_b) - scale_b_fp8 = scale_b_fp8.expand(N).contiguous() - assert scale_b_fp8.numel() == N - - def run_quant(): - a_fp8, scale_a_fp8 = vllm_scaled_fp8_quant(a, scale_a) - return vllm_scaled_mm(a_fp8, b_fp8, scale_a_fp8, scale_b_fp8, dtype) - - b_fp8 = b_fp8.t() - + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_fp8_runner(cfg, a, b, dtype, device) ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( lambda: run_quant(), quantiles=quantiles ) - # Calculate TFLOP/s, two flops per multiply-add - tflops = lambda ms: (2 * M * N * K) * 1e-12 / (ms * 1e-3) - return tflops(ms), tflops(max_ms), tflops(min_ms) + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) def prepare_shapes(args): - KN_model_names = [] - models_tps = list(itertools.product(args.models, args.tp_sizes)) - for model, tp_size in models_tps: - assert model in WEIGHT_SHAPES - for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model]): - KN[tp_split_dim] = KN[tp_split_dim] // tp_size + out = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size KN.append(model) - KN_model_names.append(KN) - return KN_model_names + out.append(KN) + return out if __name__ == "__main__": @@ -197,21 +141,13 @@ if __name__ == "__main__": nargs="+", type=str, default=["meta-llama/Llama-3.1-8B-Instruct"], - choices=[*WEIGHT_SHAPES.keys()], - help="List of models to benchmark", - ) - parser.add_argument( - "--tp-sizes", - nargs="+", - type=int, - default=[1], - help="List of tensor parallel sizes", + choices=list(WEIGHT_SHAPES.keys()), ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1]) args = parser.parse_args() - KN_model_names = prepare_shapes(args) - for K, N, model_name in KN_model_names: - print(f"{model_name}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs FP8 GEMMs TFLOP/s:") benchmark.run( print_data=True, show_plots=True, diff --git a/benchmarks/kernels/bench_int8_gemm.py b/benchmarks/kernels/bench_int8_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e9c6d64404d0dc6958aeb675bf6b893623649ffa --- /dev/null +++ b/benchmarks/kernels/bench_int8_gemm.py @@ -0,0 +1,169 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import copy +import itertools + +import torch +from weight_shapes import WEIGHT_SHAPES + +from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm +from vllm._custom_ops import scaled_int8_quant as vllm_scaled_int8_quant +from vllm.triton_utils import triton + +PROVIDER_CFGS = { + "torch-bf16": dict(enabled=True), + "int8-tensor-w-token-a": dict( + w="tensor", a="token", no_a_quant=False, enabled=False + ), + "int8-tensor-w-tensor-a": dict( + w="tensor", a="tensor", no_a_quant=False, enabled=True + ), + "int8-channel-w-token-a": dict( + w="channel", a="token", no_a_quant=False, enabled=True + ), + "int8-channel-w-tensor-a": dict( + w="channel", a="tensor", no_a_quant=False, enabled=False + ), + "int8-tensor-w-token-a-noquant": dict( + w="tensor", a="token", no_a_quant=True, enabled=False + ), + "int8-tensor-w-tensor-a-noquant": dict( + w="tensor", a="tensor", no_a_quant=True, enabled=True + ), + "int8-channel-w-token-a-noquant": dict( + w="channel", a="token", no_a_quant=True, enabled=True + ), + "int8-channel-w-tensor-a-noquant": dict( + w="channel", a="tensor", no_a_quant=True, enabled=False + ), +} + + +def _quant_weight(b, w_type, device): + if w_type == "tensor": + scale_b = torch.ones(1, device=device, dtype=torch.float32) + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b, scale_b) + assert scale_b_int8.numel() == 1 + else: # channel + b_int8, scale_b_int8, _ = vllm_scaled_int8_quant(b) + assert scale_b_int8.numel() == b.shape[0] + return b_int8.t(), scale_b_int8 + + +def build_int8_runner(cfg, a, b, dtype, device): + # quant before running the kernel + b_int8, scale_b_int8 = _quant_weight(b, cfg["w"], device) + + scale_a_const = None + if cfg["a"] == "tensor": + scale_a_const = torch.ones(1, device=device, dtype=torch.float32) + + # no quant, create activation ahead + if cfg["no_a_quant"]: + if cfg["a"] == "tensor": + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + else: # token + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + + def run_quant(): + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + # dynamic quant, create activation inside + if cfg["a"] == "tensor": + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a, scale_a_const) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + else: # token + + def run_quant(): + a_int8, scale_a_int8, _ = vllm_scaled_int8_quant(a) + return vllm_scaled_mm(a_int8, b_int8, scale_a_int8, scale_b_int8, dtype) + + return run_quant + + +_enabled = [k for k, v in PROVIDER_CFGS.items() if v.get("enabled")] + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["batch_size"], + x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384], + x_log=False, + line_arg="provider", + line_vals=_enabled, + line_names=[k for k in _enabled], + ylabel="TFLOP/s (larger is better)", + plot_name="BF16 vs INT8 GEMMs", + args={}, + ) +) +def benchmark(batch_size, provider, N, K): + M = batch_size + device = "cuda" + dtype = torch.bfloat16 + a = torch.randn((M, K), device=device, dtype=dtype) + b = torch.randn((N, K), device=device, dtype=dtype) + + quantiles = [0.5, 0.2, 0.8] + + if provider == "torch-bf16": + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: torch.nn.functional.linear(a, b), quantiles=quantiles + ) + else: + cfg = PROVIDER_CFGS[provider] + run_quant = build_int8_runner(cfg, a, b, dtype, device) + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph( + lambda: run_quant(), quantiles=quantiles + ) + + to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3) + return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms) + + +def prepare_shapes(args): + KN_model_names = [] + for model, tp_size in itertools.product(args.models, args.tp_sizes): + for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]): + KN[tp_dim] //= tp_size + KN.append(model) + KN_model_names.append(KN) + return KN_model_names + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["meta-llama/Llama-3.1-8B-Instruct"], + choices=list(WEIGHT_SHAPES.keys()), + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-sizes", + nargs="+", + type=int, + default=[1], + help="List of tensor parallel sizes", + ) + args = parser.parse_args() + + for K, N, model in prepare_shapes(args): + print(f"{model}, N={N} K={K}, BF16 vs INT8 GEMMs TFLOP/s:") + benchmark.run( + print_data=True, + show_plots=True, + save_path=f"bench_int8_res_n{N}_k{K}", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index acabe6c1ddb0a18aac64cef4326fb14220fe4858..1d4e730f99ae911535caa4cca17122656db06a4b 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -113,6 +113,7 @@ def bench_run( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, + per_act_token: bool, num_repeats: int, ): for _ in range(num_repeats): @@ -124,7 +125,8 @@ def bench_run( topk_ids, w1_scale, w2_scale, - a1_scale=a_scale, + per_act_token, + a1_scale=None, ) def run_cutlass_from_graph( @@ -148,7 +150,8 @@ def bench_run( topk_ids, w1_scale, w2_scale, - a1_scale=a_scale, + per_act_token, + a1_scale=None, ) def run_triton_from_graph( @@ -227,6 +230,7 @@ def bench_run( "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, + "per_act_token": per_act_token, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -287,12 +291,13 @@ def bench_run( w2_scale, topk_weights, topk_ids, + per_act_token, num_warmup, ) results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0f896f187ecb9783b18036779d9ebf56a1783a30..f73d0511e01fc3851847278f8604a2baed8c8f33 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: fn = lambda: ops.gptq_marlin_gemm( a=bt.a, + c=None, b_q_weight=w_q, b_scales=w_s, + global_scale=None, b_zeros=w_zp, g_idx=g_idx, perm=sort_indices, diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 9ea1fddae2a3b0b7ff595a70c0c61f9355f0ff41..34cc45e94d76d930013c9d60ca8c5f78dd0b7254 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types, ) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, + rand_marlin_weight_fp4_like, +) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + marlin_quant_fp8_torch, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, + awq_marlin_quantize, marlin_quantize, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( @@ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights, sort_weights, ) -from vllm.scalar_type import ScalarType +from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"] @@ -57,80 +65,144 @@ def bench_run( size_n: int, ): label = "Quant Matmul" - sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n ) - print(f"Testing: {sub_label}") a = torch.randn(size_m, size_k).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda() + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + if act_order and (group_size == -1 or group_size == size_k or has_zp): + return + if size_k % group_size != 0: + return - a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() + marlin_24_supported = ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ) + repack_supported = ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in MARLIN_SUPPORTED_GROUP_SIZES + ) + allspark_supported = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) + + def gen_marlin_params(): + # Marlin quant + marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b.T, group_size + ) + elif quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128] or act_order: + return + marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size) + elif group_size == 16: + return + elif has_zp: + marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b, quant_type, group_size + ) + else: + marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = ( + marlin_quantize(b, quant_type, group_size, act_order) + ) + return ( + marlin_w_ref, + marlin_q_w, + marlin_s, + marlin_s2, + marlin_zp, + marlin_g_idx, + marlin_sort_indices, + ) + + def gen_marlin_24_params(): + marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None + if marlin_24_supported: + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( + marlin_24_quantize(b, quant_type, group_size) + ) + return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) + + def gen_repack_params(): + q_w_gptq = None + repack_sort_indices = None + if repack_supported: + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) + q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) + + # For act_order, sort the "weights" and "g_idx" + # so that group ids are increasing + repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) + if act_order: + (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) + return q_w_gptq, repack_sort_indices + + def gen_allspark_params(): + qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = ( + CUBLAS_M_THRESHOLD + ) = None + nonlocal allspark_supported + if allspark_supported: + properties = torch.cuda.get_device_properties(b.device.index) + sm_count = properties.multi_processor_count + sm_version = properties.major * 10 + properties.minor + + supported_arch = sm_version >= 80 and sm_version < 90 + allspark_supported = allspark_supported and supported_arch + if supported_arch: + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) + qw = qw.to(torch.uint8) + + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) + CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD + return ( + qw_reorder, + s_reorder, + zp_reorder, + sm_count, + sm_version, + CUBLAS_M_THRESHOLD, + ) - # Marlin quant ( marlin_w_ref, marlin_q_w, marlin_s, + marlin_s2, + marlin_zp, marlin_g_idx, marlin_sort_indices, - marlin_rand_perm, - ) = marlin_quantize(b, quant_type, group_size, act_order) - - # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( - marlin_24_quantize(b, quant_type, group_size) + ) = gen_marlin_params() + marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = ( + gen_marlin_24_params() ) - - marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) - - # GPTQ quant - (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( - b, quant_type, group_size, act_order + q_w_gptq, repack_sort_indices = gen_repack_params() + qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = ( + gen_allspark_params() ) - q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) - - # For act_order, sort the "weights" and "g_idx" - # so that group ids are increasing - repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device) - if act_order: - (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare marlin_workspace = MarlinWorkspace( size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL ) - marlin_24_workspace = MarlinWorkspace( size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL ) - marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) - - # AllSpark W8A16 quant - as_supported_case = ( - quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES - and group_size == -1 - and not act_order - and is_k_full - ) - if as_supported_case: - properties = torch.cuda.get_device_properties(b.device.index) - sm_count = properties.multi_processor_count - sm_version = properties.major * 10 + properties.minor - - supported_arch = sm_version >= 80 and sm_version < 90 - as_supported_case = as_supported_case and supported_arch - if supported_arch: - has_zp = False - w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) - qw = qw.to(torch.uint8) - - qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( - qw, s, zp, has_zp - ) - CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { # Gen params @@ -140,15 +212,14 @@ def bench_run( "size_n": size_n, "size_k": size_k, "a": a, - "a_tmp": a_tmp, # Marlin params "marlin_w_ref": marlin_w_ref, "marlin_q_w": marlin_q_w, "marlin_s": marlin_s, + "marlin_s2": marlin_s2, "marlin_zp": marlin_zp, "marlin_g_idx": marlin_g_idx, "marlin_sort_indices": marlin_sort_indices, - "marlin_rand_perm": marlin_rand_perm, "marlin_workspace": marlin_workspace, "is_k_full": is_k_full, # Marlin_24 params @@ -161,12 +232,12 @@ def bench_run( "q_w_gptq": q_w_gptq, "repack_sort_indices": repack_sort_indices, # AllSpark W8A16 params - "qw_reorder": qw_reorder if as_supported_case else None, - "s_reorder": s_reorder if as_supported_case else None, - "zp_reorder": zp_reorder if as_supported_case else None, - "sm_count": sm_count if as_supported_case else None, - "sm_version": sm_version if as_supported_case else None, - "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, + "qw_reorder": qw_reorder, + "s_reorder": s_reorder, + "zp_reorder": zp_reorder, + "sm_count": sm_count, + "sm_version": sm_version, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -177,7 +248,7 @@ def bench_run( min_run_time = 1 # Warmup pytorch - for i in range(5): + for _ in range(5): torch.matmul(a, marlin_w_ref) results.append( @@ -192,17 +263,17 @@ def bench_run( results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, - description="gptq_marlin_gemm_fp16", + description="gptq_marlin_gemm", ).blocked_autorange(min_run_time=min_run_time) ) results.append( benchmark.Timer( - stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -210,10 +281,7 @@ def bench_run( ).blocked_autorange(min_run_time=min_run_time) ) - if ( - quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES - ): + if marlin_24_supported: results.append( benchmark.Timer( stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 @@ -224,17 +292,18 @@ def bench_run( ).blocked_autorange(min_run_time=min_run_time) ) - results.append( - benchmark.Timer( - stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 - globals=globals, - label=label, - sub_label=sub_label, - description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time) - ) + if repack_supported: + results.append( + benchmark.Timer( + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="gptq_marlin_repack", + ).blocked_autorange(min_run_time=min_run_time) + ) - if as_supported_case: + if allspark_supported: results.append( benchmark.Timer( stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 @@ -250,7 +319,6 @@ def main(args): print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") - results: list[benchmark.Measurement] = [] for model in args.models: @@ -278,14 +346,17 @@ def main(args): ): continue - for quant_type in query_marlin_supported_quant_types(False): + for quant_type in query_marlin_supported_quant_types(): if ( len(args.limit_num_bits) > 0 and quant_type.size_bits not in args.limit_num_bits ): continue - for group_size in MARLIN_SUPPORTED_GROUP_SIZES: + for group_size in ( + MARLIN_SUPPORTED_GROUP_SIZES + + FP4_MARLIN_SUPPORTED_GROUP_SIZES + ): if ( len(args.limit_group_size) > 0 and group_size not in args.limit_group_size diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index cef53b183cef3bd9afadc2f45bbf9624c58e6ba0..07af58d81c68331597f1e63c0c62dc814b41025c 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -620,7 +620,7 @@ def main(args: argparse.Namespace): 4096, ] else: - batch_sizes = [args.batch_size] + batch_sizes = args.batch_size use_deep_gemm = bool(args.use_deep_gemm) @@ -728,7 +728,7 @@ if __name__ == "__main__": ) parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--batch-size", type=int, nargs="+", required=False) parser.add_argument("--tune", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--model-prefix", type=str, required=False) diff --git a/benchmarks/kernels/benchmark_moe_align_block_size.py b/benchmarks/kernels/benchmark_moe_align_block_size.py new file mode 100644 index 0000000000000000000000000000000000000000..5170ac09dc42af8fac0f15e48dc5d348b728f04a --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_align_block_size.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import itertools + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( + moe_align_block_size_triton, +) +from vllm.triton_utils import triton + + +def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor: + return torch.stack( + [ + torch.randperm(num_experts, dtype=torch.int32, device="cuda")[:topk] + for _ in range(num_tokens) + ] + ) + + +def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8): + """ + Verifies vllm vs. Triton + """ + topk_ids = get_topk_ids(num_tokens, num_experts, topk) + + # 1. malloc space for triton and vllm + # malloc enough space (max_num_tokens_padded) for the sorted ids + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids_triton = torch.empty( + (max_num_tokens_padded,), dtype=torch.int32, device="cuda" + ) + sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value + expert_ids_triton = torch.zeros( + (max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda" + ) + num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda") + + sorted_ids_vllm = torch.empty_like(sorted_ids_triton) + sorted_ids_vllm.fill_(topk_ids.numel()) + expert_ids_vllm = torch.zeros_like(expert_ids_triton) + num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton) + + # 2. run implementations + moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids_triton, + expert_ids_triton, + num_tokens_post_pad_triton, + ) + + ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids_vllm, + expert_ids_vllm, + num_tokens_post_pad_vllm, + ) + print(f"✅ VLLM implementation works with {num_experts} experts!") + + # 3. compare results + if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose( + num_tokens_post_pad_triton, num_tokens_post_pad_vllm + ): + print("✅ Triton and VLLM implementations match.") + else: + print("❌ Triton and VLLM implementations DO NOT match.") + print("Triton expert_ids:", expert_ids_triton) + print("VLLM expert_ids:", expert_ids_vllm) + print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton) + print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm) + + +# test configurations +num_tokens_range = [1, 16, 256, 4096] +num_experts_range = [16, 64, 224, 256, 280, 512] +topk_range = [1, 2, 8] +configs = list(itertools.product(num_tokens_range, num_experts_range, topk_range)) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["num_tokens", "num_experts", "topk"], + x_vals=configs, + line_arg="provider", + line_vals=["vllm", "triton"], # "triton" + line_names=["VLLM", "Triton"], # "Triton" + plot_name="moe-align-block-size-performance", + args={}, + ) +) +def benchmark(num_tokens, num_experts, topk, provider): + """Benchmark function for Triton.""" + block_size = 256 + topk_ids = get_topk_ids(num_tokens, num_experts, topk) + + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda") + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = max_num_tokens_padded // block_size + expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda") + num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda") + + quantiles = [0.5, 0.2, 0.8] + + if provider == "vllm": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: ops.moe_align_block_size( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + elif provider == "triton": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: moe_align_block_size_triton( + topk_ids, + num_experts, + block_size, + sorted_ids.clone(), + expert_ids.clone(), + num_tokens_post_pad.clone(), + ), + quantiles=quantiles, + ) + + return 1000 * ms, 1000 * max_ms, 1000 * min_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--num_experts", + type=int, + default=64, + choices=[8, 16, 32, 64, 128, 256], + ) + parser.add_argument( + "--topk", + type=int, + default=8, + choices=[2, 4, 8], + help="Top-k value for correctness check.", + ) + args = parser.parse_args() + + print("Running correctness check...") + check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk) + benchmark.run(print_data=True, show_plots=True) diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index e67ce054531818d6d0b59a21c0bd192c4290763a..43c54d56ca8c1be027bec0d5f54dd3e53a294771 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -85,12 +85,6 @@ def benchmark_shape(m: int, # === DeepGEMM Implementation === def deepgemm_gemm(): - # A quantization is inside the loop as it depends on activations - # A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A) - # A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8( - # A, block_size[1]) - # A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm) - # C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16) deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm), (B_deepgemm, B_scale_deepgemm), C_deepgemm) @@ -98,8 +92,6 @@ def benchmark_shape(m: int, # === vLLM Triton Implementation === def vllm_triton_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1]) return w8a8_block_fp8_matmul(A_vllm, B_vllm, A_scale_vllm, @@ -109,9 +101,6 @@ def benchmark_shape(m: int, # === vLLM CUTLASS Implementation === def vllm_cutlass_gemm(): - # A quantization is inside the loop as it depends on activations - # A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8( - # A, block_size[1], column_major_scales=True) return ops.cutlass_scaled_mm(A_vllm_cutlass, B_vllm.T, scale_a=A_scale_vllm_cutlass, diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5cd2c98f234381818063d7947a05ca28cb544129..fc7291972309a92ff92744679eca6166477951d6 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -12,9 +12,8 @@ endif() # # Define environment variables for special configurations # -if(DEFINED ENV{VLLM_CPU_AVX512BF16}) - set(ENABLE_AVX512BF16 ON) -endif() +set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) +set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) include_directories("${CMAKE_SOURCE_DIR}/csrc") @@ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") + set(ENABLE_AVX512BF16 ON) else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") endif() else() + set(ENABLE_AVX512BF16 OFF) message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") endif() + + find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) + if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI) + if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) + list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") + set(ENABLE_AVX512VNNI ON) + else() + set(ENABLE_AVX512VNNI OFF) + message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3") + endif() + else() + set(ENABLE_AVX512VNNI OFF) + message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") + endif() elseif (AVX2_FOUND) list(APPEND CXX_COMPILE_FLAGS "-mavx2") @@ -231,12 +248,25 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) + if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) + set(VLLM_EXT_SRC + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + ${VLLM_EXT_SRC}) + add_compile_definitions(-DCPU_CAPABILITY_AVX512) + endif() elseif(POWER10_FOUND) set(VLLM_EXT_SRC "csrc/cpu/quant.cpp" ${VLLM_EXT_SRC}) endif() +message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") + # # Define extension targets # diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index a4edd5b96fe29ccfece190016aad0cb6c7d283d6..ef45a5fbebf6900c59110996d0977aea89bccdf6 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 8798f27777fb57f447070301bf33a9f9c607f491 + GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 8002dd74477b6f2d0a4c3e3d64bb7edf8128c635..774e01c15af3345732bd9822f9984c5a0e9393f3 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -122,6 +122,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG) # "-DENABLE_FP8" "-U__HIP_NO_HALF_CONVERSIONS__" "-U__HIP_NO_HALF_OPERATORS__" + "-Werror=unused-variable" "-fno-gpu-rdc" "--gpu-max-threads-per-block=1024") @@ -265,8 +266,8 @@ macro(set_gencode_flags_for_srcs) endmacro() # -# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form -# `.[letter]` compute the "loose intersection" with the +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the # `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in # `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there # is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the @@ -278,7 +279,7 @@ endmacro() # in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. # We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is # in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add -# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). +# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS). # The result is stored in `OUT_CUDA_ARCHS`. # # Example: @@ -313,21 +314,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) - list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") - set(_CUDA_ARCHS "9.0a") - endif() - endif() - - if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) - list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") - if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") - set(_CUDA_ARCHS "10.0a") + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\a$") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + string(REPLACE "a" "" _base "${_arch}") + if ("${_base}" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}") + list(APPEND _CUDA_ARCHS "${_arch}") + endif() endif() - endif() + endforeach() list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) @@ -359,7 +355,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) - + # reapply +PTX suffix to architectures that requested PTX set(_FINAL_ARCHS) foreach(_arch ${_CUDA_ARCHS}) @@ -370,7 +366,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endif() endforeach() set(_CUDA_ARCHS ${_FINAL_ARCHS}) - + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index f4b6b19f4b232c8bfd9e66ecf693a0dbfa3a1068..9d05d910dd81f56e9a139b95b9c0f488fe1a69fb 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out, "page_table must be a 32-bit integer tensor"); auto in_dtype = q_nope.dtype(); - at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device()); if (in_dtype == at::ScalarType::Half) { diff --git a/csrc/attention/paged_attention_v1.cu b/csrc/attention/paged_attention_v1.cu index 9b3a5c4b1014a6bb08340054ddd3671830538ac0..46108a32d719b8ca450ffd1cf298599af5f1716c 100644 --- a/csrc/attention/paged_attention_v1.cu +++ b/csrc/attention/paged_attention_v1.cu @@ -65,9 +65,6 @@ void paged_attention_v1_launcher( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes @@ -193,4 +190,4 @@ void paged_attention_v1( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index 9935359e02fb1855ba972d9df54e31b129e02463..9358c0d9f6a2a6fb917db68a361e7e29541fd379 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -66,9 +66,6 @@ void paged_attention_v2_launcher( int kv_block_stride = key_cache.stride(0); int kv_head_stride = key_cache.stride(1); - [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - // NOTE: alibi_slopes is optional. const float* alibi_slopes_ptr = alibi_slopes @@ -203,4 +200,4 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index 0257d8ff16baf0ccdfe1db1d85ff796a03feeee1..82862fea7f2be78bb562bfce673b324b3ad97b3d 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -137,8 +137,8 @@ FORCE_INLINE std::pair reduceSoftmaxAlibi(T* data, const int size, } template -FORCE_INLINE void reducePartitonSoftmax(const T* max_data, T* sum_data, - const int size) { +FORCE_INLINE void reducePartitionSoftmax(const T* max_data, T* sum_data, + const int size) { T max = max_data[0]; for (int i = 1; i < size; ++i) { max = max >= max_data[i] ? max : max_data[i]; @@ -634,7 +634,7 @@ struct paged_attention_v2_impl { if (partition_num == 1) continue; - reducePartitonSoftmax( + reducePartitionSoftmax( max_logits + seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions, exp_sums + seq_idx * num_heads * max_num_partitions + diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index 9a613ba588ddfe98c01b7d1e6de96f8b115307f5..3952c43cbc727de4dcdb2de2fa447d837742d123 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -83,7 +83,7 @@ struct FP16Vec16 : public Vec { explicit FP16Vec16(const void* ptr) : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} - // non-temproal load + // non-temporal load explicit FP16Vec16(bool, void* ptr) : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} @@ -120,7 +120,7 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const void* ptr) : reg((__m256i)_mm256_loadu_si256((__m256i*)ptr)) {} - // non-temproal load + // non-temporal load explicit BF16Vec16(bool, void* ptr) : reg(_mm256_stream_load_si256((__m256i*)ptr)) {} @@ -327,7 +327,7 @@ struct FP32Vec16 : public Vec { // normal load explicit FP32Vec16(const float* ptr) : reg(_mm512_loadu_ps(ptr)) {} - // non-temproal load + // non-temporal load explicit FP32Vec16(bool, void* ptr) : reg((__m512)_mm512_stream_load_si512(ptr)) {} @@ -576,7 +576,7 @@ struct INT8Vec64 : public Vec { // normal load explicit INT8Vec64(void* ptr) : reg(_mm512_loadu_epi8(ptr)) {} - // non-temproal load + // non-temporal load explicit INT8Vec64(bool, void* ptr) : reg(_mm512_stream_load_si512(ptr)) {} void save(void* ptr) const { _mm512_storeu_epi8(ptr, reg); } @@ -587,7 +587,7 @@ struct INT8Vec64 : public Vec { _mm512_mask_storeu_epi8(ptr, mask, reg); } - // non-temproal save + // non-temporal save void nt_save(int8_t* ptr) { _mm512_stream_si512((__m512i*)ptr, reg); } }; #endif diff --git a/csrc/cpu/sgl-kernels/common.h b/csrc/cpu/sgl-kernels/common.h new file mode 100644 index 0000000000000000000000000000000000000000..20261c1ef3e871034a9d38d06d40bb68f7f9bf0d --- /dev/null +++ b/csrc/cpu/sgl-kernels/common.h @@ -0,0 +1,238 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +#include +#include +#include + +// clang-format off + +#if defined(_OPENMP) +#include +#endif + +namespace { + +// dispatch bool +#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// dispatch: bfloat16, float16, int8_t, fp8_e4m3 +#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case at::ScalarType::BFloat16 : { \ + using packed_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using packed_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Char : { \ + using packed_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Float8_e4m3fn : { \ + using packed_t = at::Float8_e4m3fn; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + }() + +#define UNUSED(x) (void)(x) + +#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + +#define CHECK_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +// parallel routines +constexpr int GRAIN_SIZE = 1024; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { return (x + y - 1) / y; } + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel +{ + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); +} +#else + f(0, n); +#endif +} + +// for 1d parallel, use `actual_nth` +// for 2d parallel, use even nths, e.g. 43->42 +int inline adjust_num_threads(int m) { + int actual_nth = at::get_num_threads(); + if (m == 1) { + return actual_nth; + } + return std::max(1, (actual_nth >> 1) * 2); +} + +template +inline void parallel_2d(int m, int n, const func_t& f) { + + // make sure we have even num_threads + int nth = adjust_num_threads(m); + + // [NOTE] thread blocking: + // + // 1) prefer square block per thread + // 2) use even number of CPU cores + // 3) use all `num_threads` cores + // + // we have: + // TM * TN = T + // BM / TM = BN / TN + // then: + // TM = ((BM / BN) * T) ^ 0.5 + // + float r = float(m) / n; + int nth_m = std::ceil(std::sqrt(r * nth)); + int nth_n = 1; + for (; nth_m > 0; --nth_m) { + nth_n = nth / nth_m; + if (nth_m * nth_n == nth) { + break; + } + } + +#if defined(_OPENMP) +#pragma omp parallel num_threads(nth) +{ + int ith = omp_get_thread_num(); + int ith_m = ith / nth_n; + int ith_n = ith % nth_n; + + int thread_block_m = div_up(m, nth_m); + int thread_block_n = div_up(n, nth_n); + + int begin_m = ith_m * thread_block_m; + int end_m = std::min(m, begin_m + thread_block_m); + int begin_n = ith_n * thread_block_n; + int end_n = std::min(n, begin_n + thread_block_n); + + f(begin_m, end_m, begin_n, end_n); +} +#else + f(0, m, 0, n); +#endif +} + +template +int get_cache_blocks(int BLOCK_SIZE, int K) { + // L2 2MB and ratio of 50% + const int L2_size = 2048 * 1024 >> 1; + return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T)))); +} + +// data indexing for dimension collapse +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// forced unroll for perf critical path + +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // anonymous namespace diff --git a/csrc/cpu/sgl-kernels/gemm.cpp b/csrc/cpu/sgl-kernels/gemm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c122d07185ddbc9953943c5b2f091b8692095f88 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.cpp @@ -0,0 +1,464 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// packed layout: +// quants {N, K} int8_t +// comp {N} int32_t +template +inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + __m512i vcomp[COLS]; + + for (int col = 0; col < COLS; ++col) { + vcomp[col] = _mm512_setzero_si512(); + } + + const int64_t offset = BLOCK_N * K; + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < K / 4; ++k) { + for (int col = 0; col < COLS; ++col) { + __m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64)); + vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); + } + } + + for (int col = 0; col < COLS; ++col) { + _mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]); + } +#else + TORCH_CHECK(false, "s8s8_compensation not implemented!"); +#endif +} + +// convert to vnni format +// from [N, K] to [K/2, N, 2] for bfloat16 and float16 +template +inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { + const int VNNI_BLK = 2; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } +} + +template <> +inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + TORCH_CHECK(N == BLOCK_N); + + const int VNNI_BLK = 4; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } + s8s8_compensation(packed, K); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + // for COLS = 1, 3 use 256bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256( + reinterpret_cast<__m256i*>(C + row * ldc + col * 16), + (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, const float* __restrict__ bias, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm( + M, N, K, lda, ldb, BLOCK_N, /* add_C */false, + A, B, Ctmp); + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + if (brg) { + brgemm::apply( + A, B, C, Ctmp, bias, + M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void weight_packed_linear_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx + const bool use_brgemm = (M > 4) || (!std::is_same_v); + + // l2 cache block for n + int64_t cache_blocks_nb = get_cache_blocks(BLOCK_N, K); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) { + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) { + for (int64_t mb = begin_mb; mb < end_mb; ++mb) { + for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) { + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, + /* C */ out + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + }}} + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C, + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \ + float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \ + int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor convert_weight_packed(at::Tensor& weight) { + // for 3d moe weights + // weight : [E, OC, IC] + // w1 : [E, 2N, K] + // w2 : [E, K, N] + CHECK_INPUT(weight); + + const int64_t ndim = weight.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = weight.scalar_type(); + const int64_t E = ndim == 3 ? weight.size(0) : 1; + const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); + const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + + // we handle 2 TILE_N at a time. + TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); + TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t NB = div_up(OC, BLOCK_N); + + // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] + auto packed_weight = at::empty({}, weight.options()); + const int64_t stride = OC * IC; + + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn, + "expect weight to be bfloat16, float16, int8 or fp8_e4m3."); + + CPU_DISPATCH_PACKED_TYPES(st, [&] { + // adjust most inner dimension size + const int packed_row_size = get_row_size(IC); + auto sizes = weight.sizes().vec(); + sizes[ndim - 1] = packed_row_size; + packed_weight.resize_(sizes); + + const packed_t* w_data = weight.data_ptr(); + packed_t* packed_data = packed_weight.data_ptr(); + + // parallel on {E, NB} + at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}; + data_index_init(begin, e, E, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int64_t n = nb * BLOCK_N; + int64_t n_size = std::min(BLOCK_N, OC - n); + pack_vnni( + packed_data + e * OC * packed_row_size + n * packed_row_size, + w_data + e * stride + n * IC, + n_size, + IC); + + // move to the next index + data_index_step(e, E, nb, NB); + } + }); + }); + return packed_weight; +} + +// mat1 : [M, K] +// mat2 : [N, K] +// bias : [N] +// out : [M, N] +// +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + auto out = at::empty({M, N}, mat1.options()); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { + weight_packed_linear_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm.h b/csrc/cpu/sgl-kernels/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..afae19721ae9656c51d98ebf2b1fe09a8791c8b6 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm.h @@ -0,0 +1,266 @@ +#pragma once + +#include + +// clang-format off + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { return 2 * TILE_M; } +constexpr int block_size_n() { return 2 * TILE_N; } + +// define threshold using brgemm (intel AMX) +template inline bool can_use_brgemm(int M); +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return true; } +// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +template <> inline bool can_use_brgemm(int M) { return false; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } +template <> inline bool can_use_brgemm(int M) { return M > 4; } + +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// adjust leading dimension size for K +template +inline int64_t get_row_size(int64_t K) { + return K; +} + +template <> +inline int64_t get_row_size(int64_t K) { + return K + sizeof(int32_t); +} + +inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { + return use_int8_w8a8 ? K + sizeof(int32_t) : K; +} + +// pack weight to vnni format +at::Tensor convert_weight_packed(at::Tensor& weight); + +// moe implementations for int8 w8a8 +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for fp8 w8a16 +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// moe implementations for int4 w4a16 +template +void fused_experts_int4_w4a16_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::quint4x2* __restrict__ packed_w1, + const at::quint4x2* __restrict__ packed_w2, + const uint8_t* __restrict__ w1z, + const uint8_t* __restrict__ w2z, + const scalar_t* __restrict__ w1s, + const scalar_t* __restrict__ w2s, + int group_size, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// shared expert implememntation for int8 w8a8 +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::quint4x2* __restrict__ B, + scalar_t* __restrict__ C, + const uint8_t* __restrict__ Bz, + const scalar_t* __restrict__ Bs, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int group_size, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t strideBz, + int64_t strideBs, + bool brg); + +// TODO: debug print, remove me later +inline void print_16x32i(const __m512i x) { + int32_t a[16]; + _mm512_storeu_si512((__m512i *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + +inline void print_16x32(const __m512 x) { + float a[16]; + _mm512_storeu_ps((__m512 *)a, x); + + for (int i = 0; i < 16; i++){ + std::cout << a[i] << " "; + } + std::cout << std::endl; +} + + +inline void print_32x8u(const __m256i x) { + uint8_t a[32]; + _mm256_storeu_si256((__m256i *)a, x); + + for (int i = 0; i < 32; ++i) { + std::cout << int32_t(a[i]) << " "; + } + std::cout << std::endl; +} diff --git a/csrc/cpu/sgl-kernels/gemm_fp8.cpp b/csrc/cpu/sgl-kernels/gemm_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b5f2f07bad623aa60d29f96bcde7ba9af49c1701 --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_fp8.cpp @@ -0,0 +1,530 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +// we use 4x32 for BLOCK_M +#define BLOCK_SIZE_M_SCALE 4 + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +inline void unpack_B( + at::BFloat16* __restrict__ Btmp, + const at::Float8_e4m3fn* __restrict__ packed_B, + int N, + int K, + int ldb, + int ldb_tmp, + float scale) { +#if defined(CPU_CAPABILITY_AVX512) + // [K/2, N, 2] + const int K2 = K >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const uint16_t* b_ptr = reinterpret_cast(packed_B); + const __m512 vd = _mm512_set1_ps(scale); + + constexpr int BLOCK_N = block_size_n(); + static_assert(BLOCK_N == 32); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + +#pragma GCC unroll 4 + for (int k = 0; k < K2; ++k) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0); + } + + __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); + __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); + + __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); + __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); + + // Apply scale + __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); + __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); + __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); + __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); + + f0_lo = _mm512_mul_ps(f0_lo, vd); + f0_hi = _mm512_mul_ps(f0_hi, vd); + f1_lo = _mm512_mul_ps(f1_lo, vd); + f1_hi = _mm512_mul_ps(f1_hi, vd); + + bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); + bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); + + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0); + _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1); + } +#else + TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); +#endif +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + const int KB = div_up(K, BLOCK_K); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 64; + constexpr int PREFETCH_SIZE_KB = 1; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vsum[ROWS * COLS]; + + // block quant scale + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_setzero_ps(); + } + }; + Unroll{}(loadc); + + const int lda2 = lda >> 1; + const int ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const uint16_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0); + } + } + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0)); + vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1)); + } + } + vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]); + }; + + constexpr int BLOCK_K2 = BLOCK_K >> 1; + for (int kb = 0; kb < KB; ++kb) { + int kb_start = kb * BLOCK_K2; + int kb_end = std::min(K, kb_start + BLOCK_K2); + // 1. load scale vector + vscale = _mm512_set1_ps(scale[kb]); + if constexpr (PREFETCH_SIZE_KB > 0) { + _mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0); + } + // 2. zero vsum for each block + Unroll{}([&](auto i) { + vsum[i] = _mm512_setzero_ps(); + }); + // 3. accumulate across each block + for (int k = kb_start; k < kb_end; ++k) { + Unroll{}(compute, k); + } + // 4. apply scale + Unroll{}([&](auto i) { + vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); + }); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2,4 use 512bit store + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const packed_t* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + TORCH_CHECK(false, "struct brgemm: primary template not implemented!"); + } +}; + +template +struct brgemm { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* __restrict__ C, + at::BFloat16* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + const float* __restrict__ scale, + int M, + int N, + int K, + int lda, + int ldb, + int ldc) { + + constexpr int BLOCK_N = block_size_n(); + + // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2] + const int ldb_tmp = BLOCK_N; + + for (int k = 0; k < K; k += BLOCK_K) { + int kb_size = std::min(BLOCK_K, K - k); + + int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 + unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); + } + + at::native::cpublas::brgemm( + M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp); + + // copy from Ctmp to C + for (int m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + + if (brg) { + brgemm::apply( + A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fp8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const at::Float8_e4m3fn* __restrict__ mat2, + const float* __restrict__ scales2, + const float* __restrict__ bias, + scalar_t* __restrict__ buffer, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM, + int64_t block_size_N, + int64_t block_size_K, + int64_t buffer_size_per_thread) { + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + const int64_t scale_size_K = div_up(K, block_size_K); + const int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + int tid = at::get_thread_num(); + scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread; + float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K)); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K; + + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K, // nb * BLOCK_N * K + /* C */ out + mb_start * out_strideM + nb_start, + /* Btmp */ Btmp, + /* Ctmp */ Ctmp, + /* scale */ scale_ptr, + /* bias */ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const at::Float8_e4m3fn* __restrict__ B, + scalar_t* __restrict__ C, + scalar_t* __restrict__ Btmp, + float* __restrict__ Ctmp, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg, + int64_t block_size_K) { + tinygemm_kernel(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const at::Float8_e4m3fn* __restrict__ B, \ + TYPE* __restrict__ C, \ + TYPE* __restrict__ Btmp, \ + float* __restrict__ Ctmp, \ + const float* __restrict__ scale, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg, \ + int64_t block_size_K) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + std::vector block_size, std::optional& bias, + at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, block_size, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales2 to be float32."); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + TORCH_CHECK(block_size.size() == 2, + "fp8_scaled_mm_cpu: expect block_size.size() to be 2."); + + int64_t block_size_N = block_size[0]; + int64_t block_size_K = block_size[1]; + + constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE; + constexpr int64_t BLOCK_N = block_size_n(); + TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); + TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); + CHECK_EQ(scales2.size(0), div_up(N, block_size_N)); + CHECK_EQ(scales2.size(1), div_up(K, block_size_K)); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "fp8_scaled_mm_cpu: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, + "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "fp8_scaled_mm_cpu: expect scales to be float32."); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + // Btmp : [T, BLOCK_N * K] + // Ctmp : [T, BLOCK_M * BLOCK_N] + int num_threads = at::get_num_threads(); + int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2; + auto buffer = at::empty({num_threads, size_per_thread}, mat1.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { + fp8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales2.data_ptr(), + bias_data, + buffer.data_ptr(), + M, + N, + K, + mat1_strideM, + out_strideM, + block_size_N, + block_size_K, + size_per_thread); + }); + + return out; +} diff --git a/csrc/cpu/sgl-kernels/gemm_int8.cpp b/csrc/cpu/sgl-kernels/gemm_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a0f65a9200d4233b1ab86985c29495d4aa941bb --- /dev/null +++ b/csrc/cpu/sgl-kernels/gemm_int8.cpp @@ -0,0 +1,440 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vd0; + __m512 vd1[COLS]; + + // oops! 4x4 spills but luckly we use 4x2 + __m512 vbias[COLS]; + + // [NOTE]: s8s8 igemm compensation in avx512-vnni + // + // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: + // + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // 1) 128 * b is pre-computed when packing B to vnni formats + // 2) a + 128 is fused when dynamically quantize A + // + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + vd0 = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if constexpr (has_bias) { + vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); + vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); + } + } + } + + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); + __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); + if constexpr (has_bias) { + vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); + vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); + } else { + vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); + vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); + } + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void int8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const uint8_t* __restrict__ mat1, + const int8_t* __restrict__ mat2, + const float* __restrict__ scales1, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use int32_t for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * K, + /* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ out + mb_start * N + nb_start, + /* Ctmp*/ Ctmp, + /* As */ scales1 + mb_start, + /* Bs */ scales2 + nb_start, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ N, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) { + tinygemm_kernel(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \ + int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \ + int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +std::tuple per_token_quant_int8_cpu(at::Tensor& A) { + RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector({A})); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); + CHECK_DIM(2, A); + + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + + const auto st = A.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "per_token_quant_int8: expect A to be bfloat16 or half."); + + auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); + auto As = at::empty({M}, A.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { + uint8_t* __restrict__ Aq_data = Aq.data_ptr(); + float* __restrict__ As_data = As.data_ptr(); + const scalar_t* __restrict__ A_data = A.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + }); + return std::make_tuple(Aq, As); +} + +// weight : static, per-channel, symmetric +// activation : dynamic, per-token, symmetric +// +// mat1 : [M, K] +// mat2 : [N, K] +// scales1 : [M] +// scales2 : [N] +// bias : [N] +// out : [M, N] +// +at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales1, at::Tensor& scales2, + std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales1); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales1.numel(), M); + CHECK_EQ(scales2.numel(), N); + + TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); + TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, + "int8_scaled_mm: expect scales to be float32."); + + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { + int8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales1.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} + +// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2, + const std::optional& bias, at::ScalarType out_dtype, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + int64_t lda = mat1.stride(0); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales2.numel(), N); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, + "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, + "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, + "int8_scaled_mm_with_quant: expect mat2 to be int8."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, + "int8_scaled_mm_with_quant: expect scales to be float32."); + + const int64_t buffer_size = M * K + M * sizeof(float); + auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); + const scalar_t* __restrict__ A_data = mat1.data_ptr(); + + at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_data + m * K, + As_data[m], + A_data + m * lda, + K); + } + }); + + int8_scaled_mm_kernel_impl( + out.data_ptr(), + Aq_data, + packed_w.data_ptr(), + As_data, + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} diff --git a/csrc/cpu/sgl-kernels/moe.cpp b/csrc/cpu/sgl-kernels/moe.cpp new file mode 100644 index 0000000000000000000000000000000000000000..beeccff783ea04cb20717d8328841e7f91a4cada --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe.cpp @@ -0,0 +1,1330 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +// [NOTE]: Fused MoE kernel with AMX +// +// This file contains implementations for +// * `moe_align_block_size` +// * `fused_moe` +// +// The functionality is identical to triton kernel, excepts: +// * fuse silu_and_mul with gemm1, therefore this kernel +// allocates 2 intermediate_caches instead of 3 +// * add `offsets` in `moe_align_block_size` which keeps track +// of starting offset for each M block. this is for keeping +// output of silu_and_mul in sorted order, thus load_A for +// the 2nd gemm would be contiguous, therefore we can directly +// load A from intermediate_cache1. +// +// TODO: +// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2) +// 2. add prefetch for load A which is indexed access +// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1) +// + +template +inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +int moe_align_block_size( + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ topk_ids, + int32_t* __restrict__ total_cnts, + int32_t* __restrict__ cumsums, + int32_t* __restrict__ offsets, + int num_experts, + int numel, + int num_threads) { + + #define T_INDEX(tt) total_cnts + (tt) * num_experts + + // accumulate count of expert ids locally + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + int32_t* __restrict__ local_cnts = T_INDEX(tid + 1); + + for (int i = begin; i < end; ++i) { + local_cnts[topk_ids[i]]++; + } + }); + + using iVec = at::vec::Vectorized; + for (int t = 0; t < num_threads; ++t) { + at::vec::map2( + [](iVec x, iVec y) { return x + y; }, + T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts); + } + + // the last row holds sums of each experts + int32_t* total_cnts_t_1 = T_INDEX(num_threads); + + cumsums[0] = 0; + for (int e = 0; e < num_experts; ++e) { + // accumulate `num_tokens_post_pad`, also as the expert offset + cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M; + + for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) { + expert_ids[k / BLOCK_M] = e; + } + } + int num_tokens_post_pad = cumsums[num_experts]; + + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + // thread tid offsets in `total_cnts` + int32_t* __restrict__ offsets = T_INDEX(tid); + + for (int i = begin; i < end; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t b_offset = cumsums[expert_id]; + int32_t t_offset = offsets[expert_id]; + sorted_ids[b_offset + t_offset] = i; + offsets[expert_id]++; + } + }); + + // debug: the offset for thread t_1 should be identical to t_2 + int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1); + for (int e = 0; e < num_experts; ++e) { + TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]); + } + + // padding value for sorted_ids: numel + auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) { + for (int d = 0; d < BLOCK_M; ++d) { + if (sorted_ids_ptr[d] == numel) { return d; } + } + return BLOCK_M; + }; + + // offsets holds starting offset for each valida M blocks + // shape : [num_token_blocks + 1] + offsets[0] = 0; + const int num_token_blocks = num_tokens_post_pad / BLOCK_M; + at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) { + for (int mb = begin; mb < end; ++mb) { + offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); + } + }); + // TODO: do we need to vecterize this ? + for (int mb = 0; mb < num_token_blocks; ++mb) { + offsets[mb + 1] += offsets[mb]; + } + // debug: the last value of offsets should be `numel` + TORCH_CHECK(offsets[num_token_blocks] == numel); + + return num_tokens_post_pad; +} + +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M * topk, N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * BLOCK_N; + const float* __restrict__ y = input1 + m * BLOCK_N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B0, const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B0, const at::BFloat16* __restrict__ B1, + at::BFloat16* __restrict__ C, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb0[COLS]; + __m512bh vb1[COLS]; + __m512 vc0[ROWS * COLS]; + __m512 vc1[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_ps(0.f); + vc1[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b0_ptr = reinterpret_cast(B0); + const float* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16)); + vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + _mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = vc0[row * COLS + col + 0]; + Vec x1 = vc0[row * COLS + col + 1]; + Vec y0 = vc1[row * COLS + col + 0]; + Vec y1 = vc1[row * COLS + col + 1]; + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn2::apply( \ + A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, \ + C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]); + + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // pattern: 1-2-8 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: LAUNCH_TINYGEMM_KERNEL_NN2(1, 32); break; + // mb_size = 2 + case 0x22: LAUNCH_TINYGEMM_KERNEL_NN2(2, 32); break; + // mb_size = 3 + case 0x32: LAUNCH_TINYGEMM_KERNEL_NN2(3, 32); break; + // mb_size = 4 + case 0x42: LAUNCH_TINYGEMM_KERNEL_NN2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fused_experts_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul( + ic1 + offset * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +template +void shared_expert_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + //int64_t mb_start = mb * BLOCK_M; + //int64_t mb_size = std::min(M - mb_start, BLOCK_M); + + // A shape [m_size, K] + const scalar_t* A = input + mb * BLOCK_M * K; + + // B shape [K, n_size] in vnni format + const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + silu_and_mul( + ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + C0, + C1, + m_size, + N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: output = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A shape [m_size, IC] + const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; + + // B shape [IC, n_size] in vnni format + const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// common checks +static inline void check_moe_scales( + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale) { + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + if (use_fp8_w8a16) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16."); + TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16."); + TORCH_CHECK(block_size.value().size() == 2, "expect block_size.size() to be 2."); + } +} + +#define CHECK_MOE_SCALES_FP8(DIM0, DIM1) \ + auto w1s = w1_scale.value(); \ + auto w2s = w2_scale.value(); \ + auto block_size_val = block_size.value(); \ + int64_t block_size_N = block_size_val[0]; \ + int64_t block_size_K = block_size_val[1]; \ + TORCH_CHECK(w1s.size(DIM0) == 2 * N / block_size_N); \ + TORCH_CHECK(w1s.size(DIM1) == K / block_size_K); \ + TORCH_CHECK(w2s.size(DIM0) == K / block_size_N); \ + TORCH_CHECK(w2s.size(DIM1) == N / block_size_K) + +// hidden_states: [M, K] +// w1: [E, 2N, K] +// w2: [E, K, N] +// topk_weights: [M, topk] +// topk_ids: [M, topk] (int32_t) +// +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_EQ(topk_weights.sizes(), topk_ids.sizes()); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w1); + CHECK_DIM(3, w2); + CHECK_DIM(2, topk_weights); + CHECK_DIM(2, topk_ids); + + CHECK_EQ(topk_ids.scalar_type(), at::kInt); + CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(1) / 2; + int64_t E = w1.size(0); + int64_t topk = topk_weights.size(1); + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), E); + CHECK_EQ(w2.size(1), K); + CHECK_EQ(packed_w1.size(2), packed_K); + CHECK_EQ(packed_w2.size(2), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // NB: worst case is each expert holds a block with remainder of 1 + // 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)] + // 2. expert_ids : [max_num_blocks] + // 3. total_cnts : [T + 1, E] + // 4. cumsums : [E + 1] + // 5. offsets : [max_num_blocks + 1] + // + int num_threads = at::get_num_threads(); + int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1); + int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M); + auto buffer = at::empty( + {max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)}, + topk_ids.options()); + + int32_t* __restrict__ sorted_ids = buffer.data_ptr(); + int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded; + int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks; + int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E; + int32_t* __restrict__ offsets = cumsums + (E + 1); + + // init sorted_ids with `numel` as the padding number + // init expert_ids with `num_experts` + int64_t numel = M * topk; + at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) { + int64_t m_start = begin * BLOCK_M; + int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start); + fill_stub(sorted_ids + m_start, (int32_t)numel, m_size); + fill_stub(expert_ids + begin, (int32_t)E, end - begin); + }); + // zero total_cnts and cumsums + at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) { + fill_stub(total_cnts + begin, 0, end - begin); + }); + + // align experts index + int64_t num_tokens_post_pad = moe_align_block_size( + sorted_ids, expert_ids, topk_ids.data_ptr(), total_cnts, cumsums, offsets, E, numel, num_threads); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M * topk, N] + // 2. intermediate_cache2 : [M * topk, K] + // 3. A_tmp : [T, BLOCK_M * K] + // 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 5. Aq_tmp : [M, K] or [M * topk, N] + // 6. As_tmp : [M * topk] + // + // for fp8 w8a16: + // 7. intermediate_cache0 : [M * topk, 2N] + // 8. B_tmp : [T, BLOCK_N, std::max(K, N)] + // + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2; + } + + auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr())); + scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N; + + if (use_int8_w8a8) { + uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == E * 2 * N); + TORCH_CHECK(w2s.numel() == E * K); + + fused_experts_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else if (use_fp8_w8a16) { + // here we just ignore C_tmp as it is not used + scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N)); + + CHECK_MOE_SCALES_FP8(1, 2); + fused_experts_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + intermediate_cache2, + A_tmp, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else { + scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + fused_experts_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } + }); + return out_hidden_states; +} + +// shared expert kernel +// +// hidden_states: [M, K] +// w1: [2N, K] +// w2: [K, N] +// fused_experts_out +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + bool use_fp8_w8a16, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional> block_size, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(fused_experts_out); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_DIM(2, hidden_states); + CHECK_DIM(2, w1); + CHECK_DIM(2, w2); + CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes()); + CHECK_EQ(hidden_states.scalar_type(), st); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(0) / 2; + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), K); + CHECK_EQ(packed_w1.size(1), packed_K); + CHECK_EQ(packed_w2.size(1), packed_N); + + // check scales + check_moe_scales(use_int8_w8a8, use_fp8_w8a16, w1_scale, w2_scale, block_size, a1_scale, a2_scale); + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M, N] + // 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 3. Aq_tmp : [M, K] or [M, N] + // 4. As_tmp : [M] + // + // for fp8 w8a16: + // 5. intermediate_cache0 : [M, 2N] + // 6. B_tmp: [T, BLOCK_M, max(K, N)] + // + int num_threads = at::get_num_threads(); + int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); + } + if (use_fp8_w8a16) { + buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2; + } + + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); + float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + + if (use_int8_w8a8) { + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == 2 * N); + TORCH_CHECK(w2s.numel() == K); + + shared_expert_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else if (use_fp8_w8a16) { + scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N)); + + CHECK_MOE_SCALES_FP8(0, 1); + shared_expert_fp8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache0, + intermediate_cache1, + B_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + block_size_N, + block_size_K, + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else { + shared_expert_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } + }); + return out_hidden_states; +} diff --git a/csrc/cpu/sgl-kernels/moe_fp8.cpp b/csrc/cpu/sgl-kernels/moe_fp8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84a6af267740a7e3a40691a0c3566d7eaf0c7f1b --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_fp8.cpp @@ -0,0 +1,502 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "gemm.h" +#include "vec.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + x0 = x0 * weight_vec; + x1 = x1 * weight_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +inline void silu_and_mul_stub( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ input2, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec one = fVec(1.f); + + // no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x = bVec::loadu(input + d); + fVec x0, x1; + std::tie(x0, x1) = at::vec::convert_to_float(x); + bVec y = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y); + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + x0 = x0 * y0; + x1 = x1 * y1; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } +} + +} // anonymous namespace + +template +void fused_experts_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_N = div_up(2 * N, block_size_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n; + const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ ic0 + offset * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + scale_size_N = div_up(K, block_size_N); + scale_size_K = div_up(N, block_size_K); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K; + + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \ + template void fused_experts_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + TYPE* __restrict__ A_tmp, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_FP8_TEMPLATE(at::Half); + +template +void shared_expert_fp8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic0, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ B_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const at::Float8_e4m3fn* __restrict__ packed_w1, + const at::Float8_e4m3fn* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + int64_t block_size_N, + int64_t block_size_K, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache0 = hidden_states @ w1 + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(2 * N, BLOCK_N); + int64_t scale_size_K = div_up(K, block_size_K); + int64_t blocks_n_per_group = block_size_N / BLOCK_N; + + const bool use_brgemm = can_use_brgemm(M); + + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N); + + tinygemm_kernel( + /* A */ input + mb * BLOCK_M * K, + /* B */ packed_w1 + nb * BLOCK_N * K, + /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ 2 * N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 1.5: intermediate_cache1 = silu(intermediate_cache0) + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + silu_and_mul_stub( + ic1 + m * N, + ic0 + m * 2 * N, + ic0 + m * 2 * N + N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(K, BLOCK_N); + scale_size_K = div_up(N, block_size_K); + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ ic1 + mb * BLOCK_M * N, + /* B */ packed_w2 + nb * BLOCK_N * N, + /* C */ C, + /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N), + /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N, + /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* brg */ use_brgemm, + /* block_size_K */ block_size_K); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } +} + +#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \ + template void shared_expert_fp8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic0, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ B_tmp, \ + float* __restrict__ C_tmp, \ + const TYPE* __restrict__ input, \ + const at::Float8_e4m3fn* __restrict__ packed_w1, \ + const at::Float8_e4m3fn* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + int64_t block_size_N, \ + int64_t block_size_K, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/moe_int8.cpp b/csrc/cpu/sgl-kernels/moe_int8.cpp new file mode 100644 index 0000000000000000000000000000000000000000..89d0fb5d9f3b727bace530b492fe882c22119ae6 --- /dev/null +++ b/csrc/cpu/sgl-kernels/moe_int8.cpp @@ -0,0 +1,769 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#include "common.h" +#include "vec.h" +#include "gemm.h" + +// clang-format off + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; + // no remainder + #pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template <> +inline void copy_stub(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { + // size might be 64x + 32 + std::memcpy(out, input, size * sizeof(uint8_t)); +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, + const scalar_t* __restrict__ input2, float scale, int64_t size) { + + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; + #pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +/// gemm for w13 +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 was; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); + __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); + vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col])); + vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col])); + }; + Unroll{}(scalec); + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); + Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); + Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); + Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \ + C + mb_start * ldc + nb_start, As + mb_start, \ + Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +/// gemm for w2 +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C, + const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp, + int64_t K, int64_t lda, int64_t ldb, int64_t ldc) { + + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 was; + __m512 vbs[COLS]; + + auto loadc = [&](auto i) { + vc[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr(col == 0) { + was = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } + } + __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni2::apply( \ + A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \ + As + mb_start, Bs + nb_start, Bcomp + nb_start, \ + K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch(mb_size << 4 | nb_size >> 4) { + case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break; + case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break; + case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break; + case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break; + default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +} // anonymous namespace + +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + + const int64_t stride_e = 2 * N * packed_K; + const int64_t stride_n = packed_K; + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) float As[BLOCK_M]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, Aq_tmp + index * K, K); + As[m] = As_tmp[index]; + } + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * packed_N; + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; + const float* __restrict__ As = As_tmp + offsets[mb]; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ + template void fused_experts_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \ + int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); + +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * K, + As_tmp[m], + input + m * K, + K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + const int64_t stride_n = packed_K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // A shape [m_size, K] + const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + const float* As = As_tmp + mb * BLOCK_M; + + // B shape [K, n_size] in vnni format + const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8( + Aq_tmp + m * N, + As_tmp[m], + ic1 + m * N, + N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A shape [m_size, IC] + const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; + const float* __restrict__ As = As_tmp + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); +} + +#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ + template void shared_expert_int8_kernel_impl ( \ + TYPE* __restrict__ output, TYPE* __restrict__ ic1, \ + float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, const float* __restrict__ w2s, \ + const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \ + int64_t M, int64_t N, int64_t K) + +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); diff --git a/csrc/cpu/sgl-kernels/vec.h b/csrc/cpu/sgl-kernels/vec.h new file mode 100644 index 0000000000000000000000000000000000000000..87955cfb2922ca2050ba0ec1234671aa6ceb4152 --- /dev/null +++ b/csrc/cpu/sgl-kernels/vec.h @@ -0,0 +1,308 @@ +// Adapted from +// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu + +#pragma once + +// clang-format off + +#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) +#define CPU_CAPABILITY_AVX512 +#endif + +#include +#include + +namespace { + +using namespace at::vec; + +template , int> = 0> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return at::vec::convert_from_float(a, b); +} + +#if defined(CPU_CAPABILITY_AVX512) + +// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics +// use native instruction for bfloat16->float32 conversion +template <> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); +} + +#define CVT_BF16_TO_FP32(a) \ + _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +#define CVT_FP16_TO_FP32(a) \ + _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +// this doesn't hanel NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { + // The following conversion is without denorm behavior, that is to say, + // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) + // Min subnorm : S.0000.001 = 2**(−9) + // 0.0019 ~ 0.0137 cannot be converted correctly. + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + auto mask = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_setzero_si512()); // mask = x & 0x7f + auto mask_nan = _mm512_cmpneq_epi16_mask( + _mm512_and_si512(x, _mm512_set1_epi16(127)), + _mm512_set1_epi16(127)); // mask_nan = x & 0x7f + auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4 + auto exponent = _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), + _mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120) + auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7))); + nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan + return (__m512bh)(_mm512_or_si512( + nonsign, + _mm512_slli_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(128)), + 8))); // add sign (x & 128) << 8 +} + +inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { + __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + __m512i lg2mant = _mm512_mask_mov_epi16( + _mm512_mask_mov_epi16( + _mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)), + _mm512_test_epi16_mask(x, _mm512_set1_epi16(4)), + _mm512_set1_epi16(2)); + return (__m512bh)(_mm512_or_si512( + _mm512_maskz_mov_epi16( + _mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()), + _mm512_mask_blend_epi16( + _mm512_test_epi16_mask(x, _mm512_set1_epi16(120)), + _mm512_or_si512( + _mm512_and_si512( + _mm512_sllv_epi16( + _mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)), + _mm512_set1_epi16(0x007f)), + _mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)), + _mm512_or_si512( + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4), + _mm512_slli_epi16( + _mm512_add_epi16( + _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)), + 7)))), + _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8))); +} + +inline __m512bh CVT_FP8_TO_BF16(__m256i a) { +#ifdef SGLANG_CPU_FP8_CVT_FTZ + return cvt_e4m3_bf16_intrinsic_no_nan(a); +#else + return cvt_e4m3_bf16_intrinsic_with_denorm(a); +#endif +} + +#endif + +// vector to scalar reduction +#if defined(CPU_CAPABILITY_AVX512) && 0 +inline float vec_reduce_sum(const Vectorized& a) { + return _mm512_reduce_add_ps(__m512(a)); +} + +inline float vec_reduce_max(const Vectorized& a) { + return _mm512_reduce_max_ps(__m512(a)); +} +#else +inline float vec_reduce_sum(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); +} + +inline float vec_reduce_max(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); +} +#endif + +// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 +template +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { + + float amax = 0.f; // absolute max + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]); + amax = std::max(amax, std::abs(val)); + } + + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]) * inv_scale; + Aq[k] = (uint8_t)(std::round(val)) + 128; + } + As = scale; +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As, + const at::BFloat16* __restrict__ A, int64_t K, float eps) { + + const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m512i off = _mm512_set1_epi32(128); + + // K is 32x, no remainder + float amax = 0.f; + __m512 vamax0 = _mm512_set1_ps(0.f); + __m512 vamax1 = _mm512_set1_ps(0.f); + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); + vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); + } + amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + const __m512 vd = _mm512_set1_ps(inv_scale); + + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + va0 = _mm512_mul_ps(va0, vd); + va1 = _mm512_mul_ps(va1, vd); + va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); + __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); + } + As = scale; +} +#endif + +// transpose utils +// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998 +#if defined(CPU_CAPABILITY_AVX512) +inline void transpose_16x16_32bit(__m512i * v) { + __m512i v1[16]; + v1[0] = _mm512_unpacklo_epi32(v[0], v[1]); + v1[1] = _mm512_unpackhi_epi32(v[0], v[1]); + v1[2] = _mm512_unpacklo_epi32(v[2], v[3]); + v1[3] = _mm512_unpackhi_epi32(v[2], v[3]); + v1[4] = _mm512_unpacklo_epi32(v[4], v[5]); + v1[5] = _mm512_unpackhi_epi32(v[4], v[5]); + v1[6] = _mm512_unpacklo_epi32(v[6], v[7]); + v1[7] = _mm512_unpackhi_epi32(v[6], v[7]); + v1[8] = _mm512_unpacklo_epi32(v[8], v[9]); + v1[9] = _mm512_unpackhi_epi32(v[8], v[9]); + v1[10] = _mm512_unpacklo_epi32(v[10], v[11]); + v1[11] = _mm512_unpackhi_epi32(v[10], v[11]); + v1[12] = _mm512_unpacklo_epi32(v[12], v[13]); + v1[13] = _mm512_unpackhi_epi32(v[12], v[13]); + v1[14] = _mm512_unpacklo_epi32(v[14], v[15]); + v1[15] = _mm512_unpackhi_epi32(v[14], v[15]); + + v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]); + v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]); + v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]); + v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]); + v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]); + v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]); + v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]); + v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]); + v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]); + v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]); + v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]); + v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]); + v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]); + v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]); + v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]); + v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]); + + v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88); + v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88); + v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88); + v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88); + v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd); + v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd); + v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd); + v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd); + v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88); + v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88); + v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88); + v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88); + v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd); + v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd); + v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd); + v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd); + + v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88); + v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88); + v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88); + v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88); + v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88); + v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88); + v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88); + v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88); + v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd); + v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd); + v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd); + v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd); + v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd); + v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd); + v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd); + v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd); +} + +// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes] +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wignored-attributes" + +// transpose from [2, 32] to [32, 2] +inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) { + // r0: {a0, a1, ..., a31} + // r1: {b0, b1, ..., b31} + // + // d0: {a0, b0, ..., a15, b15} + // d1: {a16, b16, ..., a31, b31} + // + __m512i d0 = _mm512_unpacklo_epi16(r0, r1); + __m512i d1 = _mm512_unpackhi_epi16(r0, r1); + r0 = _mm512_shuffle_i32x4(d0, d1, 0x88); + r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd); + d0 = _mm512_shuffle_i32x4(r0, r1, 0x88); + d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd); + return std::make_tuple(d0, d1); +} +#pragma GCC diagnostic pop + +#endif + +// TODO: debug print, remove me later +template +void print_array(scalar_t* ptr, int size) { + for (int d = 0; d < size; ++d) { + if (d % 16 == 0) { std::cout << std::endl; } + std::cout << ptr[d] << " "; + } + std::cout << std::endl; +} + +} // anonymous namespace diff --git a/csrc/cpu/shm.cpp b/csrc/cpu/shm.cpp index f55e96de251d06750af18b71911535d23407f5d7..9adb6f27ec411c25203b5e571930b2a0d8b611a0 100644 --- a/csrc/cpu/shm.cpp +++ b/csrc/cpu/shm.cpp @@ -7,9 +7,10 @@ namespace { #define MAX_SHM_RANK_NUM 8 -#define MAX_THREAD_NUM 12 -#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024) -#define MIN_THREAD_PROCESS_SIZE (8 * 1024) +#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024) +static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0); +#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1) +#define MIN_THREAD_PROCESS_SIZE (256) #define MAX_P2P_SEND_TENSOR_NUM 8 template @@ -32,10 +33,10 @@ struct KernelVecType { using scalar_vec_t = vec_op::FP16Vec16; }; -enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE }; - struct ThreadSHMContext { - volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM]; + volatile char _curr_thread_stamp; + volatile char _ready_thread_stamp; + char _padding1[6]; int thread_id; int thread_num; int rank; @@ -44,14 +45,19 @@ struct ThreadSHMContext { int swizzled_ranks[MAX_SHM_RANK_NUM]; void* thread_shm_ptrs[MAX_SHM_RANK_NUM]; ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM]; + size_t _thread_buffer_mask; + char _padding2[56]; ThreadSHMContext(const int thread_id, const int thread_num, const int rank, const int group_size, void* thread_shm_ptr) - : thread_id(thread_id), + : _curr_thread_stamp(1), + _ready_thread_stamp(0), + thread_id(thread_id), thread_num(thread_num), rank(rank), group_size(group_size), - _spinning_count(0) { + _spinning_count(0), + _thread_buffer_mask(0) { static_assert(sizeof(ThreadSHMContext) % 64 == 0); TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM); TORCH_CHECK((size_t)this % 64 == 0); @@ -60,7 +66,6 @@ struct ThreadSHMContext { shm_contexts[i] = nullptr; thread_shm_ptrs[i] = nullptr; swizzled_ranks[i] = (i + rank) % group_size; - thread_stats[i] = ThreadSHMStat::DONE; } set_context(rank, this, thread_shm_ptr); } @@ -77,59 +82,66 @@ struct ThreadSHMContext { template T* get_thread_shm_ptr(int rank) { - return reinterpret_cast(thread_shm_ptrs[rank]); + return reinterpret_cast( + reinterpret_cast(thread_shm_ptrs[rank]) + + (PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask)); + } + + void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; } + + char get_curr_stamp() const { return _curr_thread_stamp; } + + char get_ready_stamp() const { return _ready_thread_stamp; } + + void next_stamp() { + _mm_mfence(); + _curr_thread_stamp += 1; + } + + void commit_ready_stamp() { + _mm_mfence(); + _ready_thread_stamp = _curr_thread_stamp; } int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; } - void wait_for_all(ThreadSHMStat prev_stat) { - for (int idx = 0; idx < group_size; ++idx) { + template + void wait_for_all(Cond&& cond) { + for (int idx = 1; idx < group_size; ++idx) { int rank = get_swizzled_rank(idx); - while (thread_stats[rank] == prev_stat) { - ++_spinning_count; - _mm_pause(); - } + wait_for_one(rank, std::forward(cond)); } - vec_op::mem_barrier(); } - void wait_for_one(int rank, ThreadSHMStat prev_stat) { - while (thread_stats[rank] == prev_stat) { + template + void wait_for_one(int rank, Cond&& cond) { + ThreadSHMContext* rank_ctx = shm_contexts[rank]; + for (;;) { + char local_curr_stamp = get_curr_stamp(); + char local_ready_stamp = get_ready_stamp(); + char rank_curr_stamp = rank_ctx->get_curr_stamp(); + char rank_ready_stamp = rank_ctx->get_ready_stamp(); + if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp, + rank_ready_stamp)) { + break; + } ++_spinning_count; _mm_pause(); } - vec_op::mem_barrier(); } - void set_thread_stat(ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[this->rank] = stat; - } - } - - void set_thread_stat(int target_rank, ThreadSHMStat stat) { - for (int idx = 0; idx < group_size; ++idx) { - int rank = get_swizzled_rank(idx); - shm_contexts[rank]->thread_stats[target_rank] = stat; - } + static bool check_no_buffer_conflict(char local_curr_stamp, + char local_ready_stamp, + char rank_curr_stamp, + char rank_ready_stamp) { + char temp = rank_curr_stamp + 2; + return local_curr_stamp != temp; } - // barrier for all ranks in the group, used for all2all ops - // DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ... - void barrier(ThreadSHMStat next_stat) { - if (next_stat == ThreadSHMStat::THREAD_READY) { - set_thread_stat(ThreadSHMStat::THREAD_READY); - wait_for_all(ThreadSHMStat::DONE); - } else if (next_stat == ThreadSHMStat::SHM_DATA_READY) { - set_thread_stat(ThreadSHMStat::SHM_DATA_READY); - wait_for_all(ThreadSHMStat::THREAD_READY); - } else if (next_stat == ThreadSHMStat::DONE) { - set_thread_stat(ThreadSHMStat::DONE); - wait_for_all(ThreadSHMStat::SHM_DATA_READY); - } else { - TORCH_CHECK(false, "Invalid next_stat to barrier."); - } + static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp, + char rank_curr_stamp, char rank_ready_stamp) { + char temp = local_curr_stamp + 1; + return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp); } std::string to_string() const { @@ -164,7 +176,7 @@ class SHMManager { const int group_size) : _rank(rank), _group_size(group_size), - _thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)), + _thread_num(torch::get_num_threads()), _shm_names({""}), _shared_mem_ptrs({nullptr}), _shm_ctx(nullptr) { @@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { (total_units_num + thread_num - 1) / thread_num; int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t); int64_t max_per_thread_iteration_elem_num = - PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t); + (PER_THREAD_SHM_BUFFER_BYTES >> 1) / + sizeof(scalar_t); // Note: double buffer int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num; #pragma omp parallel for schedule(static, 1) @@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) { int64_t curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); ThreadSHMContext* thread_ctx = ctx + i; + bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num); while (curr_elem_num > 0) { - inner_func(thread_ctx, offset, curr_elem_num); + inner_func(thread_ctx, offset, curr_elem_num, fast_mode); + thread_ctx->next_stamp(); + thread_ctx->next_buffer(); offset += max_per_thread_iteration_elem_num; curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset); } @@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); @@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, thread_ctx->get_swizzled_rank(idx + 1)); }); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr, thread_data_elem_num); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); - + thread_ctx->commit_ready_stamp(); int64_t aligned_data_elem_num = (data_elem_num / vec_elem_num) * vec_elem_num; int64_t i = 0; + thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready); #pragma GCC unroll 4 for (; i < aligned_data_elem_num; i += vec_elem_num) { vec_t local_data(thread_data_ptr + i); // load from cache @@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data, reduced_data.save(thread_data_ptr + i, data_elem_num - aligned_data_elem_num); } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, shm_cc_ops::shm_cc_loop( ctx, elem_num, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; scalar_t* thread_shm_ptr = thread_ctx->get_thread_shm_ptr(rank); - thread_ctx->barrier(ThreadSHMStat::THREAD_READY); - - shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset, - data_elem_num * sizeof(scalar_t)); - - thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY); + if (!fast_mode) { + thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict); + } + shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset, + data_elem_num * sizeof(scalar_t)); + thread_ctx->commit_ready_stamp(); if (rank == dst) { shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset, data_elem_num * sizeof(scalar_t)); @@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num, scalar_t* src_ptr = thread_ctx->get_thread_shm_ptr(src_rank); // shm scalar_t* dst_ptr = outputs[src_rank] + data_offset; - shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr, - data_elem_num * sizeof(scalar_t)); + thread_ctx->wait_for_one(src_rank, + ThreadSHMContext::check_stamp_ready); + shm_cc_ops::memcpy(dst_ptr, src_ptr, + data_elem_num * sizeof(scalar_t)); } } - - thread_ctx->barrier(ThreadSHMStat::DONE); }); return; @@ -599,7 +614,7 @@ struct TensorListMeta { int8_t _padding[40]; }; -void shm_send_tensor_list_impl(ThreadSHMContext* ctx, +void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst, const std::vector& tensor_list) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl) std::vector tensor_list_with_metadata; @@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata->total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { + int64_t data_elem_num, bool fast_mode) { int rank = thread_ctx->rank; - // Wait until the receiver set the stat to DONE - thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY); - int64_t curr_shm_offset = 0; + thread_ctx->wait_for_one(dst, + ThreadSHMContext::check_no_buffer_conflict); while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata->get_data(data_offset + curr_shm_offset); frag.size = std::min(frag.size, data_elem_num - curr_shm_offset); @@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, frag.ptr, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY); + thread_ctx->commit_ready_stamp(); }); } @@ -646,8 +659,7 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, torch::Tensor metadata_tensor = torch::empty({sizeof(TensorListMeta)}, options); - // Wait until the sender set the stat of the thread 0 to SHM_DATA_READY - ctx->wait_for_one(src, ThreadSHMStat::DONE); + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); shm_cc_ops::memcpy(metadata_tensor.data_ptr(), ctx->get_thread_shm_ptr(src), sizeof(TensorListMeta)); @@ -664,9 +676,8 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, shm_cc_ops::shm_cc_loop( ctx, metadata.total_bytes, [&](ThreadSHMContext* thread_ctx, int64_t data_offset, - int64_t data_elem_num) { - // Wait until the sender set the stat to SHM_DATA_READY - thread_ctx->wait_for_one(src, ThreadSHMStat::DONE); + int64_t data_elem_num, bool fast_mode) { + ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready); int64_t curr_shm_offset = 0; while (curr_shm_offset < data_elem_num) { MemPiece frag = metadata.get_data(data_offset + curr_shm_offset); @@ -677,8 +688,6 @@ std::vector shm_recv_tensor_list_impl(ThreadSHMContext* ctx, frag.size); curr_shm_offset += frag.size; } - - thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE); }); std::vector tensor_list; @@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle, int64_t dst) { CPU_KERNEL_GUARD_IN(shm_send_tensor_list) shm_send_tensor_list_impl( - SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list); + SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst, + tensor_list); CPU_KERNEL_GUARD_OUT(shm_send_tensor_list) } @@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) { TORCH_CHECK(shm_manager); shm_manager->join(name); return shm_manager->get_shm_ctx()->to_string(); -} \ No newline at end of file +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 447e826bc1c09b83c55824d0438dad1c0e12681b..ebfc81f858367f95fbf858e9629ab13b3eba0333 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle, std::vector shm_recv_tensor_list(int64_t handle, int64_t src); +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, + const std::optional& bias, + bool is_vnni); + +at::Tensor convert_weight_packed(at::Tensor& weight); + +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2, + at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace, + bool use_int8_w8a8, bool use_fp8_w8a16, + const std::optional& w1_scale, + const std::optional& w2_scale, + const std::optional> block_size, + const std::optional& a1_scale, + const std::optional& a2_scale, bool is_vnni); + +at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, + at::Tensor& scales2, + const std::optional& bias, + at::ScalarType out_dtype, bool is_vnni); + TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -131,16 +152,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Quantization #ifdef __AVX512F__ + at::Tag stride_tag = at::Tag::needs_fixed_stride_order; // Compute int8 quantized tensor for given scaling factor. ops.def( "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," - "Tensor? azp) -> ()"); + "Tensor? azp) -> ()", + {stride_tag}); ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); // Compute int8 quantized tensor and scaling factor ops.def( "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " - "Tensor!? azp) -> ()"); + "Tensor!? azp) -> ()", + {stride_tag}); ops.impl("dynamic_scaled_int8_quant", torch::kCPU, &dynamic_scaled_int8_quant); // W8A8 GEMM, supporting symmetric per-tensor or per-row/column @@ -148,7 +172,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "cutlass_scaled_mm(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," - " Tensor b_scales, Tensor? bias) -> ()"); + " Tensor b_scales, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column // quantization. @@ -156,7 +181,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_scaled_mm_azp(Tensor! out, Tensor a," " Tensor b, Tensor a_scales," " Tensor b_scales, Tensor azp_adj," - " Tensor? azp, Tensor? bias) -> ()"); + " Tensor? azp, Tensor? bias) -> ()", + {stride_tag}); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #elif defined(__powerpc64__) // Compute int8 quantized tensor for given scaling factor. @@ -209,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)", &shm_recv_tensor_list); #endif + + // sgl-kernels +#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__) + ops.def( + "weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? " + "bias, bool is_vnni) -> Tensor"); + ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear); + ops.def("convert_weight_packed(Tensor! weight) -> Tensor"); + ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed); + ops.def( + "fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor " + "topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool " + "use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? " + "block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> " + "Tensor"); + ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu); + ops.def( + "int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, " + "Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor"); + ops.impl("int8_scaled_mm_with_quant", torch::kCPU, + &int8_scaled_mm_with_quant); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index c17a8961629a6e4c24bb28d05b8a03d860f2c399..02514edce80733f4ebcc24d1be906dfd0c4c574b 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -54,8 +54,7 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { *(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp); int page_num = numa_migrate_pages(pid, src_mask, mask); if (page_num == -1) { - TORCH_CHECK(false, - "numa_migrate_pages failed. errno: " + std::to_string(errno)); + TORCH_WARN("numa_migrate_pages failed. errno: " + std::to_string(errno)); } // restrict memory allocation node. @@ -105,4 +104,4 @@ std::string init_cpu_threads_env(const std::string& cpu_ids) { return ss.str(); } -#endif \ No newline at end of file +#endif diff --git a/csrc/custom_quickreduce.cu b/csrc/custom_quickreduce.cu new file mode 100644 index 0000000000000000000000000000000000000000..33d0d4a7226e695f32c5e1fa580d03af09ebf3d1 --- /dev/null +++ b/csrc/custom_quickreduce.cu @@ -0,0 +1,114 @@ +#include +#include +#include +#include + +#ifdef USE_ROCM + + #include "quickreduce/quick_reduce.h" + +quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size) { + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size == 6) + throw std::invalid_argument("world size == 6 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms(); + fptr->init(world_size, rank, qr_max_size); + return (quickreduce::fptr_t)fptr; +} + +void qr_destroy(quickreduce::fptr_t _fa) { + if (_fa) { + auto fa = reinterpret_cast(_fa); + fa->destroy(); + delete fa; + } +} + +torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + hipIpcMemHandle_t handle = fa->get_handle(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(hipIpcMemHandle_t))}, options); + std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t)); + return data_handle; +} + +void qr_open_handles(quickreduce::fptr_t _fa, + const std::vector& handles) { + auto fa = reinterpret_cast(_fa); + std::vector ipc_handles; + ipc_handles.reserve(handles.size()); + for (auto& handle : handles) { + // Ensure the tensor is on the same device as the current device. + hipIpcMemHandle_t ipc_handle; + std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t)); + ipc_handles.push_back(ipc_handle); + } + fa->open_ipc_handles(ipc_handles); +} + +void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp, + torch::Tensor& out, int64_t quant_level, bool cast_bf2half) { + auto fa = reinterpret_cast(_fa); + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA(); + + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize); + if (out.scalar_type() == at::ScalarType::Half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else if (out.scalar_type() == at::ScalarType::BFloat16) { + if (cast_bf2half) { + fa->allreduce(reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } else { + fa->allreduce( + reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel(), quant_level, stream); + } + } else { + throw std::runtime_error( + "quick allreduce only supports float16 and bfloat16"); + } +} + +int64_t qr_max_size() { + // The default is 2GB (2,147,483,648 bytes) + return static_cast(std::numeric_limits::max()) + 1; +} + + #define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, \ + cast_bf2half>; \ + template struct quickreduce::AllReduceTwoshot, cast_bf2half>; + +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true) +INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true) + +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false) +INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false) + +#endif // USE_ROCM \ No newline at end of file diff --git a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp index d922a3349e1e19b73dcd2282a36b429838a71820..ce7f47cf723377f8f777d7641850d6c93ac84a49 100644 --- a/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp +++ b/csrc/cutlass_extensions/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8_blockwise_scaling.hpp @@ -45,7 +45,6 @@ #include "cute/algorithm/functional.hpp" #include "cute/atom/mma_atom.hpp" #include "cute/algorithm/gemm.hpp" -#include "cute/tensor_predicate.hpp" #include "cute/numeric/arithmetic_tuple.hpp" #include "cutlass_extensions/gemm/dispatch_policy.hpp" diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index f62d08c17c6d891ca9e8bd668e281c8f75962401..c83d72751a55cdd2e74c3b64e112e8832519f2d9 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, params.conv_states_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { causal_conv1d_fwd_cuda(params, stream); @@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x, params.conv_state_indices_ptr = nullptr; } - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { causal_conv1d_update_cuda(params, stream); diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index 0c9df925bdbf60faf3631f8290c626211c76aeae..785d316025eca0a6e8b5ffca0a7ed47a974f95d8 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, ); - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + const at::cuda::OptionalCUDAGuard device_guard(device_of(u)); auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 1c255396099d5fbe3b6ac1d953a6edcece7bf098..8a913bb4a738ce0a0e848bbcf5ec560184e05013 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -1255,8 +1255,6 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float) { if (is_new_zp) { if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; - FragB frag_zp_0; - FragB frag_zp_1; int zp_quant_0, zp_quant_1; if constexpr (w_type.size_bits() == 4) { diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index 6b6a9d04a60f40df4542f0552f587d2347767cfc..462dbd1f8b380ac2b0771bec008bccf1d372cf8c 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -13,232 +13,45 @@ namespace vllm { namespace moe { -namespace { -__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, - int32_t col) { - // don't worry about overflow because num_experts is relatively small - return row * total_col + col; -} -} // namespace - -template -__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* expert_ids, - int32_t* total_tokens_post_pad, - int32_t num_experts, - int32_t block_size, size_t numel) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - extern __shared__ int32_t shared_mem[]; - int32_t* cumsum = shared_mem; // 1d tensor with shape (num_experts + 1) - token_cnts_t* tokens_cnts = - (token_cnts_t*)(shared_mem + num_experts + - 1); // 2d tensor with shape (blockDim.x + 1, num_experts) - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - *total_tokens_post_pad = static_cast(cumsum[num_experts]); - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} - -// TODO(simon): this is temporarily adapted from -// https://github.com/sgl-project/sglang/commit/31548116a8dc8c6df7e146e0587335a59fc5b9d7 -// we did this to unblock Deepseek V3 but there should be a better -// implementation to manage shared memory. -template -__global__ void moe_align_block_size_global_mem_kernel( - scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* tokens_cnts, int32_t* cumsum) { - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; - - for (int i = 0; i < num_experts; ++i) { - tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; - } - - /** - * In the first step we compute token_cnts[thread_index + 1][expert_index], - * which counts how many tokens in the token shard of thread_index are - * assigned to expert expert_index. - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; - } - - __syncthreads(); - - // For each expert we accumulate the token counts from the different threads. - if (threadIdx.x < num_experts) { - tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; - for (int i = 1; i <= blockDim.x; ++i) { - tokens_cnts[index(num_experts, i, threadIdx.x)] += - tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; - } - } - - __syncthreads(); - - // We accumulate the token counts of all experts in thread 0. - if (threadIdx.x == 0) { - cumsum[0] = 0; - for (int i = 1; i <= num_experts; ++i) { - cumsum[i] = cumsum[i - 1] + - CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], - block_size) * - block_size; - } - *total_tokens_post_pad = cumsum[num_experts]; - } - - __syncthreads(); - - /** - * For each expert, each thread processes the tokens of the corresponding - * blocks and stores the corresponding expert_id for each block. - */ - if (threadIdx.x < num_experts) { - for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; - i += block_size) { - expert_ids[i / block_size] = threadIdx.x; - } - } - - /** - * Each thread processes a token shard, calculating the index of each token - * after sorting by expert number. Given the example topk_ids = - * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, - * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a - * padding value(preset in python). - */ - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { - int32_t expert_id = topk_ids[i]; - /** The cumsum[expert_id] stores the starting index of the tokens that the - * expert with expert_id needs to process, and - * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens - * processed by the expert with expert_id within the current thread's token - * shard. - */ - int32_t rank_post_pad = - tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + - cumsum[expert_id]; - sorted_token_ids[rank_post_pad] = i; - ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; - } -} - -// taken from -// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template -__global__ void sgl_moe_align_block_size_kernel( - scalar_t* __restrict__ topk_ids, int32_t* sorted_token_ids, - int32_t* expert_ids, int32_t* total_tokens_post_pad, int32_t num_experts, - int32_t block_size, size_t numel, int32_t* cumsum) { - __shared__ int32_t shared_counts[32][8]; - - const int warp_id = threadIdx.x / 32; - const int experts_per_warp = 8; +__global__ void moe_align_block_size_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size, + size_t numel, int32_t* __restrict__ cumsum) { + extern __shared__ int32_t shared_counts[]; + + const int warp_id = threadIdx.x / WARP_SIZE; const int my_expert_start = warp_id * experts_per_warp; - // Initialize shared_counts for this warp's experts for (int i = 0; i < experts_per_warp; ++i) { - if (my_expert_start + i < num_experts) { - shared_counts[warp_id][i] = 0; + if (my_expert_start + i < padded_num_experts) { + shared_counts[warp_id * experts_per_warp + i] = 0; } } __syncthreads(); - const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); - const size_t start_idx = threadIdx.x * tokens_per_thread; + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; - for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + for (size_t i = tid; i < numel; i += stride) { int expert_id = topk_ids[i]; int warp_idx = expert_id / experts_per_warp; int expert_offset = expert_id % experts_per_warp; - atomicAdd(&shared_counts[warp_idx][expert_offset], 1); + atomicAdd(&shared_counts[warp_idx * experts_per_warp + expert_offset], 1); } __syncthreads(); - // Single thread computes cumulative sum and total tokens if (threadIdx.x == 0) { cumsum[0] = 0; for (int i = 1; i <= num_experts; ++i) { int expert_count = 0; int warp_idx = (i - 1) / experts_per_warp; int expert_offset = (i - 1) % experts_per_warp; - expert_count = shared_counts[warp_idx][expert_offset]; + expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset]; cumsum[i] = cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size; @@ -248,7 +61,6 @@ __global__ void sgl_moe_align_block_size_kernel( __syncthreads(); - // Assign expert IDs to blocks if (threadIdx.x < num_experts) { for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i += block_size) { @@ -257,13 +69,11 @@ __global__ void sgl_moe_align_block_size_kernel( } } -// taken from -// https://github.com/sgl-project/sglang/commit/cdae77b03dfc6fec3863630550b45bbfc789f957 template -__global__ void sgl_moe_token_sort_kernel(scalar_t* __restrict__ topk_ids, - int32_t* sorted_token_ids, - int32_t* cumsum_buffer, - size_t numel) { +__global__ void count_and_sort_expert_tokens_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ cumsum_buffer, + size_t numel) { const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; const size_t stride = blockDim.x * gridDim.x; @@ -290,132 +100,138 @@ __global__ void moe_sum_kernel( } } +template +__global__ void moe_align_block_size_small_batch_expert_kernel( + const scalar_t* __restrict__ topk_ids, + int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, + int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tid = threadIdx.x; + const size_t stride = blockDim.x; + + extern __shared__ int32_t shared_mem[]; + int32_t* cumsum = shared_mem; + int32_t* tokens_cnts = (int32_t*)(shared_mem + num_experts + 1); + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[(threadIdx.x + 1) * num_experts + i] = 0; + } + + for (size_t i = tid; i < numel; i += stride) { + ++tokens_cnts[(threadIdx.x + 1) * num_experts + topk_ids[i]]; + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + tokens_cnts[threadIdx.x] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[i * num_experts + threadIdx.x] += + tokens_cnts[(i - 1) * num_experts + threadIdx.x]; + } + } + + __syncthreads(); + + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = + cumsum[i - 1] + + CEILDIV(tokens_cnts[blockDim.x * num_experts + i - 1], block_size) * + block_size; + } + *total_tokens_post_pad = static_cast(cumsum[num_experts]); + } + + __syncthreads(); + + if (threadIdx.x < num_experts) { + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; + i += block_size) { + expert_ids[i / block_size] = threadIdx.x; + } + } + + for (size_t i = tid; i < numel; i += stride) { + int32_t expert_id = topk_ids[i]; + int32_t rank_post_pad = + tokens_cnts[threadIdx.x * num_experts + expert_id] + cumsum[expert_id]; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[threadIdx.x * num_experts + expert_id]; + } +} + } // namespace moe } // namespace vllm +// taken from +// https://github.com/sgl-project/sglang/blob/8b5f83ed3b7d2a49ad5c5cd5aa61c5d502f47dbc void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad) { const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - int device_max_shared_mem; - auto dev = topk_ids.get_device(); - cudaDeviceGetAttribute(&device_max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - const int32_t shared_mem_i32 = - ((num_thread + 1) * num_experts + (num_experts + 1)) * sizeof(int32_t); - const int32_t shared_mem_i16 = - ((num_thread + 1) * num_experts) * sizeof(uint16_t) + - (num_experts + 1) * sizeof(int32_t); - - bool use_global_memory = false; - bool use_i16 = false; // Use uint16_t for shared memory token counts - if (shared_mem_i32 < device_max_shared_mem) { - // Do nothing in this case. We're all set to use int32_t token counts - } else if (shared_mem_i16 < device_max_shared_mem && - topk_ids.numel() <= 65535) { - // when nelements of topk_ids is smaller than 65535 (max value of uint16), - // element value of token_cnts would also smaller than 65535, - // so we can use uint16 as dtype of token_cnts - use_i16 = true; - } else { - use_global_memory = true; - } + int64_t padded_num_experts = + ((num_experts + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; + int experts_per_warp = WARP_SIZE; + int threads = 1024; + threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE; - if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { - // calc needed amount of shared mem for `tokens_cnts` and `cumsum` - // tensors - const int32_t num_thread = max((int32_t)num_experts, WARP_SIZE); - - auto options_int = torch::TensorOptions() - .dtype(torch::kInt) - .device(topk_ids.device()); - torch::Tensor token_cnts_buffer = - torch::empty({(num_experts + 1) * num_experts}, options_int); - torch::Tensor cumsum_buffer = - torch::empty({num_experts + 1}, options_int); - - auto kernel = - vllm::moe::moe_align_block_size_global_mem_kernel; - kernel<<<1, num_thread, 0, stream>>>( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `cumsum` tensors + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); + torch::Tensor cumsum_buffer = + torch::zeros({num_experts + 1}, options_int); + bool small_batch_expert_mode = + (topk_ids.numel() < 1024) && (num_experts <= 64); + + if (small_batch_expert_mode) { + const int32_t threads = max((int32_t)num_experts, WARP_SIZE); + const int32_t shared_mem_size = + ((threads + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + auto small_batch_expert_kernel = + vllm::moe::moe_align_block_size_small_batch_expert_kernel< + scalar_t>; + small_batch_expert_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), token_cnts_buffer.data_ptr(), - cumsum_buffer.data_ptr()); - }); - } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - // set dynamic shared mem - auto kernel = - vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem_i16)); - kernel<<<1, num_thread, shared_mem_i16, stream>>>( + topk_ids.numel()); + } else { + auto align_kernel = vllm::moe::moe_align_block_size_kernel; + + size_t num_warps = CEILDIV(padded_num_experts, experts_per_warp); + size_t shared_mem_size = + num_warps * experts_per_warp * sizeof(int32_t); + + align_kernel<<<1, threads, shared_mem_size, stream>>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); - } else { - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { - auto kernel = - vllm::moe::moe_align_block_size_kernel; - AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( - (void*)kernel, shared_mem_i32)); - kernel<<<1, num_thread, shared_mem_i32, stream>>>( + num_tokens_post_pad.data_ptr(), num_experts, + padded_num_experts, experts_per_warp, block_size, + topk_ids.numel(), cumsum_buffer.data_ptr()); + + const int block_threads = std::min(256, (int)threads); + const int num_blocks = + (topk_ids.numel() + block_threads - 1) / block_threads; + const int max_blocks = 65535; + const int actual_blocks = std::min(num_blocks, max_blocks); + + auto sort_kernel = + vllm::moe::count_and_sort_expert_tokens_kernel; + sort_kernel<<>>( topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel()); - }); - } -} - -void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad) { - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - TORCH_CHECK(num_experts == 256, - "sgl_moe_align_block_size kernel only supports deepseek v3."); - - VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( - topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { - // calc needed amount of shared mem for `cumsum` tensors - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device()); - torch::Tensor cumsum_buffer = - torch::zeros({num_experts + 1}, options_int); - - auto align_kernel = - vllm::moe::sgl_moe_align_block_size_kernel; - align_kernel<<<1, 1024, 0, stream>>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - experts_ids.data_ptr(), - num_tokens_post_pad.data_ptr(), num_experts, block_size, - topk_ids.numel(), cumsum_buffer.data_ptr()); - - const int block_threads = 256; - const int num_blocks = - (topk_ids.numel() + block_threads - 1) / block_threads; - const int max_blocks = 65535; - const int actual_blocks = std::min(num_blocks, max_blocks); - auto sort_kernel = vllm::moe::sgl_moe_token_sort_kernel; - sort_kernel<<>>( - topk_ids.data_ptr(), sorted_token_ids.data_ptr(), - cumsum_buffer.data_ptr(), topk_ids.numel()); + cumsum_buffer.data_ptr(), topk_ids.numel()); + } }); } @@ -423,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size] torch::Tensor& output) // [num_tokens, hidden_size] { const int hidden_size = input.size(-1); - const int num_tokens = output.numel() / hidden_size; + const auto num_tokens = output.numel() / hidden_size; const int topk = input.size(1); dim3 grid(num_tokens); diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index c4faef731060a6d60fdfacb4d5f307bcf8d4e452..661730c96867edd7b8962a3756ca42210bf9608f 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -12,12 +12,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, int64_t block_size, torch::Tensor sorted_token_ids, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); - -void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, - int64_t block_size, - torch::Tensor sorted_token_ids, - torch::Tensor experts_ids, - torch::Tensor num_tokens_post_pad); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 68f429fac18ab892bd8a7ebc7369841512f96773..a77471a7f207884668d943b171f24f5b2a66bc24 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -12,7 +12,7 @@ void moe_permute( const torch::Tensor& input, // [n_token, hidden] const torch::Tensor& topk_weights, //[n_token, topk] torch::Tensor& topk_ids, // [n_token, topk] - const torch::Tensor& token_expert_indicies, // [n_token, topk] + const torch::Tensor& token_expert_indices, // [n_token, topk] const std::optional& expert_map, // [n_expert] int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, @@ -27,15 +27,15 @@ void moe_permute( "expert_first_token_offset must be int64"); TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, "topk_ids must be int32"); - TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, - "token_expert_indicies must be int32"); + TORCH_CHECK(token_expert_indices.scalar_type() == at::ScalarType::Int, + "token_expert_indices must be int32"); TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, "src_row_id2dst_row_id_map must be int32"); TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, "expert_first_token_offset shape != n_local_expert+1") TORCH_CHECK( - src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), - "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); + src_row_id2dst_row_id_map.sizes() == token_expert_indices.sizes(), + "token_expert_indices shape must be same as src_row_id2dst_row_id_map"); auto n_token = input.sizes()[0]; auto n_hidden = input.sizes()[1]; auto align_block_size_value = @@ -71,7 +71,7 @@ void moe_permute( expert_map_ptr, n_expert, stream); } // expert sort topk expert id and scan expert id get expert_first_token_offset - sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indicies), + sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indices), get_ptr(permuted_experts_id), get_ptr(dst_row_id2src_row_id_map), get_ptr(expert_first_token_offset), n_token, @@ -190,7 +190,7 @@ void shuffle_rows(const torch::Tensor& input_tensor, void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indicies, + const torch::Tensor& token_expert_indices, const std::optional& expert_map, int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, @@ -203,7 +203,7 @@ void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, void moe_unpermute(const torch::Tensor& input, const torch::Tensor& topk_weights, torch::Tensor& topk_ids, - const torch::Tensor& token_expert_indicies, + const torch::Tensor& token_expert_indices, const std::optional& expert_map, int64_t n_expert, int64_t n_local_expert, int64_t topk, const std::optional& align_block_size, diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl index 42441800fb1107fe0b1a97f9d9a7d1e205ec1452..ad0d390665a00aa0e292ca9d449193688e6ebeb1 100644 --- a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -20,7 +20,6 @@ __global__ void expandInputRowsKernel( int expert_id = sorted_experts[expanded_dest_row]; extern __shared__ int64_t smem_expert_first_token_offset[]; - int64_t align_expanded_row_accumulate = 0; if constexpr (ALIGN_BLOCK_SIZE) { // load g2s for (int idx = threadIdx.x; idx < num_local_experts + 1; @@ -63,7 +62,6 @@ __global__ void expandInputRowsKernel( using DataElem = cutlass::Array; // Duplicate and permute rows - int64_t const source_k_rank = expanded_source_row / num_rows; int64_t const source_row = expanded_source_row % num_rows; auto const* source_row_ptr = @@ -160,7 +158,6 @@ __global__ void finalizeMoeRoutingKernel( elem_index += stride) { ComputeElem thread_output; thread_output.fill(0); - float row_rescale{0.f}; for (int k_idx = 0; k_idx < k; ++k_idx) { int64_t const expanded_original_row = original_row + k_idx * num_rows; int64_t const expanded_permuted_row = @@ -177,8 +174,6 @@ __global__ void finalizeMoeRoutingKernel( auto const* expanded_permuted_rows_row_ptr = expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; - int64_t const expert_idx = expert_for_source_row[k_offset]; - ComputeElem expert_result = arrayConvert( expanded_permuted_rows_row_ptr[elem_index]); thread_output = thread_output + row_scale * (expert_result); diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index 10be47966f61189e995d45a8010a126edbfcc34c..064b76c9cd42730c4ff681938945980493988231 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -425,7 +425,7 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f #define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ topkGatingSoftmaxLauncherHelper( \ - gating_output, nullptr, topk_weights, topk_indicies, \ + gating_output, nullptr, topk_weights, topk_indices, \ token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); @@ -433,7 +433,7 @@ template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - IndType* topk_indicies, + IndType* topk_indices, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -476,7 +476,7 @@ void topkGatingSoftmaxKernelLauncher( moeSoftmax<<>>( gating_output, nullptr, softmax_workspace, num_experts); moeTopK<<>>( - softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, + softmax_workspace, nullptr, topk_weights, topk_indices, token_expert_indices, num_experts, topk, 0, num_experts); } } @@ -492,7 +492,7 @@ void topk_softmax( torch::Tensor& gating_output) // [num_tokens, num_experts] { const int num_experts = gating_output.size(-1); - const int num_tokens = gating_output.numel() / num_experts; + const auto num_tokens = gating_output.numel() / num_experts; const int topk = topk_weights.size(-1); const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index a74eb3720cf1cf48433813da0a21ab0d79c87521..97df311d04409c5d288a364d4515788ba60cc196 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -22,15 +22,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " Tensor! num_tokens_post_pad) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); - // temporarily adapted from - // https://github.com/sgl-project/sglang/commit/ded9fcd09a43d5e7d5bb31a2bc3e9fc21bf65d2a - m.def( - "sgl_moe_align_block_size(Tensor topk_ids, int num_experts," - " int block_size, Tensor! sorted_token_ids," - " Tensor! experts_ids," - " Tensor! num_tokens_post_pad) -> ()"); - m.impl("sgl_moe_align_block_size", torch::kCUDA, &sgl_moe_align_block_size); - #ifndef USE_ROCM m.def( "moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, " @@ -66,7 +57,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," - "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," + "Tensor token_expert_indices, Tensor? expert_map, int n_expert," "int n_local_expert," "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " diff --git a/csrc/ops.h b/csrc/ops.h index 6b3d50ae8bfd85ae22286814a8e6a8b01da2a297..1190db2ab8e74a1085c8af98fd2685fc91706fed 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -362,3 +362,14 @@ std::tuple allocate_shared_buffer_and_handle( int64_t size); int64_t open_mem_handle(torch::Tensor& mem_handle); void free_shared_buffer(int64_t buffer); + +#ifdef USE_ROCM +fptr_t init_custom_qr(int64_t rank, int64_t world_size, + std::optional qr_max_size = std::nullopt); +void qr_destroy(fptr_t _fa); +torch::Tensor qr_get_handle(fptr_t _fa); +void qr_open_handles(fptr_t _fa, const std::vector& handles); +void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + int64_t quant_level, bool cast_bf2half = false); +int64_t qr_max_size(); +#endif \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index fea4bc2ca0d8fadfc6595da321342d81e9fd7a66..3d5077d9de46126ae5c3d2f08ce8e8b24ef8e044 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -274,7 +274,6 @@ void advance_step_flashinfer( cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); - [[maybe_unused]] int block_tables_stride = block_tables.stride(0); TORCH_CHECK((blocks * threads > num_queries), "multi-step: not enough threads to map to num_queries = ", num_queries, " block_tables.stride(0) = ", block_tables.stride(0), diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index bf46cce60a233909205140dfa66680f75a7720c6..5cd2ac179768b37cb84f5490351d7b43b4cd0597 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -1,15 +1,17 @@ #include #include + #include #include "../../dispatch_utils.h" +#include "../vectorization_utils.cuh" #ifndef USE_ROCM - #include #include + #include #else - #include #include + #include #endif static inline __device__ int8_t float_to_int8_rn(float x) { @@ -103,134 +105,172 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { namespace vllm { -template +template __global__ void static_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) / scale); - } + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) / scale); + }); } -template +template __global__ void static_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type const* scale_ptr, azp_type const* azp_ptr, - const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - scale_type const scale = *scale_ptr; - azp_type const azp = *azp_ptr; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + const scale_t* scale_ptr, const azp_t* azp_ptr, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; + const float scale = *scale_ptr; + const azp_t azp = *azp_ptr; + const float inv_s = 1.0f / scale; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[i] = quant_val; - } + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; + + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } -template +template __global__ void dynamic_scaled_int8_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, const int hidden_size) { - int const tid = threadIdx.x; - int64_t const token_idx = blockIdx.x; - float absmax_val = 0.0f; - float const zero = 0.0f; + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[i]); - val = val > zero ? val : -val; - absmax_val = val > absmax_val ? val : absmax_val; - } - - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - float const block_absmax_val_maybe = - BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); - __shared__ float block_absmax_val; + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; + + // calculate for absmax + float thread_max = 0.f; + vectorize_read_with_alignment<16>( + row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) { + const float v = fabsf(static_cast(src)); + thread_max = fmaxf(thread_max, v); + }); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; + float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x); + __shared__ float absmax; if (tid == 0) { - block_absmax_val = block_absmax_val_maybe; - scale[token_idx] = block_absmax_val / 127.0f; + absmax = block_max; + scale_out[blockIdx.x] = absmax / 127.f; } __syncthreads(); - float const tmp_scale = 127.0f / block_absmax_val; - for (int i = tid; i < hidden_size; i += blockDim.x) { - out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); - } + float inv_s = (absmax == 0.f) ? 0.f : 127.f / absmax; + + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + dst = float_to_int8_rn(static_cast(src) * inv_s); + }); } -template -__global__ void dynamic_scaled_int8_azp_quant_kernel( - scalar_t const* __restrict__ input, int8_t* __restrict__ out, - scale_type* scale, azp_type* azp, const int hidden_size) { - int64_t const token_idx = blockIdx.x; +// MinMax structure to hold min and max values in one go +struct MinMax { + float min, max; - // Must be performed using 64-bit math to avoid integer overflow. - out += token_idx * hidden_size; - input += token_idx * hidden_size; - - // Scan for the min and max value for this token - float max_val = std::numeric_limits::min(); - float min_val = std::numeric_limits::max(); - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[i]); - max_val = std::max(max_val, val); - min_val = std::min(min_val, val); + __host__ __device__ MinMax() + : min(std::numeric_limits::max()), + max(std::numeric_limits::lowest()) {} + + __host__ __device__ explicit MinMax(float v) : min(v), max(v) {} + + // add a value to the MinMax + __host__ __device__ MinMax& operator+=(float v) { + min = fminf(min, v); + max = fmaxf(max, v); + return *this; } - // Reduce the max and min values across the block - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage reduceStorage; - max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); - __syncthreads(); // Make sure min doesn't mess with max shared memory - min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); - - __shared__ scale_type scale_sh; - __shared__ azp_type azp_sh; - - // Compute the scale and zero point and store them, only on the first thread - if (threadIdx.x == 0) { - float const scale_val = (max_val - min_val) / 255.0f; - // Use rounding to even (same as torch.round) - auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); - auto const azp_val = static_cast(azp_float); - - // Store the scale and azp into shared and global - scale[token_idx] = scale_sh = scale_val; - azp[token_idx] = azp_sh = azp_val; + // merge two MinMax objects + __host__ __device__ MinMax& operator&=(const MinMax& other) { + min = fminf(min, other.min); + max = fmaxf(max, other.max); + return *this; } +}; - // Wait for the scale and azp to be computed - __syncthreads(); +__host__ __device__ inline MinMax operator+(MinMax a, float v) { + return a += v; +} +__host__ __device__ inline MinMax operator&(MinMax a, const MinMax& b) { + return a &= b; +} - float const scale_val = scale_sh; - azp_type const azp_val = azp_sh; +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + const scalar_t* __restrict__ input, int8_t* __restrict__ output, + scale_t* scale_out, azp_t* azp_out, const int hidden_size) { + const int tid = threadIdx.x; + const int stride = blockDim.x; + const int64_t token_idx = blockIdx.x; - // Quantize the values - for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[i]); - auto const quant_val = - int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[i] = quant_val; + // Must be performed using 64-bit math to avoid integer overflow. + const scalar_t* row_in = input + token_idx * hidden_size; + int8_t* row_out = output + token_idx * hidden_size; + + // 1. calculate min & max + MinMax thread_mm; + vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride, + [&] __device__(const scalar_t& src) { + thread_mm += static_cast(src); + }); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp; + + MinMax mm = BlockReduce(tmp).Reduce( + thread_mm, + [] __device__(MinMax a, const MinMax& b) { + a &= b; + return a; + }, + blockDim.x); + + __shared__ float scale_sh; + __shared__ azp_t azp_sh; + if (tid == 0) { + float s = (mm.max - mm.min) / 255.f; + float zp = nearbyintf(-128.f - mm.min / s); // round-to-even + scale_sh = s; + azp_sh = azp_t(zp); + scale_out[blockIdx.x] = s; + azp_out[blockIdx.x] = azp_sh; } + __syncthreads(); + + const float inv_s = 1.f / scale_sh; + const azp_t azp = azp_sh; + + // 2. quantize + vectorize_with_alignment<16>( + row_in, row_out, hidden_size, tid, stride, + [=] __device__(int8_t& dst, const scalar_t& src) { + const auto v = static_cast(src) * inv_s; + dst = int32_to_int8(float_to_int32_rn(v) + azp); + }); } } // namespace vllm @@ -247,7 +287,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { @@ -278,7 +318,7 @@ void dynamic_scaled_int8_quant( int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; dim3 const grid(num_tokens); - dim3 const block(std::min(hidden_size, 1024)); + dim3 const block(std::min(hidden_size, 256)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index 8f4df836bcc8d6cc9ac23b01590493084780d663..2d67da98763e034f015bba8609e60b511b4e0468 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -51,7 +51,8 @@ struct cutlass_3x_gemm { // These are the minimum alignments needed for the kernels to compile static constexpr int AlignmentAB = 128 / cutlass::sizeof_bits::value; - static constexpr int AlignmentCD = 4; + static constexpr int AlignmentCD = + 128 / cutlass::sizeof_bits::value; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 { Shape, CollectiveMainloop, CollectiveEpilogue, void>; }; +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm120 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index c1242fdb39da9c58b1d4fbd674631862cecf7446..e049a5f2d2c9a35c9c63f5d9960d17511c5062fe 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 6da2da63407590b9a4bd80dac08db1552e764609..24564efbd21be8db286d58632c1f65c2c79f81ea 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,11 +15,11 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { - // M in (128, inf) + // M in (256, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_256, _128, _64>; + using TileShape = Shape<_256, _128, _128>; using ClusterShape = Shape<_2, _2, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100 typename Epilogue> -struct sm100_fp8_config_M128 { - // M in (64, 128] +struct sm100_fp8_config_M256 { + // M in (64, 256] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_128, _128, _64>; - using ClusterShape = Shape<_2, _2, _1>; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_2, _1, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; @@ -43,12 +43,26 @@ struct sm100_fp8_config_M128 { template typename Epilogue> struct sm100_fp8_config_M64 { - // M in [1, 64] + // M in (16, 64] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _1, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M16 { + // M in [1, 16] static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; - using TileShape = Shape<_64, _64, _256>; - using ClusterShape = Shape<_1, _8, _1>; + using TileShape = Shape<_64, _64, _128>; + using ClusterShape = Shape<_1, _4, _1>; using Cutlass3xGemm = cutlass_3x_gemm_sm100; @@ -68,25 +82,31 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; + using Cutlass3xGemmM16 = + typename sm100_fp8_config_M16::Cutlass3xGemm; using Cutlass3xGemmM64 = typename sm100_fp8_config_M64::Cutlass3xGemm; - using Cutlass3xGemmM128 = - typename sm100_fp8_config_M128::Cutlass3xGemm; + using Cutlass3xGemmM256 = + typename sm100_fp8_config_M256::Cutlass3xGemm; uint32_t const m = a.size(0); uint32_t const mp2 = - std::max(static_cast(64), next_pow_2(m)); // next power of 2 + std::max(static_cast(16), next_pow_2(m)); // next power of 2 - if (mp2 <= 64) { - // m in [1, 64] + if (mp2 <= 16) { + // m in [1, 16] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 64) { + // m in (16, 64] return cutlass_gemm_caller( out, a, b, std::forward(args)...); - } else if (mp2 <= 128) { - // m in (64, 128] - return cutlass_gemm_caller( + } else if (mp2 <= 256) { + // m in (64, 256] + return cutlass_gemm_caller( out, a, b, std::forward(args)...); } else { - // m in (128, inf) + // m in (256, inf) return cutlass_gemm_caller( out, a, b, std::forward(args)...); } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu new file mode 100644 index 0000000000000000000000000000000000000000..bc816cbdf86e5363f164f06f0e5bdbc372ab03c2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh new file mode 100644 index 0000000000000000000000000000000000000000..c31f96bf7c0e237bd3c80c9bdffdbf5df915b4d1 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" + +/** + * This file defines Gemm kernel configurations for SM120 (fp8) based on the + * Gemm shape. + */ + +namespace vllm { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm120_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _128>; + using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1> + using Cutlass3xGemm = + cutlass_3x_gemm_sm120; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm120_fp8_config_default::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); +} + +template