diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index a378bc6baa5a5e8b4327654ffd3d595445f1365a..e29881fcbac0175b7cbcd93c82fbecd8d9d59b59 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -8,12 +8,12 @@ import zipfile # Note that we have 400 MiB quota, please use it wisely. # See https://github.com/pypi/support/issues/3792 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) def print_top_10_largest_files(zip_file): """Print the top 10 largest files in the given zip file.""" - with zipfile.ZipFile(zip_file, 'r') as z: + with zipfile.ZipFile(zip_file, "r") as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: @@ -28,14 +28,18 @@ def check_wheel_size(directory): wheel_path = os.path.join(root, file_name) wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) if wheel_size_mb > VLLM_MAX_SIZE_MB: - print(f"Not allowed: Wheel {wheel_path} is larger " - f"({wheel_size_mb:.2f} MB) than the limit " - f"({VLLM_MAX_SIZE_MB} MB).") + print( + f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB)." + ) print_top_10_largest_files(wheel_path) return 1 else: - print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb:.2f} MB).") + print( + f"Wheel {wheel_path} is within the allowed size " + f"({wheel_size_mb:.2f} MB)." + ) return 0 @@ -45,4 +49,4 @@ if __name__ == "__main__": sys.exit(1) directory = sys.argv[1] - sys.exit(check_wheel_size(directory)) \ No newline at end of file + sys.exit(check_wheel_size(directory)) diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 36e1b6c01326aa136e3cbb3cf2f585697f77a50e..270663c415c7206b82bde00377da3f45ecc08b70 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -22,5 +22,5 @@ with open("index.html", "w") as f: print(f"Generated index.html for {args.wheel}") # cloudfront requires escaping the '+' character f.write( - template.format(wheel=filename, - wheel_html_escaped=filename.replace("+", "%2B"))) + template.format(wheel=filename, wheel_html_escaped=filename.replace("+", "%2B")) + ) diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml new file mode 100644 index 0000000000000000000000000000000000000000..cca58097e8aa6d9ecc76418a81bfbc6cc39b4642 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3.2-1B-Instruct-FP8-compressed-tensors.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Llama-3.2-1B-Instruct-FP8 -b "auto" -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Llama-3.2-1B-Instruct-FP8" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.335 + - name: "exact_match,flexible-extract" + value: 0.323 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54579a63a9b864ac08937c88d3a189a01282f5ad --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-1.5B-Instruct.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m Qwen/Qwen2.5-1.5B-Instruct -b auto -l 1319 -f 5 -t 1 +model_name: "Qwen/Qwen2.5-1.5B-Instruct" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.54 + - name: "exact_match,flexible-extract" + value: 0.59 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2f235f485815848db1106809d714fd519f24e60 --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic -b auto -l 1319 -f 5 -t 1 +model_name: "RedHatAI/Qwen2.5-VL-3B-Instruct-FP8-Dynamic" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.47 + - name: "exact_match,flexible-extract" + value: 0.64 +limit: 1319 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-large.txt b/.buildkite/lm-eval-harness/configs/models-large.txt index 37eeac85c933b8a5a077364d0566772d2c592208..27a1a9a82bd352623c44728e4480ee47209bd9f0 100644 --- a/.buildkite/lm-eval-harness/configs/models-large.txt +++ b/.buildkite/lm-eval-harness/configs/models-large.txt @@ -3,3 +3,4 @@ Meta-Llama-3-70B-Instruct.yaml Mixtral-8x7B-Instruct-v0.1.yaml Qwen2-57B-A14-Instruct.yaml DeepSeek-V2-Lite-Chat.yaml +Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 254d01edf84492f05946dc99d3fad420450d0cea..36e0543879b3810c464a4d036d1d12b417086e39 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,10 +1,6 @@ -Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml +Qwen2.5-1.5B-Instruct.yaml Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml -Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml +Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml Qwen1.5-MoE-W4A16-compressed-tensors.yaml -Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml -Qwen2-1.5B-Instruct-FP8W8.yaml -Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..769d2efda4adc494cd9e78074b30b4a721cb279a --- /dev/null +++ b/.buildkite/lm-eval-harness/conftest.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--config-list-file", + action="store", + help="Path to the file listing model config YAMLs (one per line)", + ) + parser.addoption( + "--tp-size", + action="store", + default="1", + help="Tensor parallel size to use for evaluation", + ) + + +@pytest.fixture(scope="session") +def config_list_file(pytestconfig, config_dir): + rel_path = pytestconfig.getoption("--config-list-file") + return config_dir / rel_path + + +@pytest.fixture(scope="session") +def tp_size(pytestconfig): + return pytestconfig.getoption("--tp-size") + + +def pytest_generate_tests(metafunc): + if "config_filename" in metafunc.fixturenames: + rel_path = metafunc.config.getoption("--config-list-file") + config_list_file = Path(rel_path).resolve() + config_dir = config_list_file.parent + with open(config_list_file, encoding="utf-8") as f: + configs = [ + config_dir / line.strip() + for line in f + if line.strip() and not line.startswith("#") + ] + metafunc.parametrize("config_filename", configs) diff --git a/.buildkite/lm-eval-harness/run-tests.sh b/.buildkite/lm-eval-harness/run-tests.sh deleted file mode 100644 index 26f33b744289a46ce636c9b7c502ce7dba62de4d..0000000000000000000000000000000000000000 --- a/.buildkite/lm-eval-harness/run-tests.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash - -usage() { - echo`` - echo "Runs lm eval harness on GSM8k using vllm and compares to " - echo "precomputed baseline (measured by HF transformers.)" - echo - echo "usage: ${0} " - echo - echo " -c - path to the test data config (e.g. configs/small-models.txt)" - echo " -t - tensor parallel size" - echo -} - -SUCCESS=0 - -while getopts "c:t:" OPT; do - case ${OPT} in - c ) - CONFIG="$OPTARG" - ;; - t ) - TP_SIZE="$OPTARG" - ;; - \? ) - usage - exit 1 - ;; - esac -done - -# Parse list of configs. -IFS=$'\n' read -d '' -r -a MODEL_CONFIGS < "$CONFIG" - -for MODEL_CONFIG in "${MODEL_CONFIGS[@]}" -do - LOCAL_SUCCESS=0 - - echo "=== RUNNING MODEL: $MODEL_CONFIG WITH TP SIZE: $TP_SIZE===" - - export LM_EVAL_TEST_DATA_FILE=$PWD/configs/${MODEL_CONFIG} - export LM_EVAL_TP_SIZE=$TP_SIZE - pytest -s test_lm_eval_correctness.py || LOCAL_SUCCESS=$? - - if [[ $LOCAL_SUCCESS == 0 ]]; then - echo "=== PASSED MODEL: ${MODEL_CONFIG} ===" - else - echo "=== FAILED MODEL: ${MODEL_CONFIG} ===" - fi - - SUCCESS=$((SUCCESS + LOCAL_SUCCESS)) - -done - -if [ "${SUCCESS}" -eq "0" ]; then - exit 0 -else - exit 1 -fi diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 6015a83e829504b0e9a9c4c41f96f7c48a747034..409a6ca82008243ace99a3c4dc735305cdb1730c 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -3,67 +3,52 @@ LM eval harness on model to compare vs HF baseline computed offline. Configs are found in configs/$MODEL.yaml -* export LM_EVAL_TEST_DATA_FILE=configs/Meta-Llama-3-70B-Instruct.yaml -* export LM_EVAL_TP_SIZE=4 -* pytest -s test_lm_eval_correctness.py +pytest -s -v test_lm_eval_correctness.py \ + --config-list-file=configs/models-small.txt \ + --tp-size=1 """ -import os -from pathlib import Path - import lm_eval -import numpy -import pytest +import numpy as np import yaml RTOL = 0.08 -TEST_DATA_FILE = os.environ.get( - "LM_EVAL_TEST_DATA_FILE", - ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") - -TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) - -def launch_lm_eval(eval_config): - trust_remote_code = eval_config.get('trust_remote_code', False) - - model_args = f"pretrained={eval_config['model_name']}," \ - f"tensor_parallel_size={TP_SIZE}," \ - f"add_bos_token=true," \ - f"trust_remote_code={trust_remote_code}" +def launch_lm_eval(eval_config, tp_size): + trust_remote_code = eval_config.get("trust_remote_code", False) + model_args = ( + f"pretrained={eval_config['model_name']}," + f"tensor_parallel_size={tp_size}," + f"enforce_eager=true," + f"add_bos_token=true," + f"trust_remote_code={trust_remote_code}" + ) results = lm_eval.simple_evaluate( model="vllm", model_args=model_args, tasks=[task["name"] for task in eval_config["tasks"]], num_fewshot=eval_config["num_fewshot"], limit=eval_config["limit"], - batch_size="auto") - + batch_size="auto", + ) return results -def test_lm_eval_correctness(): - eval_config = yaml.safe_load( - Path(TEST_DATA_FILE).read_text(encoding="utf-8")) - - if eval_config[ - "model_name"] == "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform": #noqa: E501 - pytest.skip("FBGEMM is currently failing on main.") +def test_lm_eval_correctness_param(config_filename, tp_size): + eval_config = yaml.safe_load(config_filename.read_text(encoding="utf-8")) - # Launch eval requests. - results = launch_lm_eval(eval_config) + results = launch_lm_eval(eval_config, tp_size) - # Confirm scores match ground truth. success = True for task in eval_config["tasks"]: for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] - print(f'{task["name"]} | {metric["name"]}: ' - f'ground_truth={ground_truth} | measured={measured_value}') - success = success and numpy.isclose( - ground_truth, measured_value, rtol=RTOL) + print( + f"{task['name']} | {metric['name']}: " + f"ground_truth={ground_truth} | measured={measured_value}" + ) + success = success and np.isclose(ground_truth, measured_value, rtol=RTOL) - # Assert at the end, print all scores even on failure for debugging. assert success diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 1030ec24e8d7fa9fe2742067e33f6d47e2acefda..7f2a2d8dc2969275bd0739da8dd8c976ad728c42 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -65,18 +65,18 @@ def read_markdown(file): def results_to_json(latency, throughput, serving): - return json.dumps({ - 'latency': latency.to_dict(), - 'throughput': throughput.to_dict(), - 'serving': serving.to_dict() - }) + return json.dumps( + { + "latency": latency.to_dict(), + "throughput": throughput.to_dict(), + "serving": serving.to_dict(), + } + ) if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -120,7 +120,8 @@ if __name__ == "__main__": for perc in [10, 25, 50, 75, 90, 99]: # Multiply 1000 to convert the time unit from s to ms raw_result.update( - {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]}) + {f"P{perc}": 1000 * raw_result["percentiles"][str(perc)]} + ) raw_result["avg_latency"] = raw_result["avg_latency"] * 1000 # add the result to raw_result @@ -153,26 +154,27 @@ if __name__ == "__main__": serving_results = pd.DataFrame.from_dict(serving_results) throughput_results = pd.DataFrame.from_dict(throughput_results) - raw_results_json = results_to_json(latency_results, throughput_results, - serving_results) + raw_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) # remapping the key, for visualization purpose if not latency_results.empty: - latency_results = latency_results[list( - latency_column_mapping.keys())].rename( - columns=latency_column_mapping) + latency_results = latency_results[list(latency_column_mapping.keys())].rename( + columns=latency_column_mapping + ) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) if not throughput_results.empty: - throughput_results = throughput_results[list( - throughput_results_column_mapping.keys())].rename( - columns=throughput_results_column_mapping) + throughput_results = throughput_results[ + list(throughput_results_column_mapping.keys()) + ].rename(columns=throughput_results_column_mapping) - processed_results_json = results_to_json(latency_results, - throughput_results, - serving_results) + processed_results_json = results_to_json( + latency_results, throughput_results, serving_results + ) for df in [latency_results, serving_results, throughput_results]: if df.empty: @@ -184,38 +186,39 @@ if __name__ == "__main__": # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", # we want to turn it into "8xGPUTYPE" df["GPU"] = df["GPU"].apply( - lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}" + ) # get markdown tables - latency_md_table = tabulate(latency_results, - headers='keys', - tablefmt='pipe', - showindex=False) - serving_md_table = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) - throughput_md_table = tabulate(throughput_results, - headers='keys', - tablefmt='pipe', - showindex=False) + latency_md_table = tabulate( + latency_results, headers="keys", tablefmt="pipe", showindex=False + ) + serving_md_table = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) + throughput_md_table = tabulate( + throughput_results, headers="keys", tablefmt="pipe", showindex=False + ) # document the result with open(results_folder / "benchmark_results.md", "w") as f: - - results = read_markdown("../.buildkite/nightly-benchmarks/" + - "performance-benchmarks-descriptions.md") + results = read_markdown( + "../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md" + ) results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, serving_tests_markdown_table=serving_md_table, - benchmarking_results_in_json_string=processed_results_json) + benchmarking_results_in_json_string=processed_results_json, + ) f.write(results) # document benchmarking results in json with open(results_folder / "benchmark_results.json", "w") as f: - - results = latency_results.to_dict( - orient='records') + throughput_results.to_dict( - orient='records') + serving_results.to_dict(orient='records') + results = ( + latency_results.to_dict(orient="records") + + throughput_results.to_dict(orient="records") + + serving_results.to_dict(orient="records") + ) f.write(json.dumps(results)) diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py index 5e17b79d26a1ba4c735d9c61252d859c14e7eed2..778a3a8d87f63f3aba83d21d2a36b8159e7f81b9 100644 --- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py @@ -14,15 +14,12 @@ def main(model, cachedir): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Download and save Hugging Face tokenizer") - parser.add_argument("--model", - type=str, - required=True, - help="Name of the model") - parser.add_argument("--cachedir", - type=str, - required=True, - help="Directory to save the tokenizer") + description="Download and save Hugging Face tokenizer" + ) + parser.add_argument("--model", type=str, required=True, help="Name of the model") + parser.add_argument( + "--cachedir", type=str, required=True, help="Directory to save the tokenizer" + ) args = parser.parse_args() main(args.model, args.cachedir) diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py index 0ff95a0911b16d57e7137fab28ca0ebca90113e4..10a7a2f5a467e7ba9fefc08914a4af033b89b163 100644 --- a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -11,33 +11,33 @@ from tabulate import tabulate def parse_arguments(): parser = argparse.ArgumentParser( - description= - 'Parse command line arguments for summary-nightly-results script.') - parser.add_argument('--results-folder', - type=str, - required=True, - help='The folder where the results are stored.') - parser.add_argument('--description', - type=str, - required=True, - help='Description of the results.') + description="Parse command line arguments for summary-nightly-results script." + ) + parser.add_argument( + "--results-folder", + type=str, + required=True, + help="The folder where the results are stored.", + ) + parser.add_argument( + "--description", type=str, required=True, help="Description of the results." + ) args = parser.parse_args() return args def get_perf(df, method, model, metric): - means = [] for qps in [2, 4, 8, 16, "inf"]: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - target = target & df['Test name'].str.contains("qps_" + str(qps)) + target = df["Test name"].str.contains(model) + target = target & df["Engine"].str.contains(method) + target = target & df["Test name"].str.contains("qps_" + str(qps)) filtered_df = df[target] if filtered_df.empty: - means.append(0.) + means.append(0.0) else: means.append(filtered_df[metric].values[0]) @@ -45,7 +45,6 @@ def get_perf(df, method, model, metric): def get_perf_w_std(df, method, model, metric): - if metric in ["TTFT", "ITL"]: mean = get_perf(df, method, model, "Mean " + metric + " (ms)") mean = mean.tolist() @@ -60,7 +59,8 @@ def get_perf_w_std(df, method, model, metric): else: assert metric == "Tput" mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( - df, method, model, "Output Tput (tok/s)") + df, method, model, "Output Tput (tok/s)" + ) mean = mean.tolist() std = None @@ -80,18 +80,17 @@ def main(args): # generate markdown table df = pd.DataFrame.from_dict(results) - md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + md_table = tabulate(df, headers="keys", tablefmt="pipe", showindex=False) with open(args.description) as f: description = f.read() - description = description.format( - nightly_results_benchmarking_table=md_table) + description = description.format(nightly_results_benchmarking_table=md_table) with open("nightly_results.md", "w") as f: f.write(description) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 62ee5e10b5095fcdc2ea177450f163aa0102b33c..2a7b37991f31a0e4a553e3f3b300b4cc37d19da4 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -34,10 +34,8 @@ serving_column_mapping = { } if __name__ == "__main__": - # collect results for test_file in results_folder.glob("*.json"): - with open(test_file) as f: raw_result = json.loads(f.read()) @@ -56,17 +54,16 @@ if __name__ == "__main__": serving_results = pd.DataFrame.from_dict(serving_results) if not serving_results.empty: - serving_results = serving_results[list( - serving_column_mapping.keys())].rename( - columns=serving_column_mapping) + serving_results = serving_results[list(serving_column_mapping.keys())].rename( + columns=serving_column_mapping + ) - serving_md_table_with_headers = tabulate(serving_results, - headers='keys', - tablefmt='pipe', - showindex=False) + serving_md_table_with_headers = tabulate( + serving_results, headers="keys", tablefmt="pipe", showindex=False + ) # remove the first line of header - serving_md_table_lines = serving_md_table_with_headers.split('\n') - serving_md_table_without_header = '\n'.join(serving_md_table_lines[2:]) + serving_md_table_lines = serving_md_table_with_headers.split("\n") + serving_md_table_without_header = "\n".join(serving_md_table_lines[2:]) prefix = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") prefix = prefix + "_" + os.environ.get("CURRENT_LLM_SERVING_ENGINE") @@ -76,10 +73,9 @@ if __name__ == "__main__": # document results with header. # for those who wants to reproduce our benchmark. f.write(serving_md_table_with_headers) - f.write('\n') + f.write("\n") # document benchmarking results in json with open(results_folder / f"{prefix}_nightly_results.json", "w") as f: - - results = serving_results.to_dict(orient='records') + results = serving_results.to_dict(orient="records") f.write(json.dumps(results)) diff --git a/.buildkite/pyproject.toml b/.buildkite/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..083bb795caf5af92021fc58de7644406444248e7 --- /dev/null +++ b/.buildkite/pyproject.toml @@ -0,0 +1,51 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 +exclude = [ + # External file, leaving license intact + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" +] + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.format] +docstring-code-format = true diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index a21a657c4b05e742f1a60b9b00cbccf175154fca..2118cf4595eba8120f11cc4b052af13a117e594e 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,20 +1,20 @@ steps: - - label: "Build wheel - CUDA 12.4" + - label: "Build wheel - CUDA 12.8" agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" env: DOCKER_BUILDKIT: "1" - - label: "Build wheel - CUDA 12.1" + - label: "Build wheel - CUDA 12.6" agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.6.3 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -31,7 +31,7 @@ steps: agents: queue: cpu_queue_postmerge commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --build-arg torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0+PTX' --tag vllm-ci:build-image --target build --progress plain -f docker/Dockerfile ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" - "bash .buildkite/scripts/upload-wheels.sh" @@ -48,7 +48,7 @@ steps: queue: cpu_queue_postmerge commands: - "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7" - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.4.0 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" - label: "Build and publish TPU release image" @@ -57,6 +57,8 @@ steps: agents: queue: tpu_queue_postmerge commands: + - "yes | docker system prune -a" + - "git fetch --all" - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ." - "docker push vllm/vllm-tpu:nightly" - "docker push vllm/vllm-tpu:$BUILDKITE_COMMIT" diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 368f30434aa1d3c29029db2b444e8a27a2c4bfc4..bbc896ec68190b5b05b47be6d0a6c8e1c4d8ef7d 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -3,6 +3,9 @@ # This script runs test inside the corresponding ROCm docker container. set -o pipefail +# Export Python path +export PYTHONPATH=".." + # Print ROCm version echo "--- Confirming Clean Initial State" while true; do @@ -74,38 +77,69 @@ HF_MOUNT="/root/.cache/huggingface" commands=$@ echo "Commands:$commands" + +if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} +fi + +if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then + commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} +fi + +if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then + commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} +fi + +if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then + commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} +fi + #ignore certain kernels tests -if [[ $commands == *" kernels "* ]]; then +if [[ $commands == *" kernels/core"* ]]; then commands="${commands} \ - --ignore=kernels/test_attention_selector.py \ - --ignore=kernels/test_blocksparse_attention.py \ - --ignore=kernels/test_causal_conv1d.py \ - --ignore=kernels/test_cutlass.py \ - --ignore=kernels/test_encoder_decoder_attn.py \ - --ignore=kernels/test_flash_attn.py \ - --ignore=kernels/test_flashinfer.py \ - --ignore=kernels/test_int8_quant.py \ - --ignore=kernels/test_machete_gemm.py \ - --ignore=kernels/test_mamba_ssm.py \ - --ignore=kernels/test_marlin_gemm.py \ - --ignore=kernels/test_moe.py \ - --ignore=kernels/test_prefix_prefill.py \ - --ignore=kernels/test_rand.py \ - --ignore=kernels/test_sampler.py \ - --ignore=kernels/test_cascade_flash_attn.py \ - --ignore=kernels/test_mamba_mixer2.py \ - --ignore=kernels/test_aqlm.py \ - --ignore=kernels/test_machete_mm.py \ - --ignore=kernels/test_mha_attn.py \ - --ignore=kernels/test_block_fp8.py \ - --ignore=kernels/test_cutlass_moe.py \ - --ignore=kernels/test_mamba_ssm_ssd.py \ - --ignore=kernels/test_attention.py \ - --ignore=kernels/test_block_int8.py \ - --ignore=kernels/test_fused_quant_layernorm.py \ - --ignore=kernels/test_int8_kernel.py \ - --ignore=kernels/test_triton_moe_ptpc_fp8.py \ - --ignore=kernels/test_permute_cols.py" + --ignore=kernels/core/test_fused_quant_layernorm.py \ + --ignore=kernels/core/test_permute_cols.py" +fi + +if [[ $commands == *" kernels/attention"* ]]; then + commands="${commands} \ + --ignore=kernels/attention/stest_attention_selector.py \ + --ignore=kernels/attention/test_blocksparse_attention.py \ + --ignore=kernels/attention/test_encoder_decoder_attn.py \ + --ignore=kernels/attention/test_attention_selector.py \ + --ignore=kernels/attention/test_flash_attn.py \ + --ignore=kernels/attention/test_flashinfer.py \ + --ignore=kernels/attention/test_prefix_prefill.py \ + --ignore=kernels/attention/test_cascade_flash_attn.py \ + --ignore=kernels/attention/test_mha_attn.py \ + --ignore=kernels/attention/test_lightning_attn.py \ + --ignore=kernels/attention/test_attention.py" +fi + +if [[ $commands == *" kernels/quantization"* ]]; then + commands="${commands} \ + --ignore=kernels/quantization/test_int8_quant.py \ + --ignore=kernels/quantization/test_aqlm.py \ + --ignore=kernels/quantization/test_machete_mm.py \ + --ignore=kernels/quantization/test_block_fp8.py \ + --ignore=kernels/quantization/test_block_int8.py \ + --ignore=kernels/quantization/test_marlin_gemm.py \ + --ignore=kernels/quantization/test_cutlass_scaled_mm.py \ + --ignore=kernels/quantization/test_int8_kernel.py" +fi + +if [[ $commands == *" kernels/mamba"* ]]; then + commands="${commands} \ + --ignore=kernels/mamba/test_mamba_mixer2.py \ + --ignore=kernels/mamba/test_causal_conv1d.py \ + --ignore=kernels/mamba/test_mamba_ssm_ssd.py" +fi + +if [[ $commands == *" kernels/moe"* ]]; then + commands="${commands} \ + --ignore=kernels/moe/test_moe.py \ + --ignore=kernels/moe/test_cutlass_moe.py \ + --ignore=kernels/moe/test_triton_moe_ptpc_fp8.py" fi #ignore certain Entrypoints/openai tests @@ -147,6 +181,8 @@ fi PARALLEL_JOB_COUNT=8 +MYPYTHONPATH=".." + # check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. if [[ $commands == *"--shard-id="* ]]; then # assign job count as the number of shards used @@ -167,6 +203,7 @@ if [[ $commands == *"--shard-id="* ]]; then -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}_${GPU}" \ "${image_name}" \ /bin/bash -c "${commands_gpu}" \ @@ -197,6 +234,7 @@ else -e AWS_SECRET_ACCESS_KEY \ -v "${HF_CACHE}:${HF_MOUNT}" \ -e "HF_HOME=${HF_MOUNT}" \ + -e "PYTHONPATH=${MYPYTHONPATH}" \ --name "${container_name}" \ "${image_name}" \ /bin/bash -c "${commands}" diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 5d863dd82e9b88276c341a763de359bdd90ec055..077bd9914907945d5a99f964eda7377a3ed71294 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -32,9 +32,12 @@ function cpu_tests() { set -e pip install pytest pytest-asyncio einops peft Pillow soundfile transformers_stream_generator matplotlib pip install sentence-transformers datamodel_code_generator - pytest -v -s tests/models/embedding/language/test_cls_models.py::test_classification_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/embedding/language/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5] - pytest -v -s tests/models/encoder_decoder/language -m cpu_model" + pytest -v -s tests/models/language/generation/test_bart.py -m cpu_model + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-openai-community/gpt2] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] + pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] + pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] + pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]" } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 21982b01b9cc7783f9c40312e46e3c4162eea71d..2d375d7e9d8711502bfc104737972ee78fc482c3 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -xue +set -xu # Build the docker image. docker build -f docker/Dockerfile.tpu -t vllm-tpu . @@ -24,31 +24,80 @@ docker run --privileged --net host --shm-size=16G -it \ && export VLLM_XLA_CHECK_RECOMPILATION=1 \ && echo HARDWARE \ && tpu-info \ - && echo TEST_0 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_perf.py \ - && echo TEST_1 \ - && pytest -v -s /workspace/vllm/tests/tpu/test_compilation.py \ - && echo TEST_2 \ - && pytest -v -s /workspace/vllm/tests/v1/tpu/test_basic.py \ - && echo TEST_3 \ - && pytest -v -s /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine \ - && echo TEST_4 \ - && pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py \ - && echo TEST_5 \ - && python3 /workspace/vllm/examples/offline_inference/tpu.py \ - && echo TEST_6 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/worker/test_tpu_model_runner.py \ - && echo TEST_7 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py \ - && echo TEST_8 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py \ - && echo TEST_9 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py \ - && echo TEST_10 \ - && pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py \ - && echo TEST_11 \ - && pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" \ - + && { \ + echo TEST_0: Running test_perf.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_perf.py; \ + echo TEST_0_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_1: Running test_compilation.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_compilation.py; \ + echo TEST_1_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_2: Running test_basic.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_basic.py; \ + echo TEST_2_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_3: Running test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + python3 -m pytest -s -v /workspace/vllm/tests/entrypoints/llm/test_accuracy.py::test_lm_eval_accuracy_v1_engine; \ + echo TEST_3_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_4: Running test_quantization_accuracy.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_quantization_accuracy.py; \ + echo TEST_4_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_5: Running examples/offline_inference/tpu.py; \ + python3 /workspace/vllm/examples/offline_inference/tpu.py; \ + echo TEST_5_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_6: Running test_tpu_model_runner.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/worker/test_tpu_model_runner.py; \ + echo TEST_6_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_7: Running test_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_sampler.py; \ + echo TEST_7_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_8: Running test_topk_topp_sampler.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_topk_topp_sampler.py; \ + echo TEST_8_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_9: Running test_multimodal.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_multimodal.py; \ + echo TEST_9_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_10: Running test_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py; \ + echo TEST_10_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_11: Running test_struct_output_generate.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py; \ + echo TEST_11_EXIT_CODE: \$?; \ + } & \ + { \ + echo TEST_12: Running test_moe_pallas.py; \ + python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py; \ + echo TEST_12_EXIT_CODE: \$?; \ + } & \ + # Disable the TPU LoRA tests until the feature is activated + # & { \ + # echo TEST_13: Running test_moe_pallas.py; \ + # python3 -m pytest -s -v /workspace/vllm/tests/tpu/lora/; \ + # echo TEST_13_EXIT_CODE: \$?; \ + # } & \ + wait \ + && echo 'All tests have attempted to run. Check logs for individual test statuses and exit codes.' \ +" # TODO: This test fails because it uses RANDOM_SEED sampling # && VLLM_USE_V1=1 pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py \ diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index a681f892706002add0d74b8c7588637bc54b0786..037897e53dbef42d43fc7656c3d6caf8c18fbe9a 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -50,11 +50,11 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/$BUILDKITE_COMMIT/" if [[ $normal_wheel == *"cu118"* ]]; then # if $normal_wheel matches cu118, do not upload the index.html echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu121"* ]]; then - # if $normal_wheel matches cu121, do not upload the index.html - echo "Skipping index files for cu121 wheels" +elif [[ $normal_wheel == *"cu126"* ]]; then + # if $normal_wheel matches cu126, do not upload the index.html + echo "Skipping index files for cu126 wheels" else - # only upload index.html for cu124 wheels (default wheels) + # only upload index.html for cu128 wheels (default wheels) aws s3 cp index.html "s3://vllm-wheels/$BUILDKITE_COMMIT/vllm/index.html" aws s3 cp "s3://vllm-wheels/nightly/index.html" "s3://vllm-wheels/$BUILDKITE_COMMIT/index.html" fi @@ -66,12 +66,13 @@ aws s3 cp "$normal_wheel" "s3://vllm-wheels/nightly/" if [[ $normal_wheel == *"cu118"* ]]; then # if $normal_wheel matches cu118, do not upload the index.html echo "Skipping index files for cu118 wheels" -elif [[ $normal_wheel == *"cu121"* ]]; then - # if $normal_wheel matches cu121, do not upload the index.html - echo "Skipping index files for cu121 wheels" +elif [[ $normal_wheel == *"cu126"* ]]; then + # if $normal_wheel matches cu126, do not upload the index.html + echo "Skipping index files for cu126 wheels" else - # only upload index.html for cu124 wheels (default wheels) + # only upload index.html for cu128 wheels (default wheels) aws s3 cp index.html "s3://vllm-wheels/nightly/vllm/index.html" fi -aws s3 cp "$wheel" "s3://vllm-wheels/$version/" \ No newline at end of file +aws s3 cp "$wheel" "s3://vllm-wheels/$version/" +aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 20d858cb15a1169c05e1cf034f68960a40e5e491..461fb6d30c45e7ef55b8608d71c7eb99dd890d74 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -32,6 +32,7 @@ steps: ##### fast check tests ##### - label: Documentation Build # 2min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/test_docs/docs" fast_check: true no_gpu: True @@ -39,9 +40,10 @@ steps: - pip install -r ../../requirements/docs.txt - SPHINXOPTS=\"-W\" make html # Check API reference (if it fails, you may have missing mock imports) - - grep \"sig sig-object py\" build/html/api/inference_params.html + - grep \"sig sig-object py\" build/html/api/vllm/vllm.sampling_params.html - label: Async Engine, Inputs, Utils, Worker Test # 24min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/mq_llm_engine @@ -62,6 +64,7 @@ steps: - pytest -v -s worker # Worker - label: Python-only Installation Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - tests/standalone_tests/python_only_compile.sh - setup.py @@ -69,7 +72,7 @@ steps: - bash standalone_tests/python_only_compile.sh - label: Basic Correctness Test # 30min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true torch_nightly: true source_file_dependencies: @@ -86,6 +89,7 @@ steps: - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Chunked Prefill Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/basic_correctness/test_chunked_prefill @@ -94,7 +98,7 @@ steps: - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - label: Core Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] fast_check: true source_file_dependencies: - vllm/core @@ -104,10 +108,10 @@ steps: - pytest -v -s core - label: Entrypoints Test # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" fast_check: true torch_nightly: true - #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/entrypoints/llm @@ -126,6 +130,7 @@ steps: - VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests - label: Distributed Tests (4 GPUs) # 10min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -143,6 +148,8 @@ steps: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py @@ -153,12 +160,12 @@ steps: # TODO: create a dedicated test section for multi-GPU example tests # when we have multiple distributed example tests - pushd ../examples/offline_inference - - python3 rlhf.py - - RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py - popd - label: Metrics, Tracing Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 2 source_file_dependencies: - vllm/ @@ -172,7 +179,7 @@ steps: ##### 1 GPU test ##### - label: Regression Test # 5min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/test_regression @@ -182,7 +189,7 @@ steps: working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/engine @@ -196,7 +203,7 @@ steps: - pytest -v -s tokenization - label: V1 Test - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/v1 @@ -209,8 +216,8 @@ steps: - pytest -v -s v1/worker - pytest -v -s v1/structured_output - pytest -v -s v1/spec_decode + - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py # TODO: accuracy does not match, whether setting @@ -221,8 +228,8 @@ steps: - pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine - label: Examples Test # 25min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/examples" - #mirror_hardwares: [amd] source_file_dependencies: - vllm/entrypoints - examples/ @@ -246,7 +253,7 @@ steps: - VLLM_USE_V1=0 python3 offline_inference/profiling.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/prefix_caching @@ -254,6 +261,7 @@ steps: - pytest -v -s prefix_caching - label: Samplers Test # 36min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -264,7 +272,7 @@ steps: - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers - label: LogitsProcessor Test # 5min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers - vllm/model_executor/guided_decoding @@ -275,6 +283,7 @@ steps: - pytest -v -s model_executor/test_guided_processors.py - label: Speculative decoding tests # 40min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/spec_decode - tests/spec_decode @@ -285,7 +294,7 @@ steps: - pytest -v -s spec_decode/e2e/test_eagle_correctness.py - label: LoRA Test %N # 15min each - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/lora - tests/lora @@ -293,15 +302,20 @@ steps: parallelism: 4 - label: PyTorch Compilation Unit Tests + mirror_hardwares: [amdexperimental, amdproduction] + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile commands: - pytest -v -s compile/test_pass_manager.py - pytest -v -s compile/test_fusion.py + - pytest -v -s compile/test_silu_mul_quant_fusion.py - pytest -v -s compile/test_sequence_parallelism.py - label: PyTorch Fullgraph Smoke Test # 9min + mirror_hardwares: [amdexperimental, amdproduction] + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile @@ -312,6 +326,8 @@ steps: - pytest -v -s compile/piecewise/test_toy_llama.py - label: PyTorch Fullgraph Test # 18min + mirror_hardwares: [amdexperimental, amdproduction] + torch_nightly: true source_file_dependencies: - vllm/ - tests/compile @@ -319,6 +335,7 @@ steps: - pytest -v -s compile/test_full_graph.py - label: Kernels Core Operation Test + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/ - tests/kernels/core @@ -326,6 +343,7 @@ steps: - pytest -v -s kernels/core - label: Kernels Attention Test %N + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/attention/ - vllm/attention @@ -336,6 +354,7 @@ steps: parallelism: 2 - label: Kernels Quantization Test %N + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - csrc/quantization/ - vllm/model_executor/layers/quantization @@ -345,6 +364,7 @@ steps: parallelism: 2 - label: Kernels MoE Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/moe/ - tests/kernels/moe @@ -353,6 +373,7 @@ steps: - pytest -v -s kernels/moe - label: Kernels Mamba Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/mamba/ - tests/kernels/mamba @@ -360,7 +381,7 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - # mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader @@ -371,37 +392,42 @@ steps: - pytest -v -s tensorizer_loader - label: Benchmarks # 9min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] source_file_dependencies: - benchmarks/ commands: - bash scripts/run-benchmarks.sh - label: Benchmarks CLI Test # 10min + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/ - tests/benchmarks/ commands: - pytest -v -s benchmarks/ -- label: Quantization Test # 33min +- label: Quantization Test + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization - tests/quantization - command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + commands: + - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-small.txt --tp-size=1 - label: OpenAI API correctness + mirror_hardwares: [amdexperimental] source_file_dependencies: - csrc/ - vllm/entrypoints/openai/ @@ -410,6 +436,7 @@ steps: - pytest -s entrypoints/openai/correctness/ - label: Encoder Decoder tests # 5min + mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ - tests/encoder_decoder @@ -417,8 +444,8 @@ steps: - pytest -v -s encoder_decoder - label: OpenAI-Compatible Tool Use # 20 min + mirror_hardwares: [amdexperimental] fast_check: false - #mirror_hardwares: [ amd ] source_file_dependencies: - vllm/ - tests/tool_use @@ -430,92 +457,98 @@ steps: ##### models test ##### - label: Basic Models Test # 24min + mirror_hardwares: [amdexperimental, amdproduction] + torch_nightly: true source_file_dependencies: - vllm/ - tests/models commands: - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py + - pytest -v -s models/test_utils.py + - pytest -v -s models/test_vision.py # V1 Test: https://github.com/vllm-project/vllm/issues/14531 - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'llama4' - VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'plamo2' -- label: Language Models Test (Standard) # 32min - #mirror_hardwares: [amd] +- label: Language Models Test (Standard) + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - - tests/models/decoder_only/language - - tests/models/embedding/language - - tests/models/encoder_decoder/language + - tests/models/language commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - - pip install causal-conv1d - - pytest -v -s models/decoder_only/language -m 'core_model or quant_model' - - pytest -v -s models/embedding/language -m core_model + - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pip freeze | grep -E 'torch' + - pytest -v -s models/language -m core_model -- label: Language Models Test (Extended) # 1h10min +- label: Language Models Test (Extended) + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ - - tests/models/decoder_only/language - - tests/models/embedding/language - - tests/models/encoder_decoder/language + - tests/models/language commands: # Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile. - - pip install causal-conv1d - - pytest -v -s models/decoder_only/language -m 'not core_model and not quant_model' - - pytest -v -s models/embedding/language -m 'not core_model' + - pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8' + - pytest -v -s models/language -m 'not core_model' -- label: Multi-Modal Models Test (Standard) # 40min - #mirror_hardwares: [amd] +- label: Multi-Modal Models Test (Standard) + mirror_hardwares: [amdexperimental] + torch_nightly: true source_file_dependencies: - vllm/ - - tests/models/decoder_only/audio_language - - tests/models/decoder_only/vision_language - - tests/models/embedding/vision_language - - tests/models/encoder_decoder/audio_language - - tests/models/encoder_decoder/vision_language + - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/multimodal - - pytest -v -s models/decoder_only/audio_language -m 'core_model or quant_model' - - pytest -v -s models/decoder_only/vision_language -m 'core_model or quant_model' - - pytest -v -s models/embedding/vision_language -m core_model - - pytest -v -s models/encoder_decoder/audio_language -m core_model - - pytest -v -s models/encoder_decoder/language -m core_model - - pytest -v -s models/encoder_decoder/vision_language -m core_model - - pytest -v -s models/decoder_only/vision_language/test_interleaved.py - -- label: Multi-Modal Models Test (Extended) 1 # 48m + - pip freeze | grep -E 'torch' + - pytest -v -s models/multimodal/processing + - pytest -v -s --ignore models/multimodal/generation/test_whisper.py models/multimodal -m core_model + - cd .. && pytest -v -s tests/models/multimodal/generation/test_whisper.py -m core_model # Otherwise, mp_method="spawn" doesn't work + +- label: Multi-Modal Models Test (Extended) 1 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ - - tests/models/decoder_only/audio_language - - tests/models/decoder_only/vision_language - - tests/models/embedding/vision_language - - tests/models/encoder_decoder/vision_language + - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/decoder_only/audio_language -m 'not core_model and not quant_model' - - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=0) and not core_model and not quant_model' - - pytest -v -s --ignore models/decoder_only/vision_language/test_models.py models/decoder_only/vision_language -m 'not core_model and not quant_model' - - pytest -v -s models/embedding/vision_language -m 'not core_model' - - pytest -v -s models/encoder_decoder/language -m 'not core_model' - - pytest -v -s models/encoder_decoder/vision_language -m 'not core_model' - -- label: Multi-Modal Models Test (Extended) 2 # 38m + - pytest -v -s --ignore models/multimodal/generation/test_common.py --ignore models/multimodal/processing models/multimodal -m 'not core_model' + +- label: Multi-Modal Models Test (Extended) 2 + mirror_hardwares: [amdexperimental] optional: true source_file_dependencies: - vllm/ - - tests/models/decoder_only/vision_language + - tests/models/multimodal commands: - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git - - pytest -v -s models/decoder_only/vision_language/test_models.py -m 'split(group=1) and not core_model and not quant_model' + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' + +- label: Multi-Modal Models Test (Extended) 3 + mirror_hardwares: [amdexperimental, amdproduction] + optional: true + source_file_dependencies: + - vllm/ + - tests/models/multimodal + commands: + - pip install git+https://github.com/TIGER-AI-Lab/Mantis.git + - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' + +- label: Quantized Models Test + mirror_hardwares: [amdexperimental, amdproduction] + source_file_dependencies: + - vllm/model_executor/layers/quantization + - tests/models/quantization + commands: + - pytest -v -s models/quantization # This test is used only in PR development phase to test individual models and should never run on main - label: Custom Models Test - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] optional: true commands: - echo 'Testing custom models...' @@ -527,7 +560,7 @@ steps: ##### multi gpus test ##### - label: Distributed Comm Ops Test # 7min - mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -538,6 +571,7 @@ steps: - pytest -v -s distributed/test_shm_broadcast.py - label: 2 Node Tests (4 GPUs in total) # 16min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 num_nodes: 2 @@ -556,7 +590,7 @@ steps: - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed' - label: Distributed Tests (2 GPUs) # 40min - #mirror_hardwares: [amd] + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -581,9 +615,8 @@ steps: - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' + - pytest models/language -v -s -m 'distributed(num_gpus=2)' + - pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' # test sequence parallel - pytest -v -s distributed/test_sequence_parallel.py # this test fails consistently. @@ -594,13 +627,14 @@ steps: - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown - label: Plugin Tests (2 GPUs) # 40min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: - vllm/plugins/ - tests/plugins/ commands: - # begin platform plugin tests, all the code in-between runs on dummy platform + # begin platform plugin and general plugin tests, all the code in-between runs on dummy platform - pip install -e ./plugins/vllm_add_dummy_platform - pytest -v -s plugins_tests/test_platform_plugins.py - pip uninstall vllm_add_dummy_platform -y @@ -611,8 +645,10 @@ steps: - pytest -v -s distributed/test_distributed_oot.py - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins - label: Multi-step Tests (4 GPUs) # 36min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -633,6 +669,7 @@ steps: - pytest -v -s multi_step/test_correctness_llm.py - label: Pipeline Parallelism Test # 45min + mirror_hardwares: [amdexperimental, amdproduction] working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -646,6 +683,7 @@ steps: - pytest -v -s distributed/test_pipeline_parallel.py - label: LoRA TP Test (Distributed) + mirror_hardwares: [amdexperimental, amdproduction] num_gpus: 4 source_file_dependencies: - vllm/lora @@ -661,6 +699,7 @@ steps: - label: Weight Loading Multiple GPU Test # 33min + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -670,6 +709,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt - label: Weight Loading Multiple GPU Test - Large Models # optional + mirror_hardwares: [amdexperimental] working_dir: "/vllm-workspace/tests" num_gpus: 2 gpu: a100 @@ -708,4 +748,4 @@ steps: - vllm/model_executor/layers/quantization commands: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 diff --git a/.github/ISSUE_TEMPLATE/400-bug-report.yml b/.github/ISSUE_TEMPLATE/400-bug-report.yml index b96ab40749003a9ea6fbcf112567028f9594bfdf..00b0f024c0da5ea39f6fecdcb0e1f8a313a0651d 100644 --- a/.github/ISSUE_TEMPLATE/400-bug-report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug-report.yml @@ -21,12 +21,12 @@ body: It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: |
- The output of `python collect_env.py` + The output of python collect_env.py ```text Your output of `python collect_env.py` here ``` - +
validations: required: true @@ -75,7 +75,7 @@ body: ``` ``` - The error message you got, with the full traceback. + The error message you got, with the full traceback and the error logs with [dump_input.py:##] if present. ``` validations: required: true diff --git a/.github/mergify.yml b/.github/mergify.yml index 15fa3660a87df9154454e47e293a9b0b3b741f22..ccfd571625b54f96fa18a074e90f37e525d6ab8d 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -163,6 +163,17 @@ pull_request_rules: https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork +- name: assign reviewer for tensorizer changes + conditions: + - files~=^vllm/model_executor/model_loader/tensorizer.py + - files~=^vllm/model_executor/model_loader/tensorizer_loader.py + - files~=^tests/entrypoints/openai/test_tensorizer_entrypoint.py + - files~=^tests/tensorizer_loader/ + actions: + assign: + users: + - "sangstar" + - name: remove 'needs-rebase' label when conflict is resolved conditions: - -conflict diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index c9d6d4259df9998413c2ca2e26417ebf4ade12aa..315042fbf5cf44f409fb57923947e9fa3f1791b4 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -1,4 +1,6 @@ name: Add label on auto-merge enabled +permissions: + pull-requests: write on: pull_request_target: types: diff --git a/.github/workflows/lint-and-deploy.yaml b/.github/workflows/lint-and-deploy.yaml index 7b1d9f69938c82258fc9279c047e192d9b1f459f..64011922ad82535803664331d667d55cf3283c02 100644 --- a/.github/workflows/lint-and-deploy.yaml +++ b/.github/workflows/lint-and-deploy.yaml @@ -2,6 +2,9 @@ name: Lint and Deploy Charts on: pull_request +permissions: + contents: read + jobs: lint-and-deploy: runs-on: ubuntu-latest @@ -66,7 +69,7 @@ jobs: export AWS_SECRET_ACCESS_KEY=minioadmin sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" & helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env" - + - name: curl test run: | kubectl -n ns-vllm port-forward service/test-vllm-service 8001:80 & @@ -79,4 +82,4 @@ jobs: "max_tokens": 7, "temperature": 0 }'):$CODE" - echo "$CODE" \ No newline at end of file + echo "$CODE" diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index 6ab63a402770419390d4d251b276e11d0851b6c8..8e694d18134efeebd280976cabe1488b5f37c31d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -5,6 +5,9 @@ on: push: branches: [main] +permissions: + contents: read + jobs: pre-commit: runs-on: ubuntu-latest diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 27318c2fdd93f837bd93608cde7225807fd3c72f..16ae1aadb96be289c4a153dda43772e3586e84cb 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -1,4 +1,6 @@ name: PR Reminder Comment Bot +permissions: + pull-requests: write on: pull_request_target: types: [opened] diff --git a/.gitignore b/.gitignore index 728213ceb74f050cf63f4a0437f0f50658568238..2756c612b82f874615ace1e71d4767f0c7b3f4b2 100644 --- a/.gitignore +++ b/.gitignore @@ -80,6 +80,7 @@ instance/ # Sphinx documentation docs/_build/ docs/source/getting_started/examples/ +docs/source/api/vllm # PyBuilder .pybuilder/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f76b24c025ffb9da08b181ecec0a7188129ba599..f5c0c368d578cb2ac83b5c4535650f4a9b078f29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -12,29 +12,31 @@ repos: - id: yapf args: [--in-place, --verbose] - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.3 + rev: v0.11.7 hooks: - id: ruff args: [--output-format, github, --fix] + - id: ruff-format + files: ^(.buildkite|benchmarks)/.* - repo: https://github.com/codespell-project/codespell - rev: v2.4.0 + rev: v2.4.1 hooks: - id: codespell additional_dependencies: ['tomli'] args: ['--toml', 'pyproject.toml'] - repo: https://github.com/PyCQA/isort - rev: 0a0b7a830386ba6a31c2ec8316849ae4d1b8240d # 6.0.0 + rev: 6.0.1 hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v19.1.7 + rev: v20.1.3 hooks: - id: clang-format exclude: 'csrc/(moe/topk_softmax_kernels.cu|quantization/gguf/(ggml-common.h|dequantize.cuh|vecdotq.cuh|mmq.cuh|mmvq.cuh))|vllm/third_party/.*' types_or: [c++, cuda] args: [--style=file, --verbose] - repo: https://github.com/jackdewinter/pymarkdown - rev: v0.9.27 + rev: v0.9.29 hooks: - id: pymarkdown args: [fix] @@ -43,10 +45,10 @@ repos: hooks: - id: actionlint - repo: https://github.com/astral-sh/uv-pre-commit - rev: 0.6.2 + rev: 0.6.17 hooks: - id: pip-compile - args: [requirements/test.in, -o, requirements/test.txt] + args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128] files: ^requirements/test\.(in|txt)$ - repo: local hooks: @@ -101,8 +103,8 @@ repos: args: - -c - | - if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" .git/COMMIT_EDITMSG; then - printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> .git/COMMIT_EDITMSG + if ! grep -q "^Signed-off-by: $(git config user.name) <$(git config user.email)>" "$(git rev-parse --git-path COMMIT_EDITMSG)"; then + printf "\nSigned-off-by: $(git config user.name) <$(git config user.email)>\n" >> "$(git rev-parse --git-path COMMIT_EDITMSG)" fi language: system verbose: true @@ -125,8 +127,6 @@ repos: name: Update Dockerfile dependency graph entry: tools/update-dockerfile-graph.sh language: script - files: ^docker/Dockerfile$ - pass_filenames: false # Keep `suggestion` last - id: suggestion name: Suggestion diff --git a/CMakeLists.txt b/CMakeLists.txt index e237bd43176404fc3a8010df4264b20708d0d6f1..93b198078537e1c4d58e7d22a85c97a449b75b39 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM") - message(STATUS "Build type: ${CMAKE_BUILD_TYPE}") message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") @@ -46,8 +45,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from docker/Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.6.0") -set(TORCH_SUPPORTED_VERSION_ROCM "2.6.0") +set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0") +set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0") # # Try to find python package with an executable that exactly matches @@ -231,6 +230,7 @@ set(VLLM_EXT_SRC "csrc/attention/paged_attention_v1.cu" "csrc/attention/paged_attention_v2.cu" "csrc/attention/merge_attn_states.cu" + "csrc/attention/vertical_slash_index.cu" "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" @@ -242,6 +242,7 @@ set(VLLM_EXT_SRC # "csrc/quantization/fp8/common.cu" # "csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "csrc/quantization/gguf/gguf_kernel.cu" + "csrc/quantization/activation_kernels.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/custom_all_reduce.cu" @@ -250,9 +251,8 @@ set(VLLM_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") - # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. - # Please keep this in sync with FetchContent_Declare line below. - set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use") + # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. + set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) @@ -270,7 +270,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git # Please keep this in sync with CUTLASS_REVISION line above. - GIT_TAG v3.9.0 + GIT_TAG ${CUTLASS_REVISION} GIT_PROGRESS TRUE # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. @@ -290,6 +290,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu" "csrc/sparse/cutlass/sparse_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp" "csrc/attention/mla/cutlass_mla_entry.cu") @@ -301,10 +302,55 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + # 9.0 for latest bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_ARCHS) + + # + # For the Marlin kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MARLIN_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/gptq_marlin/generate_kernels.py) + file(MD5 ${MARLIN_GEN_SCRIPT} MARLIN_GEN_SCRIPT_HASH) + + message(STATUS "Marlin generation script hash: ${MARLIN_GEN_SCRIPT_HASH}") + message(STATUS "Last run Marlin generate script hash: $CACHE{MARLIN_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MARLIN_GEN_SCRIPT_HASH} + OR NOT $CACHE{MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MARLIN_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=$PYTHONPATH + ${Python_EXECUTABLE} ${MARLIN_GEN_SCRIPT} + RESULT_VARIABLE marlin_generation_result + OUTPUT_VARIABLE marlin_generation_result + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log + ) + + if (NOT marlin_generation_result EQUAL 0) + message(FATAL_ERROR "Marlin generation failed." + " Result: \"${marlin_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/marlin_generation.log") + else() + set(MARLIN_GEN_SCRIPT_HASH ${MARLIN_GEN_SCRIPT_HASH} + CACHE STRING "Last run Marlin generate script hash" FORCE) + message(STATUS "Marlin generation completed successfully.") + endif() + else() + message(STATUS "Marlin generation script has not changed, skipping generation.") + endif() + + file(GLOB MARLIN_TEMPLATE_KERNEL_SRC "csrc/quantization/gptq_marlin/kernel_*.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_TEMPLATE_KERNEL_SRC}" + CUDA_ARCHS "${MARLIN_ARCHS}") + + list(APPEND VLLM_EXT_SRC ${MARLIN_TEMPLATE_KERNEL_SRC}) + set(MARLIN_SRCS - "csrc/quantization/fp8/fp8_marlin.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" @@ -376,6 +422,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu" "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu" ) set_gencode_flags_for_srcs( SRCS "${SRCS}" @@ -400,8 +447,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. + # (Build 8.9 for FP8) cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + "7.5;8.0;8.9+PTX" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -452,7 +500,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") + "csrc/quantization/fp4/nvfp4_experts_quant.cu" + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" + "csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -490,7 +540,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible # to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" "csrc/quantization/cutlass_w8a8/moe/moe_data.cu") @@ -628,7 +678,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + # 9.0 for latest bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # @@ -646,7 +697,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") OR NOT $CACHE{MOE_MARLIN_GEN_SCRIPT_HASH} STREQUAL ${MOE_MARLIN_GEN_SCRIPT_HASH}) execute_process( COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + PYTHONPATH=$PYTHONPATH ${Python_EXECUTABLE} ${MOE_MARLIN_GEN_SCRIPT} RESULT_VARIABLE moe_marlin_generation_result OUTPUT_VARIABLE moe_marlin_generation_output @@ -682,6 +733,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() +if(VLLM_GPU_LANG STREQUAL "CUDA") + set(MOE_PERMUTE_SRC + "csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu" + "csrc/moe/moe_permute_unpermute_op.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_PERMUTE_SRC}" + CUDA_ARCHS "${MOE_PERMUTE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}") +endif() message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C @@ -690,6 +752,8 @@ define_gpu_extension_target( SOURCES ${VLLM_MOE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) diff --git a/README.md b/README.md index dda3ae6009f555eac423e68a5a72b5d0b423ddf1..5b87ae838885c52aa3651758a859d777c2e8753f 100644 --- a/README.md +++ b/README.md @@ -16,18 +16,20 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 +- [2025/05] We hosted [NYC vLLM Meetup](https://lu.ma/c1rqyf1f)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing). +- [2025/05] vLLM is now a hosted project under PyTorch Foundation! Please find the announcement [here](https://pytorch.org/blog/pytorch-foundation-welcomes-vllm/). - [2025/04] We hosted [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). +- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). + +
+Previous News + - [2025/03] We hosted [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama)! Please find the meetup slides from the vLLM team [here](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [2025/03] We hosted [the first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg)! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). - [2025/03] We hosted [the East Coast vLLM Meetup](https://lu.ma/7mu4k4xx)! Please find the meetup slides [here](https://docs.google.com/presentation/d/1NHiv8EUFF1NLd3fEYODm56nDmL26lEeXCaDgyDlTsRs/edit#slide=id.g31441846c39_0_0). - [2025/02] We hosted [the ninth vLLM meetup](https://lu.ma/h7g3kuj9) with Meta! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1jzC_PZVXrVNSFVCW-V4cFXb6pn7zZ2CyP_Flwo05aqg/edit?usp=sharing) and AMD [here](https://drive.google.com/file/d/1Zk5qEJIkTmlQ2eQcXQZlljAx3m9s7nwn/view?usp=sharing). The slides from Meta will not be posted. -- [2025/01] We are excited to announce the alpha release of vLLM V1: A major architectural upgrade with 1.7x speedup! Clean code, optimized execution loop, zero-overhead prefix caching, enhanced multimodal support, and more. Please check out our blog post [here](https://blog.vllm.ai/2025/01/27/v1-alpha-release.html). - [2025/01] We hosted [the eighth vLLM meetup](https://lu.ma/zep56hui) with Google Cloud! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1epVkt4Zu8Jz_S5OhEHPc798emsYh2BwYfRuDDVEF7u4/edit?usp=sharing), and Google Cloud team [here](https://drive.google.com/file/d/1h24pHewANyRL11xy5dXUbvRC9F9Kkjix/view?usp=sharing). - [2024/12] vLLM joins [pytorch ecosystem](https://pytorch.org/blog/vllm-joins-pytorch)! Easy, Fast, and Cheap LLM Serving for Everyone! - -
-Previous News - - [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! - [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! @@ -72,7 +74,7 @@ vLLM is flexible and easy to use with: - OpenAI-compatible API server - Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. - Prefix caching support -- Multi-lora support +- Multi-LoRA support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) diff --git a/benchmarks/auto_tune.sh b/benchmarks/auto_tune.sh new file mode 100644 index 0000000000000000000000000000000000000000..ea63c6f71a6c50ae698b0c9969d38da91896f728 --- /dev/null +++ b/benchmarks/auto_tune.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +# This script aims to tune the best server parameter combinations to maximize throughput for given requirement. +# The current server parameter combination is max_num_seqs and max_num_batched_tokens +# It also supports additional requirement: e2e latency and prefix cache. + +# Pre-requisite: +# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version. +# 2. If the model is customized, replace the MODEL's config with the customized config. +# 3. Set variables (ALL REQUIRED) +# BASE: your directory for vllm repo +# MODEL: the model served by vllm +# DOWNLOAD_DIR: directory to download and load model weights. +# INPUT_LEN: request input len +# OUTPUT_LEN: request output len +# MIN_CACHE_HIT_PCT: prefix cache rate +# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000 +# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens. +# 5. The final result will be saved in RESULT file. + + +# Example use cases +# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput? +# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000 +# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter? +# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500 +# 3. If we want to reach 60% prefix cache, what's the best server parameter? +# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500 + +TAG=$(date +"%Y_%m_%d_%H_%M") +BASE="" +MODEL="meta-llama/Llama-3.1-8B-Instruct" +DOWNLOAD_DIR="" +INPUT_LEN=4000 +OUTPUT_LEN=16 +MIN_CACHE_HIT_PCT_PCT=0 +MAX_LATENCY_ALLOWED_MS=100000000000 + +LOG_FOLDER="$BASE/auto-benchmark/$TAG" +RESULT="$LOG_FOLDER/result.txt" + +echo "result file$ $RESULT" +echo "model: $MODEL" +echo + +rm -rf $LOG_FOLDER +mkdir -p $LOG_FOLDER + +cd "$BASE/vllm" +# create sonnet-4x.txt so that we can sample 2048 tokens for input +echo "" > benchmarks/sonnet_4x.txt +for _ in {1..4} +do +cat benchmarks/sonnet.txt >> benchmarks/sonnet_4x.txt +done + +pip install datasets + +current_hash=$(git rev-parse HEAD) +echo "hash:$current_hash" >> "$RESULT" +echo "current_hash: $current_hash" + +best_throughput=0 +best_max_num_seqs=0 +best_num_batched_tokens=0 +best_goodput=0 +run_benchmark() { + local max_num_seqs=$1 + local max_num_batched_tokens=$2 + echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" + local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" + echo "vllm_log: $vllm_log" + echo + rm -f $vllm_log + + # start the server + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ + --disable-log-requests \ + --port 8004 \ + --gpu-memory-utilization 0.98 \ + --max-num-seqs $max_num_seqs \ + --max-num-batched-tokens $max_num_batched_tokens \ + --tensor-parallel-size 1 \ + --enable-prefix-caching \ + --load-format dummy \ + --download-dir $DOWNLOAD_DIR \ + --max-model-len $(( INPUT_LEN+OUTPUT_LEN )) > "$vllm_log" 2>&1 & + echo "wait for 10 minutes.." + echo + # wait for 10 minutes... + server_started=0 + for i in {1..60}; do + if grep -Fq "Application startup complete" "$vllm_log"; then + echo "Application started" + server_started=1 + break + else + # echo "wait for 10 seconds..." + sleep 10 + fi + done + + if (( ! server_started )); then + echo "server did not start within 10 minutes, terminate the benchmarking. Please check server log at $vllm_log" + echo "pkill -f vllm" + echo + pkill vllm + sleep 10 + return 1 + fi + + echo "run benchmark test..." + echo + meet_latency_requirement=0 + # get a basic qps by using request-rate inf + bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt" + prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 )) + python benchmarks/benchmark_serving.py \ + --backend vllm \ + --model $MODEL \ + --dataset-name sonnet \ + --dataset-path benchmarks/sonnet_4x.txt \ + --sonnet-input-len $INPUT_LEN \ + --sonnet-output-len $OUTPUT_LEN \ + --ignore-eos \ + --disable-tqdm \ + --request-rate inf \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ + --num-prompts 100 \ + --sonnet-prefix-len $prefix_len \ + --port 8004 > "$bm_log" + through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') + goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + + if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then + meet_latency_requirement=1 + fi + + if (( ! meet_latency_requirement )); then + # start from request-rate as int(through_put) + 1 + request_rate=$((${through_put%.*} + 1)) + while ((request_rate > 0)); do + # clear prefix cache + curl -X POST http://0.0.0.0:8004/reset_prefix_cache + sleep 5 + bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt" + python benchmarks/benchmark_serving.py \ + --backend vllm \ + --model $MODEL \ + --dataset-name sonnet \ + --dataset-path benchmarks/sonnet_4x.txt \ + --sonnet-input-len $INPUT_LEN \ + --sonnet-output-len $OUTPUT_LEN \ + --ignore_eos \ + --disable-tqdm \ + --request-rate $request_rate \ + --percentile-metrics ttft,tpot,itl,e2el \ + --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ + --num-prompts 100 \ + --sonnet-prefix-len $prefix_len \ + --port 8004 > "$bm_log" + through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') + goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then + meet_latency_requirement=1 + break + fi + request_rate=$((request_rate-1)) + done + fi + # write the results and update the best result. + if ((meet_latency_requirement)); then + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" >> "$RESULT" + if (( $(echo "$through_put > $best_throughput" | bc -l) )); then + best_throughput=$through_put + best_max_num_seqs=$max_num_seqs + best_num_batched_tokens=$max_num_batched_tokens + best_goodput=$goodput + fi + else + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}" >> "$RESULT" + fi + + echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" + + echo "pkill -f vllm" + echo + pkill vllm + sleep 10 + rm -f $vllm_log + printf '=%.0s' $(seq 1 20) + return 0 +} + + +num_seqs_list="128 256" +num_batched_tokens_list="512 1024 2048 4096" +for num_seqs in $num_seqs_list; do + for num_batched_tokens in $num_batched_tokens_list; do + run_benchmark $num_seqs $num_batched_tokens + exit 0 + done +done +echo "finish permutations" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" +echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT" + diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index efd51c79c37cfff04b91e37f5f123b2a8c489e84..800d426c6d11822faf273667019c17f044855321 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -12,8 +12,7 @@ from typing import Optional, Union import aiohttp import huggingface_hub.constants from tqdm.asyncio import tqdm -from transformers import (AutoTokenizer, PreTrainedTokenizer, - PreTrainedTokenizerFast) +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast # NOTE(simon): do not import vLLM here so the benchmark script # can run without vLLM installed. @@ -43,8 +42,7 @@ class RequestFuncOutput: latency: float = 0.0 output_tokens: int = 0 ttft: float = 0.0 # Time to first token - itl: list[float] = field( - default_factory=list) # list of inter-token latencies + itl: list[float] = field(default_factory=list) # list of inter-token latencies tpot: float = 0.0 # avg next-token latencies prompt_len: int = 0 error: str = "" @@ -57,8 +55,9 @@ async def async_request_tgi( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: params = { "max_new_tokens": request_func_input.output_len, "do_sample": True, @@ -105,8 +104,7 @@ async def async_request_tgi( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -133,8 +131,9 @@ async def async_request_trt_llm( api_url = request_func_input.api_url assert api_url.endswith("generate_stream") - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { "accumulate_tokens": True, "text_input": request_func_input.prompt, @@ -159,8 +158,7 @@ async def async_request_trt_llm( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data:") + chunk = chunk_bytes.decode("utf-8").removeprefix("data:") data = json.loads(chunk) output.generated_text += data["text_output"] @@ -172,8 +170,7 @@ async def async_request_trt_llm( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp @@ -197,10 +194,11 @@ async def async_request_deepspeed_mii( request_func_input: RequestFuncInput, pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: - + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { + "model": request_func_input.model, "prompt": request_func_input.prompt, "max_tokens": request_func_input.output_len, "temperature": 0.01, # deepspeed-mii does not accept 0.0 temp. @@ -216,19 +214,21 @@ async def async_request_deepspeed_mii( st = time.perf_counter() try: - async with session.post(url=request_func_input.api_url, - json=payload) as response: + async with session.post( + url=request_func_input.api_url, json=payload + ) as response: if response.status == 200: parsed_resp = await response.json() output.latency = time.perf_counter() - st if "choices" in parsed_resp: - output.generated_text = parsed_resp["choices"][0][ - "text"] + output.generated_text = parsed_resp["choices"][0]["text"] elif "text" in parsed_resp: output.generated_text = parsed_resp["text"][0] else: - output.error = ("Unexpected response format: " - "neither 'choices' nor 'text' found") + output.error = ( + "Unexpected response format: " + "neither 'choices' nor 'text' found" + ) output.success = False output.success = True else: @@ -249,17 +249,20 @@ async def async_request_openai_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("completions", "profile") - ), "OpenAI Completions API URL must end with 'completions' or 'profile'." + assert api_url.endswith(("completions", "profile")), ( + "OpenAI Completions API URL must end with 'completions' or 'profile'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, + "repetition_penalty": 1.0, "max_tokens": request_func_input.output_len, "logprobs": request_func_input.logprobs, "stream": True, @@ -271,9 +274,7 @@ async def async_request_openai_completions( payload["ignore_eos"] = request_func_input.ignore_eos if request_func_input.extra_body: payload.update(request_func_input.extra_body) - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -282,8 +283,9 @@ async def async_request_openai_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: first_chunk_received = False async for chunk_bytes in response.content: @@ -291,8 +293,7 @@ async def async_request_openai_completions( if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": data = json.loads(chunk) @@ -312,21 +313,20 @@ async def async_request_openai_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) most_recent_timestamp = timestamp generated_text += text or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") if first_chunk_received: output.success = True else: output.success = False output.error = ( "Never received a valid chunk to calculate TTFT." - "This response will be marked as failed!") + "This response will be marked as failed!" + ) output.generated_text = generated_text output.latency = most_recent_timestamp - st else: @@ -347,23 +347,22 @@ async def async_request_openai_chat_completions( pbar: Optional[tqdm] = None, ) -> RequestFuncOutput: api_url = request_func_input.api_url - assert api_url.endswith( - ("chat/completions", "profile") - ), "OpenAI Chat Completions API URL must end with 'chat/completions'." + assert api_url.endswith(("chat/completions", "profile")), ( + "OpenAI Chat Completions API URL must end with 'chat/completions'." + ) - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "messages": [ - { - "role": "user", - "content": content - }, + {"role": "user", "content": content}, ], "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, @@ -389,16 +388,16 @@ async def async_request_openai_chat_completions( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, json=payload, - headers=headers) as response: + async with session.post( + url=api_url, json=payload, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) @@ -412,13 +411,11 @@ async def async_request_openai_chat_completions( # Decoding phase else: - output.itl.append(timestamp - - most_recent_timestamp) + output.itl.append(timestamp - most_recent_timestamp) generated_text += content or "" elif usage := data.get("usage"): - output.output_tokens = usage.get( - "completion_tokens") + output.output_tokens = usage.get("completion_tokens") most_recent_timestamp = timestamp @@ -444,25 +441,28 @@ async def async_request_openai_audio( ) -> RequestFuncOutput: # Lazy import without PlaceholderModule to avoid vllm dep. import soundfile + api_url = request_func_input.api_url - assert api_url.endswith( - ("transcriptions", "translations" - )), "OpenAI Chat Completions API URL must end with 'transcriptions' " + assert api_url.endswith(("transcriptions", "translations")), ( + "OpenAI Chat Completions API URL must end with 'transcriptions' " + ) "or `translations`." - async with aiohttp.ClientSession(trust_env=True, - timeout=AIOHTTP_TIMEOUT) as session: + async with aiohttp.ClientSession( + trust_env=True, timeout=AIOHTTP_TIMEOUT + ) as session: content = [{"type": "text", "text": request_func_input.prompt}] payload = { - "model": request_func_input.model_name \ - if request_func_input.model_name else request_func_input.model, + "model": request_func_input.model_name + if request_func_input.model_name + else request_func_input.model, "temperature": 0.0, "max_completion_tokens": request_func_input.output_len, "stream": True, "language": "en", # Flattened due to multipart/form-data "stream_include_usage": True, - "stream_continuous_usage_stats": True + "stream_continuous_usage_stats": True, } if request_func_input.extra_body: payload.update(request_func_input.extra_body) @@ -477,9 +477,9 @@ async def async_request_openai_audio( buffer.seek(0) return buffer - with to_bytes(*request_func_input.multi_modal_content['audio']) as f: + with to_bytes(*request_func_input.multi_modal_content["audio"]) as f: form = aiohttp.FormData() - form.add_field('file', f, content_type='audio/wav') + form.add_field("file", f, content_type="audio/wav") for key, value in payload.items(): form.add_field(key, str(value)) @@ -491,24 +491,22 @@ async def async_request_openai_audio( st = time.perf_counter() most_recent_timestamp = st try: - async with session.post(url=api_url, - data=form, - headers=headers) as response: + async with session.post( + url=api_url, data=form, headers=headers + ) as response: if response.status == 200: async for chunk_bytes in response.content: chunk_bytes = chunk_bytes.strip() if not chunk_bytes: continue - chunk = chunk_bytes.decode("utf-8").removeprefix( - "data: ") + chunk = chunk_bytes.decode("utf-8").removeprefix("data: ") if chunk != "[DONE]": timestamp = time.perf_counter() data = json.loads(chunk) if choices := data.get("choices"): - content = choices[0]["delta"].get( - "content") + content = choices[0]["delta"].get("content") # First token if ttft == 0.0: ttft = timestamp - st @@ -517,12 +515,14 @@ async def async_request_openai_audio( # Decoding phase else: output.itl.append( - timestamp - most_recent_timestamp) + timestamp - most_recent_timestamp + ) generated_text += content or "" elif usage := data.get("usage"): output.output_tokens = usage.get( - "completion_tokens") + "completion_tokens" + ) most_recent_timestamp = timestamp @@ -543,7 +543,7 @@ async def async_request_openai_audio( def get_model(pretrained_model_name_or_path: str) -> str: - if os.getenv('VLLM_USE_MODELSCOPE', 'False').lower() == 'true': + if os.getenv("VLLM_USE_MODELSCOPE", "False").lower() == "true": from modelscope import snapshot_download from vllm.model_executor.model_loader.weight_utils import get_lock @@ -554,7 +554,8 @@ def get_model(pretrained_model_name_or_path: str) -> str: model_path = snapshot_download( model_id=pretrained_model_name_or_path, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, - ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"], + ) return model_path return pretrained_model_name_or_path @@ -567,23 +568,23 @@ def get_tokenizer( **kwargs, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: if pretrained_model_name_or_path is not None and not os.path.exists( - pretrained_model_name_or_path): - pretrained_model_name_or_path = get_model( - pretrained_model_name_or_path) + pretrained_model_name_or_path + ): + pretrained_model_name_or_path = get_model(pretrained_model_name_or_path) if tokenizer_mode == "slow": if kwargs.get("use_fast", False): - raise ValueError( - "Cannot use the fast tokenizer in slow tokenizer mode.") + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False if tokenizer_mode == "mistral": try: from vllm.transformers_utils.tokenizer import MistralTokenizer except ImportError as e: - raise ImportError("MistralTokenizer requires vllm package.\n" - "Please install it with `pip install vllm` " - "to use mistral tokenizer mode.") from e - return MistralTokenizer.from_pretrained( - str(pretrained_model_name_or_path)) + raise ImportError( + "MistralTokenizer requires vllm package.\n" + "Please install it with `pip install vllm` " + "to use mistral tokenizer mode." + ) from e + return MistralTokenizer.from_pretrained(str(pretrained_model_name_or_path)) else: return AutoTokenizer.from_pretrained( pretrained_model_name_or_path, @@ -606,7 +607,7 @@ ASYNC_REQUEST_FUNCS = { } OPENAI_COMPATIBLE_BACKENDS = [ - k for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, - async_request_openai_chat_completions) + k + for k, v in ASYNC_REQUEST_FUNCS.items() + if v in (async_request_openai_completions, async_request_openai_chat_completions) ] diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index ccbc6c022f1f935576810450d914e7848966be41..d8f48644cc00506b9b5eb4ef5768f9c8ef437b1c 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -82,14 +82,12 @@ class BenchmarkDataset(ABC): self.dataset_path = dataset_path # Set the random seed, ensuring that a None value is replaced with the # default seed. - self.random_seed = (random_seed - if random_seed is not None else self.DEFAULT_SEED) + self.random_seed = random_seed if random_seed is not None else self.DEFAULT_SEED self.data = None def apply_multimodal_chat_transformation( - self, - prompt: str, - mm_content: Optional[MultiModalDataDict] = None) -> list[dict]: + self, prompt: str, mm_content: Optional[MultiModalDataDict] = None + ) -> list[dict]: """ Transform a prompt and optional multimodal content into a chat format. This method is used for chat models that expect a specific conversation @@ -111,8 +109,7 @@ class BenchmarkDataset(ABC): NotImplementedError: If a subclass does not implement this method. """ # TODO (jenniferzhao): add support for downloading data - raise NotImplementedError( - "load_data must be implemented in subclasses.") + raise NotImplementedError("load_data must be implemented in subclasses.") def get_random_lora_request( self, @@ -158,8 +155,9 @@ class BenchmarkDataset(ABC): return lora_request, lora_tokenizer_cache[lora_id] or tokenizer @abstractmethod - def sample(self, tokenizer: PreTrainedTokenizerBase, - num_requests: int) -> list[SampleRequest]: + def sample( + self, tokenizer: PreTrainedTokenizerBase, num_requests: int + ) -> list[SampleRequest]: """ Abstract method to generate sample requests from the dataset. @@ -177,8 +175,9 @@ class BenchmarkDataset(ABC): """ raise NotImplementedError("sample must be implemented in subclasses.") - def maybe_oversample_requests(self, requests: list[SampleRequest], - num_requests: int) -> None: + def maybe_oversample_requests( + self, requests: list[SampleRequest], num_requests: int + ) -> None: """ Oversamples the list of requests if its size is less than the desired number. @@ -189,11 +188,9 @@ class BenchmarkDataset(ABC): """ if len(requests) < num_requests: random.seed(self.random_seed) - additional = random.choices(requests, - k=num_requests - len(requests)) + additional = random.choices(requests, k=num_requests - len(requests)) requests.extend(additional) - logger.info("Oversampled requests to reach %d total samples.", - num_requests) + logger.info("Oversampled requests to reach %d total samples.", num_requests) # ----------------------------------------------------------------------------- @@ -218,14 +215,14 @@ def is_valid_sequence( """ # Check for invalid conditions prompt_too_short = prompt_len < min_len - output_too_short = (not skip_min_output_len_check) and (output_len - < min_len) + output_too_short = (not skip_min_output_len_check) and (output_len < min_len) prompt_too_long = prompt_len > max_prompt_len combined_too_long = (prompt_len + output_len) > max_total_len # Return True if none of the invalid conditions are met - return not (prompt_too_short or output_too_short or prompt_too_long - or combined_too_long) + return not ( + prompt_too_short or output_too_short or prompt_too_long or combined_too_long + ) @cache @@ -257,28 +254,28 @@ def process_image(image: Any) -> Mapping[str, Any]: Raises: ValueError: If the input is not a supported type. """ - if isinstance(image, dict) and 'bytes' in image: - image = Image.open(BytesIO(image['bytes'])) + if isinstance(image, dict) and "bytes" in image: + image = Image.open(BytesIO(image["bytes"])) if isinstance(image, Image.Image): image = image.convert("RGB") with io.BytesIO() as image_data: image.save(image_data, format="JPEG") - image_base64 = base64.b64encode( - image_data.getvalue()).decode("utf-8") + image_base64 = base64.b64encode(image_data.getvalue()).decode("utf-8") return { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - }, + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, } if isinstance(image, str): - image_url = (image if image.startswith( - ("http://", "file://")) else f"file://{image}") + image_url = ( + image if image.startswith(("http://", "file://")) else f"file://{image}" + ) return {"type": "image_url", "image_url": {"url": image_url}} - raise ValueError(f"Invalid image input {image}. Must be a PIL.Image.Image" - " or str or dictionary with raw image bytes.") + raise ValueError( + f"Invalid image input {image}. Must be a PIL.Image.Image" + " or str or dictionary with raw image bytes." + ) # ----------------------------------------------------------------------------- @@ -315,42 +312,56 @@ class RandomDataset(BenchmarkDataset): ) vocab_size = tokenizer.vocab_size + num_special_tokens = tokenizer.num_special_tokens_to_add() + real_input_len = input_len - num_special_tokens - prefix_token_ids = (np.random.randint( - 0, vocab_size, size=prefix_len).tolist() if prefix_len > 0 else []) + prefix_token_ids = ( + np.random.randint(0, vocab_size, size=prefix_len).tolist() + if prefix_len > 0 + else [] + ) # New sampling logic: [X * (1 - b), X * (1 + b)] - input_low = int(input_len * (1 - range_ratio)) - input_high = int(input_len * (1 + range_ratio)) + input_low = int(real_input_len * (1 - range_ratio)) + input_high = int(real_input_len * (1 + range_ratio)) output_low = int(output_len * (1 - range_ratio)) output_high = int(output_len * (1 + range_ratio)) # Add logging for debugging logger.info("Sampling input_len from [%s, %s]", input_low, input_high) - logger.info("Sampling output_len from [%s, %s]", output_low, - output_high) - - input_lens = np.random.randint(input_low, - input_high + 1, - size=num_requests) - output_lens = np.random.randint(output_low, - output_high + 1, - size=num_requests) + logger.info("Sampling output_len from [%s, %s]", output_low, output_high) + + input_lens = np.random.randint(input_low, input_high + 1, size=num_requests) + output_lens = np.random.randint(output_low, output_high + 1, size=num_requests) offsets = np.random.randint(0, vocab_size, size=num_requests) requests = [] for i in range(num_requests): - inner_seq = ((offsets[i] + i + np.arange(input_lens[i])) % - vocab_size).tolist() + inner_seq = ( + (offsets[i] + i + np.arange(input_lens[i])) % vocab_size + ).tolist() token_sequence = prefix_token_ids + inner_seq prompt = tokenizer.decode(token_sequence) + # After decoding the prompt we have to encode and decode it again. + # This is done because in some cases N consecutive tokens + # give a string tokenized into != N number of tokens. + # For example for GPT2Tokenizer: + # [6880, 6881] -> ['Ġcalls', 'here'] -> + # [1650, 939, 486] -> ['Ġcall', 'sh', 'ere'] + # To avoid uncontrolled change of the prompt length, + # the encoded sequence is truncated before being decode again. + re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[ + : input_lens[i] + ] + prompt = tokenizer.decode(re_encoded_sequence) total_input_len = prefix_len + int(input_lens[i]) requests.append( SampleRequest( prompt=prompt, prompt_len=total_input_len, expected_output_len=int(output_lens[i]), - )) + ) + ) return requests @@ -377,7 +388,8 @@ class ShareGPTDataset(BenchmarkDataset): self.data = json.load(f) # Filter entries with at least two conversation turns. self.data = [ - entry for entry in self.data + entry + for entry in self.data if "conversations" in entry and len(entry["conversations"]) >= 2 ] random.seed(self.random_seed) @@ -403,27 +415,28 @@ class ShareGPTDataset(BenchmarkDataset): ) lora_request, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids prompt_len = len(prompt_ids) - new_output_len = (len(completion_ids) - if output_len is None else output_len) - if not is_valid_sequence(prompt_len, - new_output_len, - skip_min_output_len_check=output_len - is not None): + new_output_len = len(completion_ids) if output_len is None else output_len + if not is_valid_sequence( + prompt_len, + new_output_len, + skip_min_output_len_check=output_len is not None, + ): continue if enable_multimodal_chat: - prompt = self.apply_multimodal_chat_transformation( - prompt, None) + prompt = self.apply_multimodal_chat_transformation(prompt, None) samples.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=new_output_len, lora_request=lora_request, - )) + ) + ) self.maybe_oversample_requests(samples, num_requests) return samples @@ -469,20 +482,20 @@ class SonnetDataset(BenchmarkDataset): ) -> list: # Calculate average token length for a poem line. tokenized_lines = [tokenizer(line).input_ids for line in self.data] - avg_len = sum(len(tokens) - for tokens in tokenized_lines) / len(tokenized_lines) + avg_len = sum(len(tokens) for tokens in tokenized_lines) / len(tokenized_lines) # Build the base prompt. base_prompt = "Pick as many lines as you can from these poem lines:\n" base_msg = [{"role": "user", "content": base_prompt}] - base_fmt = tokenizer.apply_chat_template(base_msg, - add_generation_prompt=True, - tokenize=False) + base_fmt = tokenizer.apply_chat_template( + base_msg, add_generation_prompt=True, tokenize=False + ) base_offset = len(tokenizer(base_fmt).input_ids) if input_len <= base_offset: raise ValueError( f"'input_len' must be higher than the base prompt length " - f"({base_offset}).") + f"({base_offset})." + ) # Determine how many poem lines to use. num_input_lines = round((input_len - base_offset) / avg_len) @@ -491,21 +504,23 @@ class SonnetDataset(BenchmarkDataset): samples = [] while len(samples) < num_requests: - extra_lines = random.choices(self.data, - k=num_input_lines - num_prefix_lines) + extra_lines = random.choices( + self.data, k=num_input_lines - num_prefix_lines + ) prompt = f"{base_prompt}{''.join(prefix_lines + extra_lines)}" msg = [{"role": "user", "content": prompt}] prompt_formatted = tokenizer.apply_chat_template( - msg, add_generation_prompt=True, tokenize=False) + msg, add_generation_prompt=True, tokenize=False + ) prompt_len = len(tokenizer(prompt_formatted).input_ids) if prompt_len <= input_len: samples.append( SampleRequest( - prompt=prompt_formatted - if return_prompt_formatted else prompt, + prompt=prompt_formatted if return_prompt_formatted else prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) return samples @@ -525,7 +540,9 @@ class BurstGPTDataset(BenchmarkDataset): super().__init__(**kwargs) self.load_data() - def load_data(self, ): + def load_data( + self, + ): if self.dataset_path is None: raise ValueError("dataset_path must be provided for loading data.") @@ -539,8 +556,7 @@ class BurstGPTDataset(BenchmarkDataset): def _sample_loaded_data(self, num_requests: int) -> list: if num_requests <= len(self.data): - data = self.data.sample(n=num_requests, - random_state=self.random_seed) + data = self.data.sample(n=num_requests, random_state=self.random_seed) else: data = self.data.sample( n=num_requests, @@ -564,7 +580,8 @@ class BurstGPTDataset(BenchmarkDataset): input_len = int(data[i][2]) output_len = int(data[i][3]) lora_req, tokenizer = self.get_random_lora_request( - tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path) + tokenizer=tokenizer, max_loras=max_loras, lora_path=lora_path + ) vocab_size = tokenizer.vocab_size # Generate a synthetic prompt: a list of token IDs computed as (i + # j) modulo vocab_size. @@ -576,7 +593,8 @@ class BurstGPTDataset(BenchmarkDataset): prompt_len=input_len, expected_output_len=output_len, lora_request=lora_req, - )) + ) + ) return samples @@ -619,20 +637,23 @@ class HuggingFaceDataset(BenchmarkDataset): class ConversationDataset(HuggingFaceDataset): """Dataset for conversation data with multimodal support.""" + SUPPORTED_DATASET_PATHS = { - 'lmms-lab/LLaVA-OneVision-Data', 'Aeala/ShareGPT_Vicuna_unfiltered' + "lmms-lab/LLaVA-OneVision-Data", + "Aeala/ShareGPT_Vicuna_unfiltered", } IS_MULTIMODAL = True - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: # Filter examples with at least 2 conversations - filtered_data = self.data.filter( - lambda x: len(x["conversations"]) >= 2) + filtered_data = self.data.filter(lambda x: len(x["conversations"]) >= 2) sampled_requests = [] dynamic_output = output_len is None @@ -648,24 +669,22 @@ class ConversationDataset(HuggingFaceDataset): completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence( - prompt_len, completion_len): + if dynamic_output and not is_valid_sequence(prompt_len, completion_len): continue - mm_content = process_image( - item["image"]) if "image" in item else None + mm_content = process_image(item["image"]) if "image" in item else None if enable_multimodal_chat: # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len and output len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -682,10 +701,8 @@ class VisionArenaDataset(HuggingFaceDataset): DEFAULT_OUTPUT_LEN = 128 SUPPORTED_DATASET_PATHS = { - "lmarena-ai/VisionArena-Chat": - lambda x: x["conversation"][0][0]["content"], - "lmarena-ai/vision-arena-bench-v0.1": - lambda x: x["turns"][0][0]["content"] + "lmarena-ai/VisionArena-Chat": lambda x: x["conversation"][0][0]["content"], + "lmarena-ai/vision-arena-bench-v0.1": lambda x: x["turns"][0][0]["content"], } IS_MULTIMODAL = True @@ -697,16 +714,14 @@ class VisionArenaDataset(HuggingFaceDataset): enable_multimodal_chat: bool = False, **kwargs, ) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: break parser_fn = self.SUPPORTED_DATASET_PATHS.get(self.dataset_path) if parser_fn is None: - raise ValueError( - f"Unsupported dataset path: {self.dataset_path}") + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") prompt = parser_fn(item) mm_content = process_image(item["images"][0]) prompt_len = len(tokenizer(prompt).input_ids) @@ -714,15 +729,15 @@ class VisionArenaDataset(HuggingFaceDataset): # Note: when chat is enabled the request prompt_len is no longer # accurate and we will be using request output to count the # actual prompt len - prompt = self.apply_multimodal_chat_transformation( - prompt, mm_content) + prompt = self.apply_multimodal_chat_transformation(prompt, mm_content) sampled_requests.append( SampleRequest( prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -747,14 +762,15 @@ class InstructCoderDataset(HuggingFaceDataset): "likaixin/InstructCoder", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - enable_multimodal_chat: bool = False, - **kwargs) -> list: - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN sampled_requests = [] for item in self.data: if len(sampled_requests) >= num_requests: @@ -766,7 +782,63 @@ class InstructCoderDataset(HuggingFaceDataset): prompt=prompt, prompt_len=prompt_len, expected_output_len=output_len, - )) + ) + ) + self.maybe_oversample_requests(sampled_requests, num_requests) + return sampled_requests + + +# ----------------------------------------------------------------------------- +# MT-Bench Dataset Implementation +# ----------------------------------------------------------------------------- + + +class MTBenchDataset(HuggingFaceDataset): + """ + MT-Bench Dataset. + https://huggingface.co/datasets/philschmid/mt-bench + + We create a single turn dataset for MT-Bench. + This is similar to Spec decoding benchmark setup in vLLM + https://github.com/vllm-project/vllm/blob/9d98ab5ec/examples/offline_inference/eagle.py#L14-L18 + """ # noqa: E501 + + DEFAULT_OUTPUT_LEN = 256 # avg len used in SD bench in vLLM + SUPPORTED_DATASET_PATHS = { + "philschmid/mt-bench", + } + + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + enable_multimodal_chat: bool = False, + **kwargs, + ) -> list: + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN + sampled_requests = [] + + for item in self.data: + if len(sampled_requests) >= num_requests: + break + prompt = item["turns"][0] + + # apply template + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) + + prompt_len = len(tokenizer(prompt).input_ids) + sampled_requests.append( + SampleRequest( + prompt=prompt, + prompt_len=prompt_len, + expected_output_len=output_len, + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests @@ -780,23 +852,27 @@ class AIMODataset(HuggingFaceDataset): """ Dataset class for processing a AIMO dataset with reasoning questions. """ + SUPPORTED_DATASET_PATHS = { - "AI-MO/aimo-validation-aime", "AI-MO/NuminaMath-1.5", - "AI-MO/NuminaMath-CoT" + "AI-MO/aimo-validation-aime", + "AI-MO/NuminaMath-1.5", + "AI-MO/NuminaMath-CoT", } - def sample(self, - tokenizer: PreTrainedTokenizerBase, - num_requests: int, - output_len: Optional[int] = None, - **kwargs) -> list: + def sample( + self, + tokenizer: PreTrainedTokenizerBase, + num_requests: int, + output_len: Optional[int] = None, + **kwargs, + ) -> list: sampled_requests = [] dynamic_output = output_len is None for item in self.data: if len(sampled_requests) >= num_requests: break - prompt, completion = item['problem'], item["solution"] + prompt, completion = item["problem"], item["solution"] prompt_ids = tokenizer(prompt).input_ids completion_ids = tokenizer(completion).input_ids @@ -804,10 +880,9 @@ class AIMODataset(HuggingFaceDataset): completion_len = len(completion_ids) output_len = completion_len if dynamic_output else output_len assert isinstance(output_len, int) and output_len > 0 - if dynamic_output and not is_valid_sequence(prompt_len, - completion_len, - max_prompt_len=2048, - max_total_len=32000): + if dynamic_output and not is_valid_sequence( + prompt_len, completion_len, max_prompt_len=2048, max_total_len=32000 + ): continue sampled_requests.append( SampleRequest( @@ -815,11 +890,100 @@ class AIMODataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=None, - )) + ) + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests +# ----------------------------------------------------------------------------- +# Next Edit Prediction Dataset Implementation +# ----------------------------------------------------------------------------- + + +zeta_prompt = """### Instruction: +You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location. + +### User Edits: + +{} + +### User Excerpt: + +{} + +### Response: + +""" # noqa: E501 + + +def _format_zeta_prompt( + sample: dict, original_start_marker: str = "<|editable_region_start|>" +) -> dict: + """Format the zeta prompt for the Next Edit Prediction (NEP) dataset. + + This function formats examples from the NEP dataset + into prompts and expected outputs. It could be + further extended to support more NEP datasets. + + Args: + sample: The dataset sample containing events, + inputs, and outputs. + original_start_marker: The marker indicating the + start of the editable region. Defaults to + "<|editable_region_start|>". + + Returns: + A dictionary with the formatted prompts and expected outputs. + """ + events = sample["events"] + input = sample["input"] + output = sample["output"] + prompt = zeta_prompt.format(events, input) + + # following the original implementation, extract the focused region + # from the raw output + output_start_index = output.find(original_start_marker) + output_focused_region = output[output_start_index:] + expected_output = output_focused_region + + return {"prompt": prompt, "expected_output": expected_output} + + +class NextEditPredictionDataset(HuggingFaceDataset): + """ + Dataset class for processing a Next Edit Prediction dataset. + """ + + SUPPORTED_DATASET_PATHS = { + "zed-industries/zeta", + } + MAPPING_PROMPT_FUNCS = { + "zed-industries/zeta": _format_zeta_prompt, + } + + def sample(self, tokenizer: PreTrainedTokenizerBase, num_requests: int, **kwargs): + formatting_prompt_func = self.MAPPING_PROMPT_FUNCS.get(self.dataset_path) + if formatting_prompt_func is None: + raise ValueError(f"Unsupported dataset path: {self.dataset_path}") + samples = [] + for sample in self.data: + sample = formatting_prompt_func(sample) + samples.append( + SampleRequest( + prompt=sample["prompt"], + prompt_len=len(tokenizer(sample["prompt"]).input_ids), + expected_output_len=len( + tokenizer(sample["expected_output"]).input_ids + ), + ) + ) + if len(samples) >= num_requests: + break + self.maybe_oversample_requests(samples, num_requests) + return samples + + # ----------------------------------------------------------------------------- # ASR Dataset Implementation # ----------------------------------------------------------------------------- @@ -842,18 +1006,22 @@ class ASRDataset(HuggingFaceDataset): | AMI | Meetings | Spontaneous | ihm, sdm | +----------------+----------------------------------------+--------------------------+-----------------------------+ - """ # noqa: E501 + """ # noqa: E501 + SUPPORTED_DATASET_PATHS = { - "openslr/librispeech_asr", "facebook/voxpopuli", "LIUM/tedlium", - "edinburghcstr/ami", "speechcolab/gigaspeech", "kensho/spgispeech" + "openslr/librispeech_asr", + "facebook/voxpopuli", + "LIUM/tedlium", + "edinburghcstr/ami", + "speechcolab/gigaspeech", + "kensho/spgispeech", } DEFAULT_OUTPUT_LEN = 128 IS_MULTIMODAL = True # TODO Whisper-specific. Abstract interface when more models are supported. - TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|>"\ - "<|notimestamps|>" + TRANSCRIPTION_PREAMBLE = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" skip_long_audios: bool = True def sample( @@ -864,8 +1032,8 @@ class ASRDataset(HuggingFaceDataset): **kwargs, ) -> list: import librosa - output_len = (output_len - if output_len is not None else self.DEFAULT_OUTPUT_LEN) + + output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN prompt = ASRDataset.TRANSCRIPTION_PREAMBLE prompt_len = len(tokenizer(prompt).input_ids) sampled_requests = [] @@ -888,10 +1056,14 @@ class ASRDataset(HuggingFaceDataset): prompt_len=prompt_len, expected_output_len=output_len, multi_modal_data=mm_content, - )) + ) + ) if skipped: - logger.warning("%d samples discarded from dataset due to" \ - " their length being greater than" \ - " what Whisper supports.", skipped) + logger.warning( + "%d samples discarded from dataset due to" + " their length being greater than" + " what Whisper supports.", + skipped, + ) self.maybe_oversample_requests(sampled_requests, num_requests) return sampled_requests diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index dfd9bb1e6a4d0c4a8961e3d29510a2bea4dfcd78..d5aaceeb8c9c37e95d705710e2d4876cf5133ed9 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,9 +11,9 @@ from typing import Any, Optional import numpy as np import torch -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs from vllm.inputs import PromptType @@ -21,13 +21,14 @@ from vllm.sampling_params import BeamSearchParams from vllm.utils import FlexibleArgumentParser -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={"latency": results["latencies"]}, - extra_info={k: results[k] - for k in ["avg_latency", "percentiles"]}) + extra_info={k: results[k] for k in ["avg_latency", "percentiles"]}, + ) if pt_records: pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" write_to_json(pt_file, pt_records) @@ -42,9 +43,11 @@ def main(args: argparse.Namespace): # the engine will automatically process the request in multiple batches. llm = LLM(**dataclasses.asdict(engine_args)) assert llm.llm_engine.model_config.max_model_len >= ( - args.input_len + - args.output_len), ("Please ensure that max_model_len is greater than" - " the sum of input_len and output_len.") + args.input_len + args.output_len + ), ( + "Please ensure that max_model_len is greater than" + " the sum of input_len and output_len." + ) sampling_params = SamplingParams( n=args.n, @@ -55,18 +58,16 @@ def main(args: argparse.Namespace): detokenize=not args.disable_detokenize, ) print(sampling_params) - dummy_prompt_token_ids = np.random.randint(10000, - size=(args.batch_size, - args.input_len)) - dummy_prompts: list[PromptType] = [{ - "prompt_token_ids": batch - } for batch in dummy_prompt_token_ids.tolist()] + dummy_prompt_token_ids = np.random.randint( + 10000, size=(args.batch_size, args.input_len) + ) + dummy_prompts: list[PromptType] = [ + {"prompt_token_ids": batch} for batch in dummy_prompt_token_ids.tolist() + ] def llm_generate(): if not args.use_beam_search: - llm.generate(dummy_prompts, - sampling_params=sampling_params, - use_tqdm=False) + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) else: llm.beam_search( dummy_prompts, @@ -80,12 +81,13 @@ def main(args: argparse.Namespace): def run_to_completion(profile_dir: Optional[str] = None): if profile_dir: with torch.profiler.profile( - activities=[ - torch.profiler.ProfilerActivity.CPU, - torch.profiler.ProfilerActivity.CUDA, - ], - on_trace_ready=torch.profiler.tensorboard_trace_handler( - str(profile_dir)), + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + on_trace_ready=torch.profiler.tensorboard_trace_handler( + str(profile_dir) + ), ) as p: llm_generate() print(p.key_averages().table(sort_by="self_cuda_time_total")) @@ -103,8 +105,9 @@ def main(args: argparse.Namespace): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = (Path(".") / "vllm_benchmark_result" / - f"latency_result_{time.time()}") + profile_dir = ( + Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + ) print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=profile_dir) return @@ -135,7 +138,8 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( description="Benchmark the latency of processing a single batch of " - "requests till completion.") + "requests till completion." + ) parser.add_argument("--input-len", type=int, default=32) parser.add_argument("--output-len", type=int, default=128) parser.add_argument("--batch-size", type=int, default=8) @@ -152,10 +156,9 @@ if __name__ == "__main__": default=10, help="Number of iterations to run for warmup.", ) - parser.add_argument("--num-iters", - type=int, - default=30, - help="Number of iterations to run.") + parser.add_argument( + "--num-iters", type=int, default=30, help="Number of iterations to run." + ) parser.add_argument( "--profile", action="store_true", @@ -165,8 +168,10 @@ if __name__ == "__main__": "--profile-result-dir", type=str, default=None, - help=("path to save the pytorch profiler output. Can be visualized " - "with ui.perfetto.dev or Tensorboard."), + help=( + "path to save the pytorch profiler output. Can be visualized " + "with ui.perfetto.dev or Tensorboard." + ), ) parser.add_argument( "--output-json", @@ -177,8 +182,10 @@ if __name__ == "__main__": parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 21480578edbd5212b9b362ee4d2fa336f44fccef..109624c877891c87d042abdbc5785e61b97a0924 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -76,7 +76,7 @@ def repeat_prompts(prompts, repeat_count, mode: str): - 'random': Shuffle the prompts randomly after repetition. - 'tile': Repeat the entire prompt list in sequence. Example: [1, 2, 3] -> [1, 2, 3, 1, 2, 3]. - - 'interleave': Repeat each prompt consecutively before moving to + - 'interleave': Repeat each prompt consecutively before moving to the next. Example: [1, 2, 3] -> [1, 1, 2, 2, 3, 3]. Returns: @@ -86,20 +86,21 @@ def repeat_prompts(prompts, repeat_count, mode: str): ValueError: If an invalid mode is provided. """ print("Repeat mode: ", mode) - if mode == 'random': + if mode == "random": repeated_prompts = prompts * repeat_count random.shuffle(repeated_prompts) return repeated_prompts - elif mode == 'tile': + elif mode == "tile": return prompts * repeat_count - elif mode == 'interleave': + elif mode == "interleave": repeated_prompts = [] for prompt in prompts: repeated_prompts.extend([prompt] * repeat_count) return repeated_prompts else: - raise ValueError(f"Invalid mode: {mode}, only support " - "'random', 'tile', 'interleave'") + raise ValueError( + f"Invalid mode: {mode}, only support 'random', 'tile', 'interleave'" + ) def main(args): @@ -109,16 +110,16 @@ def main(args): # we append the document id at the beginning to avoid any of the document # being the prefix of other documents prompts = [ - str(i) + ' '.join(['hi'] * args.document_length) + str(i) + " ".join(["hi"] * args.document_length) for i in range(args.num_documents) ] prompts = repeat_prompts(prompts, args.repeat_count, mode=args.repeat_mode) warmup_prompts = [ - "This is warm up request " + str(i) + \ - ' '.join(['hi'] * args.document_length) - for i in range(args.num_documents)] + "This is warm up request " + str(i) + " ".join(["hi"] * args.document_length) + for i in range(args.num_documents) + ] # Create the LLM engine engine_args = EngineArgs.from_cli_args(args) @@ -142,42 +143,52 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') + description="Benchmark the performance with or " + "without automatic prefix caching." + ) parser.add_argument( - '--document-length', + "--document-length", type=int, # Roughly the number of tokens for a system paper, # excluding images default=20000, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--num-documents', - type=int, - default=8, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') - - parser.add_argument('--output-len', type=int, default=10) - - parser.add_argument('--repeat-count', - type=int, - default=2, - help='Number of times to repeat each prompt') - - parser.add_argument("--repeat-mode", - type=str, - default='random', - help='The mode to repeat prompts. The supported ' - 'modes are "random", "tile", and "interleave". ' - 'See repeat_prompts() in the source code for details.') - - parser.add_argument("--shuffle-seed", - type=int, - default=0, - help='Random seed when the repeat mode is "random"') + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument( + "--num-documents", + type=int, + default=8, + help="Range of input lengths for sampling prompts, " + 'specified as "min:max" (e.g., "128:256").', + ) + + parser.add_argument("--output-len", type=int, default=10) + + parser.add_argument( + "--repeat-count", + type=int, + default=2, + help="Number of times to repeat each prompt", + ) + + parser.add_argument( + "--repeat-mode", + type=str, + default="random", + help="The mode to repeat prompts. The supported " + 'modes are "random", "tile", and "interleave". ' + "See repeat_prompts() in the source code for details.", + ) + + parser.add_argument( + "--shuffle-seed", + type=int, + default=0, + help='Random seed when the repeat mode is "random"', + ) parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index f44da95d3216fcaf6d77636ae4339239d3af51a1..ffaa8035797c10accebde3ab6e35356bd6db6292 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -63,8 +63,7 @@ class Request: output_len: int -def sample_tokens(tokenizer: PreTrainedTokenizerBase, - length: int) -> list[int]: +def sample_tokens(tokenizer: PreTrainedTokenizerBase, length: int) -> list[int]: vocab = tokenizer.get_vocab() all_special_ids = set(tokenizer.all_special_ids) @@ -91,8 +90,10 @@ def sample_requests_from_dataset( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -113,8 +114,9 @@ def sample_requests_from_dataset( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = (len(completion_token_ids) - if fixed_output_len is None else fixed_output_len) + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if min_len <= prompt_len <= max_len: filtered_requests.append(Request(prompt, prompt_len, output_len)) @@ -128,27 +130,27 @@ def sample_requests_from_random( fixed_output_len: Optional[int], prefix_len: int, ) -> list[Request]: - requests = [] prefix_token_ids = sample_tokens(tokenizer, prefix_len) min_len, max_len = input_length_range for i in range(num_requests): unique_part_token_ids = sample_tokens( - tokenizer, - random.randint(min_len - prefix_len, max_len - prefix_len)) + tokenizer, random.randint(min_len - prefix_len, max_len - prefix_len) + ) prompt_token_ids = prefix_token_ids + unique_part_token_ids prompt = tokenizer.decode(prompt_token_ids) prompt_len = len(prompt_token_ids) - assert (min_len <= prompt_len <= max_len - ), f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + assert min_len <= prompt_len <= max_len, ( + f"prompt_len {prompt_len} out of range {min_len}:{max_len}" + ) requests.append(Request(prompt, prompt_len, fixed_output_len)) return requests -def repeat_and_sort_requests(requests: list[Request], - repeat_count: int, - sort: bool = False) -> list[str]: +def repeat_and_sort_requests( + requests: list[Request], repeat_count: int, sort: bool = False +) -> list[str]: repeated_requests = requests * repeat_count if sort: repeated_requests.sort(key=lambda x: x[1]) @@ -159,14 +161,14 @@ def repeat_and_sort_requests(requests: list[Request], def main(args): tokenizer = get_tokenizer(args.model, trust_remote_code=True) - input_length_range = tuple(map(int, args.input_length_range.split(':'))) + input_length_range = tuple(map(int, args.input_length_range.split(":"))) random.seed(args.seed) if args.dataset_path is not None: if args.prefix_len > 0: - raise ValueError("prefix-len is not supported when " - "dataset-path is provided.") - print(f"Start to sample {args.num_prompts} prompts " - f"from {args.dataset_path}") + raise ValueError( + "prefix-len is not supported when dataset-path is provided." + ) + print(f"Start to sample {args.num_prompts} prompts from {args.dataset_path}") filtered_requests = sample_requests_from_dataset( dataset_path=args.dataset_path, num_requests=args.num_prompts, @@ -196,14 +198,16 @@ def main(args): llm = LLM(**dataclasses.asdict(engine_args)) - sampling_params = SamplingParams(temperature=0, - max_tokens=args.output_len, - detokenize=not args.disable_detokenize) + sampling_params = SamplingParams( + temperature=0, + max_tokens=args.output_len, + detokenize=not args.disable_detokenize, + ) print("Testing filtered requests") - prompts = repeat_and_sort_requests(filtered_requests, - repeat_count=args.repeat_count, - sort=args.sort) + prompts = repeat_and_sort_requests( + filtered_requests, repeat_count=args.repeat_count, sort=args.sort + ) print("------start generating------") test_prefix( @@ -215,29 +219,35 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description= - 'Benchmark the performance with or without automatic prefix caching.') - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--num-prompts', - type=int, - required=True, - help="Number of the prompts sampled from dataset") - parser.add_argument('--repeat-count', - type=int, - default=1, - help='Number of times to repeat each prompt') - parser.add_argument('--sort', - action='store_true', - help='Sort prompts by input length') - parser.add_argument('--input-length-range', - type=str, - required=True, - help='Range of input lengths for sampling prompts,' - 'specified as "min:max" (e.g., "128:256").') + description="Benchmark the performance with or without " + "automatic prefix caching." + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--num-prompts", + type=int, + required=True, + help="Number of the prompts sampled from dataset", + ) + parser.add_argument( + "--repeat-count", + type=int, + default=1, + help="Number of times to repeat each prompt", + ) + parser.add_argument( + "--sort", action="store_true", help="Sort prompts by input length" + ) + parser.add_argument( + "--input-length-range", + type=str, + required=True, + help="Range of input lengths for sampling prompts," + 'specified as "min:max" (e.g., "128:256").', + ) parser.add_argument( "--prefix-len", type=int, @@ -248,10 +258,12 @@ if __name__ == "__main__": "when dataset-path is not provided.", ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 76fe00ede249b9e0695df289d50a3d76c6dfcb70..a05dd24dece83d5c7a12130f56d4a82d5eb377f4 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Benchmark offline prioritization.""" + import argparse import dataclasses import json @@ -13,7 +14,7 @@ from vllm.engine.arg_utils import EngineArgs from vllm.utils import FlexibleArgumentParser -#Select a equi-probable random priority +# Select a equi-probable random priority def get_random_flag(): return 0 if random.random() < 0.5 else 1 @@ -33,8 +34,10 @@ def sample_requests( # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], - data["conversations"][1]["value"]) for data in dataset] + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] # Shuffle the dataset. random.shuffle(dataset) @@ -51,8 +54,9 @@ def sample_requests( completion = dataset[i][1] completion_token_ids = tokenizer(completion).input_ids prompt_len = len(prompt_token_ids) - output_len = len(completion_token_ids - ) if fixed_output_len is None else fixed_output_len + output_len = ( + len(completion_token_ids) if fixed_output_len is None else fixed_output_len + ) if prompt_len < 4 or output_len < 4: # Prune too short sequences. continue @@ -74,13 +78,16 @@ def run_vllm( disable_detokenize: bool = False, ) -> float: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( llm.llm_engine.model_config.max_model_len >= (request[1] + request[2]) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " input_len and output_len for all requests.") + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " input_len and output_len for all requests." + ) # Add the requests to the engine. prompts = [] @@ -97,7 +104,8 @@ def run_vllm( ignore_eos=True, max_tokens=output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) @@ -111,26 +119,33 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) if args.dataset is None: # Synthesize a prompt with the given input length. prompt = "hi" * (args.input_len - 1) - requests = [(prompt, args.input_len, args.output_len, - get_random_flag()) for _ in range(args.num_prompts)] + requests = [ + (prompt, args.input_len, args.output_len, get_random_flag()) + for _ in range(args.num_prompts) + ] else: - requests = sample_requests(args.dataset, args.num_prompts, tokenizer, - args.output_len) + requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.output_len + ) if args.backend == "vllm": - elapsed_time = run_vllm(requests, args.n, - EngineArgs.from_cli_args(args), - args.disable_detokenize) + elapsed_time = run_vllm( + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") - total_num_tokens = sum(prompt_len + output_len - for _, prompt_len, output_len, priority in requests) - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} tokens/s") + total_num_tokens = sum( + prompt_len + output_len for _, prompt_len, output_len, priority in requests + ) + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s" + ) # Output JSON results if specified if args.output_json: @@ -147,41 +162,44 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii"], - default="vllm") - parser.add_argument("--dataset", - type=str, - default=None, - help="Path to the dataset.") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=200, - help="Number of prompts to process.") parser.add_argument( - '--output-json', + "--backend", type=str, choices=["vllm", "hf", "mii"], default="vllm" + ) + parser.add_argument( + "--dataset", type=str, default=None, help="Path to the dataset." + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) + parser.add_argument( + "--num-prompts", type=int, default=200, help="Number of prompts to process." + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') + help="Path to save the throughput results in JSON format.", + ) parser.add_argument( - '--disable-detokenize', - action='store_true', - help=("Do not detokenize responses (i.e. do not include " - "detokenization time in the latency measurement)"), + "--disable-detokenize", + action="store_true", + help=( + "Do not detokenize responses (i.e. do not include " + "detokenization time in the latency measurement)" + ), ) parser = EngineArgs.add_cli_args(parser) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index da124e1a81b487c79b26f4746aada51200c13a8e..a887e7150dc78ad1f5ca03951b5482d170a36f54 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -20,6 +20,7 @@ On the client side, run: --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import gc @@ -34,12 +35,16 @@ from datetime import datetime from typing import Any, Optional import numpy as np -from backend_request_func import (ASYNC_REQUEST_FUNCS, - OPENAI_COMPATIBLE_BACKENDS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + OPENAI_COMPATIBLE_BACKENDS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -50,11 +55,21 @@ try: except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser -from benchmark_dataset import (AIMODataset, ASRDataset, BurstGPTDataset, - ConversationDataset, HuggingFaceDataset, - InstructCoderDataset, RandomDataset, - SampleRequest, ShareGPTDataset, SonnetDataset, - VisionArenaDataset) +from benchmark_dataset import ( + AIMODataset, + ASRDataset, + BurstGPTDataset, + ConversationDataset, + HuggingFaceDataset, + InstructCoderDataset, + MTBenchDataset, + NextEditPredictionDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -117,7 +132,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for request in input_requests: @@ -163,8 +179,10 @@ def calculate_metrics( # bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer( + outputs[i].generated_text, add_special_tokens=False + ).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -187,16 +205,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -207,7 +228,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -216,27 +238,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -269,10 +295,12 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len, test_mm_content = \ - input_requests[0].prompt, input_requests[0].prompt_len, \ - input_requests[0].expected_output_len, \ - input_requests[0].multi_modal_data + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0].prompt, + input_requests[0].prompt_len, + input_requests[0].expected_output_len, + input_requests[0].multi_modal_data, + ) assert test_mm_content is None or isinstance(test_mm_content, dict) test_input = RequestFuncInput( @@ -292,36 +320,36 @@ async def benchmark( if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") if lora_modules: # For each input request, choose a LoRA module at random. lora_modules = iter( - [random.choice(lora_modules) \ - for _ in range(len(input_requests))]) + [random.choice(lora_modules) for _ in range(len(input_requests))] + ) if profile: print("Starting profiler...") - profile_input = RequestFuncInput(model=model_id, - model_name=model_name, - prompt=test_prompt, - api_url=base_url + "/start_profile", - prompt_len=test_prompt_len, - output_len=test_output_len, - logprobs=logprobs, - multi_modal_content=test_mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + profile_input = RequestFuncInput( + model=model_id, + model_name=model_name, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -333,42 +361,45 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] async for request in get_request(input_requests, request_rate, burstiness): - prompt, prompt_len, output_len, mm_content = request.prompt, \ - request.prompt_len, request.expected_output_len, \ - request.multi_modal_data + prompt, prompt_len, output_len, mm_content = ( + request.prompt, + request.prompt_len, + request.expected_output_len, + request.multi_modal_data, + ) req_model_id, req_model_name = model_id, model_name if lora_modules: req_lora_module = next(lora_modules) req_model_id, req_model_name = req_lora_module, req_lora_module - request_func_input = RequestFuncInput(model=req_model_id, - model_name=req_model_name, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - logprobs=logprobs, - multi_modal_content=mm_content, - ignore_eos=ignore_eos, - extra_body=extra_body) + request_func_input = RequestFuncInput( + model=req_model_id, + model_name=req_model_name, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + multi_modal_content=mm_content, + ignore_eos=ignore_eos, + extra_body=extra_body, + ) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -400,22 +431,32 @@ async def benchmark( goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { "duration": benchmark_duration, @@ -423,8 +464,7 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "request_goodput:": - metrics.request_goodput if goodput_config_dict else None, + "request_goodput:": metrics.request_goodput if goodput_config_dict else None, "output_throughput": metrics.output_throughput, "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], @@ -447,29 +487,35 @@ async def benchmark( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -489,12 +535,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -507,31 +555,42 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any], - file_name: str) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any], file_name: str +) -> None: metrics = [ - "median_ttft_ms", "mean_ttft_ms", "std_ttft_ms", "p99_ttft_ms", - "mean_tpot_ms", "median_tpot_ms", "std_tpot_ms", "p99_tpot_ms", - "median_itl_ms", "mean_itl_ms", "std_itl_ms", "p99_itl_ms" + "median_ttft_ms", + "mean_ttft_ms", + "std_ttft_ms", + "p99_ttft_ms", + "mean_tpot_ms", + "median_tpot_ms", + "std_tpot_ms", + "p99_tpot_ms", + "median_itl_ms", + "mean_itl_ms", + "std_itl_ms", + "p99_itl_ms", ] # These raw data might be useful, but they are rather big. They can be added # later if needed ignored_metrics = ["ttfts", "itls", "generated_texts", "errors"] pt_records = convert_to_pytorch_benchmark_format( args=args, - metrics={k: [results[k]] - for k in metrics}, + metrics={k: [results[k]] for k in metrics}, extra_info={ k: results[k] - for k in results if k not in metrics and k not in ignored_metrics - }) + for k in results + if k not in metrics and k not in ignored_metrics + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(file_name)[0]}.pytorch.json" @@ -556,34 +615,42 @@ def main(args: argparse.Namespace): api_url = f"http://{args.host}:{args.port}{args.endpoint}" base_url = f"http://{args.host}:{args.port}" - tokenizer = get_tokenizer(tokenizer_id, - tokenizer_mode=tokenizer_mode, - trust_remote_code=args.trust_remote_code) + tokenizer = get_tokenizer( + tokenizer_id, + tokenizer_mode=tokenizer_mode, + trust_remote_code=args.trust_remote_code, + ) if args.dataset_name is None: raise ValueError( "Please specify '--dataset-name' and the corresponding " - "'--dataset-path' if required.") + "'--dataset-path' if required." + ) if args.dataset_name == "sonnet": dataset = SonnetDataset(dataset_path=args.dataset_path) # For the "sonnet" dataset, formatting depends on the backend. if args.backend == "openai-chat": - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=False) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=False, + ) else: assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") - input_requests = dataset.sample(num_requests=args.num_prompts, - input_len=args.sonnet_input_len, - output_len=args.sonnet_output_len, - prefix_len=args.sonnet_prefix_len, - tokenizer=tokenizer, - return_prompt_formatted=True) + "Tokenizer/model must have chat template for sonnet dataset." + ) + input_requests = dataset.sample( + num_requests=args.num_prompts, + input_len=args.sonnet_input_len, + output_len=args.sonnet_output_len, + prefix_len=args.sonnet_prefix_len, + tokenizer=tokenizer, + return_prompt_formatted=True, + ) elif args.dataset_name == "hf": # all following datasets are implemented from the @@ -595,32 +662,45 @@ def main(args: argparse.Namespace): elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_class = InstructCoderDataset args.hf_split = "train" + elif args.dataset_path in MTBenchDataset.SUPPORTED_DATASET_PATHS: + dataset_class = MTBenchDataset + args.hf_split = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_class = ConversationDataset elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_class = AIMODataset args.hf_split = "train" + elif args.dataset_path in NextEditPredictionDataset.SUPPORTED_DATASET_PATHS: # noqa: E501 + dataset_class = NextEditPredictionDataset + args.hf_split = "train" elif args.dataset_path in ASRDataset.SUPPORTED_DATASET_PATHS: dataset_class = ASRDataset args.hf_split = "train" else: - supported_datasets = set([ - dataset_name for cls in HuggingFaceDataset.__subclasses__() - for dataset_name in cls.SUPPORTED_DATASET_PATHS - ]) + supported_datasets = set( + [ + dataset_name + for cls in HuggingFaceDataset.__subclasses__() + for dataset_name in cls.SUPPORTED_DATASET_PATHS + ] + ) raise ValueError( f"Unsupported dataset path: {args.dataset_path}. " "Huggingface dataset only supports dataset_path" f" from one of following: {supported_datasets}. " "Please consider contributing if you would " - "like to add support for additional dataset formats.") + "like to add support for additional dataset formats." + ) - if (dataset_class.IS_MULTIMODAL and backend not in \ - ["openai-chat", "openai-audio"]): + if dataset_class.IS_MULTIMODAL and backend not in [ + "openai-chat", + "openai-audio", + ]: # multi-modal benchmark is only available on OpenAI Chat backend. raise ValueError( - "Multi-modal content is only supported on 'openai-chat' and " \ - "'openai-audio' backend.") + "Multi-modal content is only supported on 'openai-chat' and " + "'openai-audio' backend." + ) input_requests = dataset_class( dataset_path=args.dataset_path, dataset_subset=args.hf_subset, @@ -635,26 +715,24 @@ def main(args: argparse.Namespace): else: # For datasets that follow a similar structure, use a mapping. dataset_mapping = { - "sharegpt": - lambda: ShareGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path).sample( - tokenizer=tokenizer, - num_requests=args.num_prompts, - output_len=args.sharegpt_output_len, - ), - "burstgpt": - lambda: BurstGPTDataset(random_seed=args.seed, - dataset_path=args.dataset_path). - sample(tokenizer=tokenizer, num_requests=args.num_prompts), - "random": - lambda: RandomDataset(dataset_path=args.dataset_path).sample( + "sharegpt": lambda: ShareGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample( + tokenizer=tokenizer, + num_requests=args.num_prompts, + output_len=args.sharegpt_output_len, + ), + "burstgpt": lambda: BurstGPTDataset( + random_seed=args.seed, dataset_path=args.dataset_path + ).sample(tokenizer=tokenizer, num_requests=args.num_prompts), + "random": lambda: RandomDataset(dataset_path=args.dataset_path).sample( tokenizer=tokenizer, num_requests=args.num_prompts, prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, range_ratio=args.random_range_ratio, - ) + ), } try: @@ -670,15 +748,16 @@ def main(args: argparse.Namespace): "top_p": args.top_p, "top_k": args.top_k, "min_p": args.min_p, - "temperature": args.temperature - }.items() if v is not None + "temperature": args.temperature, + }.items() + if v is not None } # Sampling parameters are only supported by openai-compatible backend. if sampling_params and args.backend not in OPENAI_COMPATIBLE_BACKENDS: raise ValueError( - "Sampling parameters are only supported by openai-compatible " - "backends.") + "Sampling parameters are only supported by openai-compatible backends." + ) if "temperature" not in sampling_params: sampling_params["temperature"] = 0.0 # Default to greedy decoding. @@ -702,15 +781,14 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, goodput_config_dict=goodput_config_dict, max_concurrency=args.max_concurrency, lora_modules=args.lora_modules, extra_body=sampling_params, - )) + ) + ) # Save config and results to json if args.save_result or args.append_result: @@ -735,8 +813,9 @@ def main(args: argparse.Namespace): "Invalid metadata format. Please use KEY=VALUE format." ) # Traffic - result_json["request_rate"] = (args.request_rate if args.request_rate - < float("inf") else "inf") + result_json["request_rate"] = ( + args.request_rate if args.request_rate < float("inf") else "inf" + ) result_json["burstiness"] = args.burstiness result_json["max_concurrency"] = args.max_concurrency @@ -746,24 +825,31 @@ def main(args: argparse.Namespace): if not args.save_detailed: # Remove fields with too many data points for field in [ - "input_lens", "output_lens", "ttfts", "itls", - "generated_texts", "errors" + "input_lens", + "output_lens", + "ttfts", + "itls", + "generated_texts", + "errors", ]: if field in result_json: del result_json[field] # Save to file base_model_id = model_id.split("/")[-1] - max_concurrency_str = (f"-concurrency{args.max_concurrency}" - if args.max_concurrency is not None else "") - file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" #noqa + max_concurrency_str = ( + f"-concurrency{args.max_concurrency}" + if args.max_concurrency is not None + else "" + ) + file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa if args.result_filename: file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, - mode="a+" if args.append_result else "w", - encoding='utf-8') as outfile: + with open( + file_name, mode="a+" if args.append_result else "w", encoding="utf-8" + ) as outfile: # Append a newline. if args.append_result and outfile.tell() != 0: outfile.write("\n") @@ -773,7 +859,8 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -802,11 +889,13 @@ if __name__ == "__main__": choices=["sharegpt", "burstgpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the sharegpt/sonnet dataset. " - "Or the huggingface dataset ID if using HF dataset.") + parser.add_argument( + "--dataset-path", + type=str, + default=None, + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.", + ) parser.add_argument( "--max-concurrency", type=int, @@ -818,7 +907,8 @@ if __name__ == "__main__": "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", @@ -829,8 +919,7 @@ if __name__ == "__main__": parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument("--use-beam-search", action="store_true") parser.add_argument( @@ -843,11 +932,13 @@ if __name__ == "__main__": "--logprobs", type=int, default=None, - help=("Number of logprobs-per-token to compute & return as part of " - "the request. If unspecified, then either (1) if beam search " - "is disabled, no logprobs are computed & a single dummy " - "logprob is returned for each token; or (2) if beam search " - "is enabled 1 logprob per token is computed"), + help=( + "Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed" + ), ) parser.add_argument( "--request-rate", @@ -931,35 +1022,38 @@ if __name__ == "__main__": "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) # group for dataset specific arguments sonnet_group = parser.add_argument_group("sonnet dataset options") @@ -967,22 +1061,19 @@ if __name__ == "__main__": "--sonnet-input-len", type=int, default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", + help="Number of input tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-output-len", type=int, default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", + help="Number of output tokens per request, used only for sonnet dataset.", ) sonnet_group.add_argument( "--sonnet-prefix-len", type=int, default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", + help="Number of prefix tokens per request, used only for sonnet dataset.", ) sharegpt_group = parser.add_argument_group("sharegpt dataset options") @@ -991,22 +1082,21 @@ if __name__ == "__main__": type=int, default=None, help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") + "from the ShareGPT dataset.", + ) random_group = parser.add_argument_group("random dataset options") random_group.add_argument( "--random-input-len", type=int, default=1024, - help= - "Number of input tokens per request, used only for random sampling.", + help="Number of input tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-output-len", type=int, default=128, - help= - "Number of output tokens per request, used only for random sampling.", + help="Number of output tokens per request, used only for random sampling.", ) random_group.add_argument( "--random-range-ratio", @@ -1021,23 +1111,23 @@ if __name__ == "__main__": "--random-prefix-len", type=int, default=0, - help=("Number of fixed prefix tokens before the random context " - "in a request. " - "The total input length is the sum of `random-prefix-len` and " - "a random " - "context length sampled from [input_len * (1 - range_ratio), " - "input_len * (1 + range_ratio)]."), + help=( + "Number of fixed prefix tokens before the random context " + "in a request. " + "The total input length is the sum of `random-prefix-len` and " + "a random " + "context length sampled from [input_len * (1 - range_ratio), " + "input_len * (1 + range_ratio)]." + ), ) hf_group = parser.add_argument_group("hf dataset options") - hf_group.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - hf_group.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + hf_group.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) hf_group.add_argument( "--hf-output-len", type=int, @@ -1051,52 +1141,58 @@ if __name__ == "__main__": "--top-p", type=float, default=None, - help="Top-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--top-k", type=int, default=None, - help="Top-k sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Top-k sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--min-p", type=float, default=None, - help="Min-p sampling parameter. Only has effect on openai-compatible " - "backends.") + help="Min-p sampling parameter. Only has effect on openai-compatible backends.", + ) sampling_group.add_argument( "--temperature", type=float, default=None, help="Temperature sampling parameter. Only has effect on " "openai-compatible backends. If not specified, default to greedy " - "decoding (i.e. temperature==0.0).") + "decoding (i.e. temperature==0.0).", + ) parser.add_argument( - '--tokenizer-mode', + "--tokenizer-mode", type=str, default="auto", - choices=['auto', 'slow', 'mistral', 'custom'], + choices=["auto", "slow", "mistral", "custom"], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer. \n* ' + "always use the slow tokenizer. \n* " '"mistral" will always use the `mistral_common` tokenizer. \n*' - '"custom" will use --tokenizer to select the preregistered tokenizer.') - - parser.add_argument("--served-model-name", - type=str, - default=None, - help="The model name used in the API. " - "If not specified, the model name will be the " - "same as the ``--model`` argument. ") - - parser.add_argument("--lora-modules", - nargs='+', - default=None, - help="A subset of LoRA module names passed in when " - "launching the server. For each request, the " - "script chooses a LoRA module at random.") + '"custom" will use --tokenizer to select the preregistered tokenizer.', + ) + + parser.add_argument( + "--served-model-name", + type=str, + default=None, + help="The model name used in the API. " + "If not specified, the model name will be the " + "same as the ``--model`` argument. ", + ) + + parser.add_argument( + "--lora-modules", + nargs="+", + default=None, + help="A subset of LoRA module names passed in when " + "launching the server. For each request, the " + "script chooses a LoRA module at random.", + ) args = parser.parse_args() diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 74ee00ec893076f52f8428a10f7ced66189e712d..5088c805f53ef8f85181f5d1c701534abf583ebc 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -19,6 +19,7 @@ On the client side, run: --endpoint /generate_stream to the end of the command above. """ + import argparse import asyncio import copy @@ -36,11 +37,15 @@ from typing import Optional import datasets import numpy as np import pandas as pd -from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, - RequestFuncOutput) from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase +from backend_request_func import ( + ASYNC_REQUEST_FUNCS, + RequestFuncInput, + RequestFuncOutput, +) + try: from vllm.transformers_utils.tokenizer import get_tokenizer except ImportError: @@ -52,7 +57,8 @@ except ImportError: from argparse import ArgumentParser as FlexibleArgumentParser from vllm.v1.structured_output.backend_xgrammar import ( - has_xgrammar_unsupported_json_features) + has_xgrammar_unsupported_json_features, +) MILLISECONDS_TO_SECONDS_CONVERSION = 1000 @@ -98,6 +104,7 @@ class SampleRequest: prompt_len: The length of the prompt in tokens. expected_output_len: The expected length of the output in tokens. """ + prompt: str prompt_len: int expected_output_len: int @@ -106,45 +113,45 @@ class SampleRequest: completion: str = None -def sample_requests(tokenizer: PreTrainedTokenizerBase, - args: argparse.Namespace) -> list[SampleRequest]: - if args.dataset == 'json' or args.dataset == 'json-unique': +def sample_requests( + tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace +) -> list[SampleRequest]: + if args.dataset == "json" or args.dataset == "json-unique": if args.json_schema_path is None: dir_path = os.path.dirname(os.path.realpath(__file__)) - args.json_schema_path = os.path.join(dir_path, - "structured_schemas", - "structured_schema_1.json") + args.json_schema_path = os.path.join( + dir_path, "structured_schemas", "structured_schema_1.json" + ) json_schemas = [] with open(args.json_schema_path) as f: schema = json.load(f) - if args.dataset == 'json-unique': - json_schemas = [ - copy.deepcopy(schema) for _ in range(args.num_prompts) - ] + if args.dataset == "json-unique": + json_schemas = [copy.deepcopy(schema) for _ in range(args.num_prompts)] for i in range(len(json_schemas)): - json_schemas[i]["properties"][ - f"__optional_field_{uuid.uuid4()}"] = { - "type": - "string", - "description": - "An unique optional field to avoid cached schemas" - } + if "properties" not in json_schemas[i]: + json_schemas[i]["properties"] = {} + json_schemas[i]["properties"][f"__optional_field_{uuid.uuid4()}"] = { + "type": "string", + "description": "An unique optional field to avoid cached schemas", + } else: json_schemas = [schema] * args.num_prompts def gen_prompt(index: int): - return f"Generate an example of a user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 + return f"Generate an example of a brief user profile given the following schema: {json.dumps(get_schema(index))}" # noqa: E501 def get_schema(index: int): return json_schemas[index % len(json_schemas)] requests = [ - SampleRequest(prompt=gen_prompt(i), - prompt_len=len(tokenizer(gen_prompt(i)).input_ids), - expected_output_len=args.output_len, - schema=get_schema(i), - structure_type=args.structure_type) + SampleRequest( + prompt=gen_prompt(i), + prompt_len=len(tokenizer(gen_prompt(i)).input_ids), + expected_output_len=args.output_len, + schema=get_schema(i), + structure_type=args.structure_type, + ) for i in range(args.num_prompts) ] @@ -168,11 +175,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -186,11 +195,13 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=regex, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=regex, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] @@ -201,47 +212,55 @@ def sample_requests(tokenizer: PreTrainedTokenizerBase, input_len = len(tokenizer(prompt).input_ids) print(f"Input length of the prompt: {input_len} tokens") requests = [ - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=choice, - structure_type=args.structure_type) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=choice, + structure_type=args.structure_type, + ) for _ in range(args.num_prompts) ] elif args.dataset == "xgrammar_bench": requests: list[SampleRequest] = [] - dataset = datasets.load_dataset("NousResearch/json-mode-eval", - split="train") + dataset = datasets.load_dataset("NousResearch/json-mode-eval", split="train") full_dataset_len = len(dataset) def _filter_func(item): import json + schema = json.loads(item["schema"]) return not has_xgrammar_unsupported_json_features(schema) dataset = dataset.filter(_filter_func) num_filtered_out = full_dataset_len - len(dataset) - print(f"dataset has {len(dataset)} entries after filtering " - f"out {num_filtered_out} entries with unsupported features") + print( + f"dataset has {len(dataset)} entries after filtering " + f"out {num_filtered_out} entries with unsupported features" + ) len_dataset = len(dataset) for data_point_idx in range(args.num_prompts): idx = data_point_idx while idx >= len_dataset: idx -= len_dataset schema = dataset["schema"][idx] - prompt = tokenizer.apply_chat_template(dataset["prompt"][idx], - tokenize=False) + prompt = tokenizer.apply_chat_template( + dataset["prompt"][idx], tokenize=False, add_generation_prompt=True + ) input_len = len(tokenizer(prompt).input_ids) completion = dataset["completion"][idx] requests.append( - SampleRequest(prompt=prompt, - prompt_len=input_len, - expected_output_len=args.output_len, - schema=schema, - structure_type=args.structure_type, - completion=completion)) + SampleRequest( + prompt=prompt, + prompt_len=input_len, + expected_output_len=args.output_len, + schema=schema, + structure_type=args.structure_type, + completion=completion, + ) + ) return requests @@ -273,7 +292,8 @@ async def get_request( # Calculate scale parameter theta to maintain the desired request_rate. assert burstiness > 0, ( - f"A positive burstiness factor is expected, but given {burstiness}.") + f"A positive burstiness factor is expected, but given {burstiness}." + ) theta = 1.0 / (request_rate * burstiness) for i, request in enumerate(input_requests): @@ -315,8 +335,8 @@ def calculate_metrics( # multiple output tokens may be bundled together # Note : this may inflate the output token count slightly output_len = len( - tokenizer(outputs[i].generated_text, - add_special_tokens=False).input_ids) + tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids + ) actual_output_lens.append(output_len) total_input += input_requests[i].prompt_len tpot = 0 @@ -340,16 +360,19 @@ def calculate_metrics( if "ttft" in goodput_config_dict: valid_metrics.append(ttfts) - slo_values.append(goodput_config_dict["ttft"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["ttft"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "tpot" in goodput_config_dict: valid_metrics.append(all_tpots) - slo_values.append(goodput_config_dict["tpot"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["tpot"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) if "e2el" in goodput_config_dict: valid_metrics.append(e2els) - slo_values.append(goodput_config_dict["e2el"] / - MILLISECONDS_TO_SECONDS_CONVERSION) + slo_values.append( + goodput_config_dict["e2el"] / MILLISECONDS_TO_SECONDS_CONVERSION + ) for req_metric in zip(*valid_metrics): is_good_req = all([s >= r for s, r in zip(slo_values, req_metric)]) @@ -360,7 +383,8 @@ def calculate_metrics( warnings.warn( "All requests failed. This is likely due to a misconfiguration " "on the benchmark arguments.", - stacklevel=2) + stacklevel=2, + ) metrics = BenchmarkMetrics( completed=completed, total_input=total_input, @@ -369,27 +393,31 @@ def calculate_metrics( request_goodput=good_completed / dur_s, output_throughput=sum(actual_output_lens) / dur_s, total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, - mean_ttft_ms=np.mean(ttfts or 0) * - 1000, # ttfts is empty if streaming is not supported by backend + mean_ttft_ms=np.mean(ttfts or 0) + * 1000, # ttfts is empty if streaming is not supported by backend std_ttft_ms=np.std(ttfts or 0) * 1000, median_ttft_ms=np.median(ttfts or 0) * 1000, - percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) - for p in selected_percentiles], + percentiles_ttft_ms=[ + (p, np.percentile(ttfts or 0, p) * 1000) for p in selected_percentiles + ], mean_tpot_ms=np.mean(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, median_tpot_ms=np.median(tpots or 0) * 1000, - percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) - for p in selected_percentiles], + percentiles_tpot_ms=[ + (p, np.percentile(tpots or 0, p) * 1000) for p in selected_percentiles + ], mean_itl_ms=np.mean(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, median_itl_ms=np.median(itls or 0) * 1000, - percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) - for p in selected_percentiles], + percentiles_itl_ms=[ + (p, np.percentile(itls or 0, p) * 1000) for p in selected_percentiles + ], mean_e2el_ms=np.mean(e2els or 0) * 1000, std_e2el_ms=np.std(e2els or 0) * 1000, median_e2el_ms=np.median(e2els or 0) * 1000, - percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) - for p in selected_percentiles], + percentiles_e2el_ms=[ + (p, np.percentile(e2els or 0, p) * 1000) for p in selected_percentiles + ], ) return metrics, actual_output_lens @@ -411,7 +439,6 @@ async def benchmark( ignore_eos: bool, max_concurrency: Optional[int], structured_output_ratio: float, - structured_output_backend: str, goodput_config_dict: Optional[dict[str, float]] = None, ): if backend in ASYNC_REQUEST_FUNCS: @@ -423,18 +450,17 @@ async def benchmark( extra_body = {} # Add the schema to the extra_body extra_body[request.structure_type] = request.schema - # Add the specific structured_output_backend - extra_body["guided_decoding_backend"] = structured_output_backend return extra_body print("Starting initial single prompt test run...") structured_output_req_idx = random.sample( - range(len(input_requests)), - int(len(input_requests) * structured_output_ratio)) + range(len(input_requests)), int(len(input_requests) * structured_output_ratio) + ) test_request = input_requests[0] - test_req_extra_body = (prepare_extra_body(test_request) - if 0 in structured_output_req_idx else None) + test_req_extra_body = ( + prepare_extra_body(test_request) if 0 in structured_output_req_idx else None + ) test_input = RequestFuncInput( model=model_id, prompt=test_request.prompt, @@ -448,7 +474,8 @@ async def benchmark( if not test_output.success: raise ValueError( "Initial test run failed - Please make sure benchmark arguments " - f"are correctly specified. Error: {test_output.error}") + f"are correctly specified. Error: {test_output.error}" + ) else: print("Initial test run completed. Starting main benchmark run...") @@ -467,10 +494,7 @@ async def benchmark( if profile_output.success: print("Profiler started") - if burstiness == 1.0: - distribution = "Poisson process" - else: - distribution = "Gamma distribution" + distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution" print(f"Traffic request rate: {request_rate}") print(f"Burstiness factor: {burstiness} ({distribution})") @@ -482,24 +506,21 @@ async def benchmark( # and it will simplify the code in limited_request_func. # semaphore = (asyncio.Semaphore(max_concurrency) # if max_concurrency else contextlib.nullcontext()) - semaphore = (asyncio.Semaphore(max_concurrency) - if max_concurrency else None) + semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None async def limited_request_func(request_func_input, pbar): if semaphore is None: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) async with semaphore: - return await request_func(request_func_input=request_func_input, - pbar=pbar) + return await request_func(request_func_input=request_func_input, pbar=pbar) benchmark_start_time = time.perf_counter() tasks: list[asyncio.Task] = [] expected: list[str] = [] - async for i, request in get_request(input_requests, request_rate, - burstiness): - extra_body = prepare_extra_body( - request) if i in structured_output_req_idx else None + async for i, request in get_request(input_requests, request_rate, burstiness): + extra_body = ( + prepare_extra_body(request) if i in structured_output_req_idx else None + ) request_func_input = RequestFuncInput( model=model_id, prompt=request.prompt, @@ -512,8 +533,9 @@ async def benchmark( expected.append(request.completion) tasks.append( asyncio.create_task( - limited_request_func(request_func_input=request_func_input, - pbar=pbar))) + limited_request_func(request_func_input=request_func_input, pbar=pbar) + ) + ) outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks) if profile: @@ -545,54 +567,58 @@ async def benchmark( goodput_config_dict=goodput_config_dict, ) - print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) + print("{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="=")) print("{:<40} {:<10}".format("Successful requests:", metrics.completed)) - print("{:<40} {:<10.2f}".format("Benchmark duration (s):", - benchmark_duration)) + print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration)) print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input)) - print("{:<40} {:<10}".format("Total generated tokens:", - metrics.total_output)) - print("{:<40} {:<10.2f}".format("Request throughput (req/s):", - metrics.request_throughput)) + print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output)) + print( + "{:<40} {:<10.2f}".format( + "Request throughput (req/s):", metrics.request_throughput + ) + ) if goodput_config_dict: - print("{:<40} {:<10.2f}".format("Request goodput (req/s):", - metrics.request_goodput)) - print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", - metrics.output_throughput)) - print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", - metrics.total_token_throughput)) + print( + "{:<40} {:<10.2f}".format( + "Request goodput (req/s):", metrics.request_goodput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Output token throughput (tok/s):", metrics.output_throughput + ) + ) + print( + "{:<40} {:<10.2f}".format( + "Total Token throughput (tok/s):", metrics.total_token_throughput + ) + ) result = { - "duration": - benchmark_duration, - "completed": - metrics.completed, - "total_input_tokens": - metrics.total_input, - "total_output_tokens": - metrics.total_output, - "request_throughput": - metrics.request_throughput, - "output_throughput": - metrics.output_throughput, - "total_token_throughput": - metrics.total_token_throughput, - "ttft_description": - pd.Series([output.ttft for output in outputs]).describe().to_dict(), - "tpot_description": - pd.Series([output.tpot for output in outputs]).describe().to_dict(), + "duration": benchmark_duration, + "completed": metrics.completed, + "total_input_tokens": metrics.total_input, + "total_output_tokens": metrics.total_output, + "request_throughput": metrics.request_throughput, + "output_throughput": metrics.output_throughput, + "total_token_throughput": metrics.total_token_throughput, + "ttft_description": pd.Series([output.ttft for output in outputs]) + .describe() + .to_dict(), + "tpot_description": pd.Series([output.tpot for output in outputs]) + .describe() + .to_dict(), "input_lens": [output.prompt_len for output in outputs], - "output_lens": - actual_output_lens, + "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], "itls": [output.itl for output in outputs], "errors": [output.error for output in outputs], } - ret = [{ - 'generated': output.generated_text, - 'expected': gt - } for output, gt in zip(outputs, expected)] + ret = [ + {"generated": output.generated_text, "expected": gt} + for output, gt in zip(outputs, expected) + ] def process_one_metric( # E.g., "ttft" @@ -606,29 +632,35 @@ async def benchmark( # metric. if metric_attribute_name not in selected_percentile_metrics: return - print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) - print("{:<40} {:<10.2f}".format( - f"Mean {metric_name} (ms):", - getattr(metrics, f"mean_{metric_attribute_name}_ms"))) - print("{:<40} {:<10.2f}".format( - f"Median {metric_name} (ms):", - getattr(metrics, f"median_{metric_attribute_name}_ms"))) + print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-")) + print( + "{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"), + ) + ) + print( + "{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"), + ) + ) result[f"mean_{metric_attribute_name}_ms"] = getattr( - metrics, f"mean_{metric_attribute_name}_ms") + metrics, f"mean_{metric_attribute_name}_ms" + ) result[f"median_{metric_attribute_name}_ms"] = getattr( - metrics, f"median_{metric_attribute_name}_ms") + metrics, f"median_{metric_attribute_name}_ms" + ) result[f"std_{metric_attribute_name}_ms"] = getattr( - metrics, f"std_{metric_attribute_name}_ms") - for p, value in getattr(metrics, - f"percentiles_{metric_attribute_name}_ms"): + metrics, f"std_{metric_attribute_name}_ms" + ) + for p, value in getattr(metrics, f"percentiles_{metric_attribute_name}_ms"): p_word = str(int(p)) if int(p) == p else str(p) - print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", - value)) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value)) result[f"p{p_word}_{metric_attribute_name}_ms"] = value process_one_metric("ttft", "TTFT", "Time to First Token") - process_one_metric("tpot", "TPOT", - "Time per Output Token (excl. 1st token)") + process_one_metric("tpot", "TPOT", "Time per Output Token (excl. 1st token)") process_one_metric("itl", "ITL", "Inter-token Latency") process_one_metric("e2el", "E2EL", "End-to-end Latency") @@ -638,13 +670,13 @@ async def benchmark( def evaluate(ret, args): - def _eval_correctness_json(expected, actual): # extract json string from string using regex import re - actual = actual.replace('\n', '').replace(' ', '').strip() + + actual = actual.replace("\n", "").replace(" ", "").strip() try: - actual = re.search(r'\{.*\}', actual).group() + actual = re.search(r"\{.*\}", actual).group() actual = json.loads(actual) except Exception: return False @@ -656,28 +688,32 @@ def evaluate(ret, args): def _eval_correctness_regex(expected, actual): import re + return re.match(args.regex, actual) is not None def _eval_correctness(expected, actual): - if args.structure_type == 'guided_json': + if args.structure_type == "guided_json": return _eval_correctness_json(expected, actual) - elif args.structure_type == 'guided_regex': + elif args.structure_type == "guided_regex": return _eval_correctness_regex(expected, actual) - elif args.structure_type == 'guided_choice': + elif args.structure_type == "guided_choice": return _eval_correctness_choice(expected, actual) else: return None scores = [] for res in ret: - score = _eval_correctness(res['expected'], res['generated']) - res['correctness'] = score + score = _eval_correctness(res["expected"], res["generated"]) + res["correctness"] = score scores.append(score) not_none_scores = [score for score in scores if score is not None] - return (sum(not_none_scores) / len(not_none_scores) * - 100) if len(not_none_scores) > 0 else None + return ( + (sum(not_none_scores) / len(not_none_scores) * 100) + if len(not_none_scores) > 0 + else None + ) def parse_goodput(slo_pairs): @@ -689,9 +725,10 @@ def parse_goodput(slo_pairs): except ValueError as err: raise argparse.ArgumentTypeError( "Invalid format found for service level objectives. " - "Specify service level objectives for goodput as \"KEY:VALUE\" " + 'Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is a " - "number in milliseconds.") from err + "number in milliseconds." + ) from err return goodput_config_dict @@ -705,12 +742,14 @@ def check_goodput_args(args): raise ValueError( f"Invalid metric name found, {slo_name}: {slo_val}. " "The service level objective name should be one of " - f"{str(VALID_NAMES)}. ") + f"{str(VALID_NAMES)}. " + ) if slo_val < 0: raise ValueError( f"Invalid value found, {slo_name}: {slo_val}. " "The service level objective value should be " - "non-negative.") + "non-negative." + ) return goodput_config_dict @@ -736,19 +775,19 @@ def main(args: argparse.Namespace): tokenizer_mode=args.tokenizer_mode, ) - if args.dataset == 'grammar': - args.structure_type = 'guided_grammar' - elif args.dataset == 'regex': - args.structure_type = 'guided_regex' - elif args.dataset == 'choice': - args.structure_type = 'guided_choice' + if args.dataset == "grammar": + args.structure_type = "guided_grammar" + elif args.dataset == "regex": + args.structure_type = "guided_regex" + elif args.dataset == "choice": + args.structure_type = "guided_choice" else: - args.structure_type = 'guided_json' + args.structure_type = "guided_json" if args.no_structured_output: args.structured_output_ratio = 0 if args.save_results: - result_file_name = f'{args.structured_output_ratio}guided' + result_file_name = f"{args.structured_output_ratio}guided" result_file_name += f"_{backend}" result_file_name += f"_{args.request_rate}qps" result_file_name += f"_{args.model.split('/')[-1]}" @@ -776,37 +815,29 @@ def main(args: argparse.Namespace): disable_tqdm=args.disable_tqdm, profile=args.profile, selected_percentile_metrics=args.percentile_metrics.split(","), - selected_percentiles=[ - float(p) for p in args.metric_percentiles.split(",") - ], + selected_percentiles=[float(p) for p in args.metric_percentiles.split(",")], ignore_eos=args.ignore_eos, max_concurrency=args.max_concurrency, structured_output_ratio=args.structured_output_ratio, - structured_output_backend=args.structured_output_backend, goodput_config_dict=goodput_config_dict, - )) + ) + ) # Save config and results to json score = evaluate(ret, args) - print("correct_rate(%)", score, '\n') + print("correct_rate(%)", score, "\n") if args.save_results: results = { - "backend": - backend, - "model_id": - model_id, - "tokenizer_id": - tokenizer_id, - "num_prompts": - args.num_prompts, - "request_rate": - args.request_rate if args.request_rate < float("inf") else "inf", - "burstiness": - args.burstiness, - "max_concurrency": - args.max_concurrency, - "correct_rate(%)": - score + "backend": backend, + "model_id": model_id, + "tokenizer_id": tokenizer_id, + "num_prompts": args.num_prompts, + "request_rate": args.request_rate + if args.request_rate < float("inf") + else "inf", + "burstiness": args.burstiness, + "max_concurrency": args.max_concurrency, + "correct_rate(%)": score, } results = {"outputs": ret, **results, **benchmark_result} @@ -815,13 +846,14 @@ def main(args: argparse.Namespace): result_file_name = args.result_filename if args.result_dir: result_file_name = os.path.join(args.result_dir, result_file_name) - with open(result_file_name, "w", encoding='utf-8') as outfile: + with open(result_file_name, "w", encoding="utf-8") as outfile: json.dump(results, outfile, indent=4) if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the online serving throughput.") + description="Benchmark the online serving throughput." + ) parser.add_argument( "--backend", type=str, @@ -843,16 +875,14 @@ if __name__ == "__main__": default="/v1/completions", help="API endpoint.", ) - parser.add_argument("--dataset", - default='json', - choices=[ - 'json', 'json-unique', 'grammar', 'regex', - 'choice', 'xgrammar_bench' - ]) - parser.add_argument("--json_schema_path", - type=str, - default=None, - help="Path to json schema.") + parser.add_argument( + "--dataset", + default="json", + choices=["json", "json-unique", "grammar", "regex", "choice", "xgrammar_bench"], + ) + parser.add_argument( + "--json-schema-path", type=str, default=None, help="Path to json schema." + ) parser.add_argument( "--max-concurrency", type=int, @@ -864,7 +894,8 @@ if __name__ == "__main__": "initiated, this argument will control how many are actually allowed " "to execute at a time. This means that when used in combination, the " "actual request rate may be lower than specified with --request-rate, " - "if the server is not processing requests fast enough to keep up.") + "if the server is not processing requests fast enough to keep up.", + ) parser.add_argument( "--model", type=str, @@ -874,15 +905,13 @@ if __name__ == "__main__": parser.add_argument( "--tokenizer", type=str, - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--tokenizer-mode", type=str, default="auto", - help= - "Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 + help="Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501 ) parser.add_argument( "--num-prompts", @@ -959,52 +988,51 @@ if __name__ == "__main__": "--ignore-eos", action="store_true", help="Set ignore_eos flag when sending the benchmark request." - "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.", + ) parser.add_argument( "--percentile-metrics", type=str, default="ttft,tpot,itl", help="Comma-separated list of selected metrics to report percentils. " "This argument specifies the metrics to report percentiles. " - "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " - "Default value is \"ttft,tpot,itl\".") + 'Allowed metric names are "ttft", "tpot", "itl", "e2el". ' + 'Default value is "ttft,tpot,itl".', + ) parser.add_argument( "--metric-percentiles", type=str, default="99", help="Comma-separated list of percentiles for selected metrics. " - "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " - "Default value is \"99\". " - "Use \"--percentile-metrics\" to select metrics.", + 'To report 25-th, 50-th, and 75-th percentiles, use "25,50,75". ' + 'Default value is "99". ' + 'Use "--percentile-metrics" to select metrics.', ) parser.add_argument( "--goodput", nargs="+", required=False, - help="Specify service level objectives for goodput as \"KEY:VALUE\" " + help='Specify service level objectives for goodput as "KEY:VALUE" ' "pairs, where the key is a metric name, and the value is in " - "milliseconds. Multiple \"KEY:VALUE\" pairs can be provided, " + 'milliseconds. Multiple "KEY:VALUE" pairs can be provided, ' "separated by spaces. Allowed request level metric names are " - "\"ttft\", \"tpot\", \"e2el\". For more context on the definition of " + '"ttft", "tpot", "e2el". For more context on the definition of ' "goodput, refer to DistServe paper: https://arxiv.org/pdf/2401.09670 " - "and the blog: https://hao-ai-lab.github.io/blogs/distserve") - - parser.add_argument("--no-structured-output", - action='store_true', - default=False, - help="Whether to disable JSON decoding or not.") - parser.add_argument("--structured-output-ratio", - type=float, - default=1.0, - help="Ratio of Structured Outputs requests") - parser.add_argument("--structured-output-backend", - type=str, - choices=[ - "outlines", "lm-format-enforcer", "xgrammar", - "guidance", "auto" - ], - default="auto", - help="Backend to use for structured outputs") + "and the blog: https://hao-ai-lab.github.io/blogs/distserve", + ) + + parser.add_argument( + "--no-structured-output", + action="store_true", + default=False, + help="Whether to disable JSON decoding or not.", + ) + parser.add_argument( + "--structured-output-ratio", + type=float, + default=1.0, + help="Ratio of Structured Outputs requests", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1f65277e1bfebc98543d4dc009d5dee8351e1419..7a13babda9d16227385e5bae3df37d5daed12b05 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 """Benchmark offline inference throughput.""" + import argparse import dataclasses import json @@ -11,18 +12,25 @@ from typing import Any, Optional, Union import torch import uvloop -from benchmark_dataset import (AIMODataset, BurstGPTDataset, - ConversationDataset, InstructCoderDataset, - RandomDataset, SampleRequest, ShareGPTDataset, - SonnetDataset, VisionArenaDataset) -from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from tqdm import tqdm -from transformers import (AutoModelForCausalLM, AutoTokenizer, - PreTrainedTokenizerBase) - +from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase + +from benchmark_dataset import ( + AIMODataset, + BurstGPTDataset, + ConversationDataset, + InstructCoderDataset, + RandomDataset, + SampleRequest, + ShareGPTDataset, + SonnetDataset, + VisionArenaDataset, +) +from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs from vllm.entrypoints.openai.api_server import ( - build_async_engine_client_from_engine_args) + build_async_engine_client_from_engine_args, +) from vllm.inputs import TextPrompt, TokensPrompt from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput @@ -37,23 +45,30 @@ def run_vllm( disable_detokenize: bool = False, ) -> tuple[float, Optional[list[RequestOutput]]]: from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] sampling_params: list[SamplingParams] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -62,7 +77,8 @@ def run_vllm( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests: Optional[list[LoRARequest]] = None if engine_args.enable_lora: lora_requests = [request.lora_request for request in requests] @@ -72,10 +88,9 @@ def run_vllm( outputs = None if not use_beam_search: start = time.perf_counter() - outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + outputs = llm.generate( + prompts, sampling_params, lora_request=lora_requests, use_tqdm=True + ) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -91,30 +106,35 @@ def run_vllm( beam_width=n, max_tokens=output_len, ignore_eos=True, - )) + ), + ) end = time.perf_counter() return end - start, outputs def run_vllm_chat( - requests: list[SampleRequest], - n: int, - engine_args: EngineArgs, - disable_detokenize: bool = False) -> tuple[float, list[RequestOutput]]: + requests: list[SampleRequest], + n: int, + engine_args: EngineArgs, + disable_detokenize: bool = False, +) -> tuple[float, list[RequestOutput]]: """ Run vLLM chat benchmark. This function is recommended ONLY for benchmarking multimodal models as it properly handles multimodal inputs and chat formatting. For non-multimodal models, use run_vllm() instead. """ from vllm import LLM, SamplingParams + llm = LLM(**dataclasses.asdict(engine_args)) assert all( - llm.llm_engine.model_config.max_model_len >= ( - request.prompt_len + request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of " - "prompt_len and expected_output_len for all requests.") + llm.llm_engine.model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of " + "prompt_len and expected_output_len for all requests." + ) prompts = [] sampling_params: list[SamplingParams] = [] @@ -128,7 +148,8 @@ def run_vllm_chat( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) start = time.perf_counter() outputs = llm.chat(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() @@ -145,13 +166,17 @@ async def run_vllm_async( from vllm import SamplingParams async with build_async_engine_client_from_engine_args( - engine_args, disable_frontend_multiprocessing) as llm: + engine_args, disable_frontend_multiprocessing + ) as llm: + model_config = await llm.get_model_config() assert all( - llm.model_config.max_model_len >= (request.prompt_len + - request.expected_output_len) - for request in requests), ( - "Please ensure that max_model_len is greater than the sum of" - " prompt_len and expected_output_len for all requests.") + model_config.max_model_len + >= (request.prompt_len + request.expected_output_len) + for request in requests + ), ( + "Please ensure that max_model_len is greater than the sum of" + " prompt_len and expected_output_len for all requests." + ) # Add the requests to the engine. prompts: list[Union[TextPrompt, TokensPrompt]] = [] @@ -159,11 +184,15 @@ async def run_vllm_async( lora_requests: list[Optional[LoRARequest]] = [] for request in requests: prompts.append( - TokensPrompt(prompt_token_ids=request.prompt["prompt_token_ids"], - multi_modal_data=request.multi_modal_data) - if "prompt_token_ids" in request.prompt else \ - TextPrompt(prompt=request.prompt, - multi_modal_data=request.multi_modal_data)) + TokensPrompt( + prompt_token_ids=request.prompt["prompt_token_ids"], + multi_modal_data=request.multi_modal_data, + ) + if "prompt_token_ids" in request.prompt + else TextPrompt( + prompt=request.prompt, multi_modal_data=request.multi_modal_data + ) + ) sampling_params.append( SamplingParams( n=n, @@ -172,17 +201,16 @@ async def run_vllm_async( ignore_eos=True, max_tokens=request.expected_output_len, detokenize=not disable_detokenize, - )) + ) + ) lora_requests.append(request.lora_request) generators = [] start = time.perf_counter() - for i, (prompt, sp, - lr) in enumerate(zip(prompts, sampling_params, lora_requests)): - generator = llm.generate(prompt, - sp, - lora_request=lr, - request_id=f"test{i}") + for i, (prompt, sp, lr) in enumerate( + zip(prompts, sampling_params, lora_requests) + ): + generator = llm.generate(prompt, sp, lora_request=lr, request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) async for i, res in all_gens: @@ -201,7 +229,8 @@ def run_hf( disable_detokenize: bool = False, ) -> float: llm = AutoModelForCausalLM.from_pretrained( - model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) + model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code + ) if llm.config.model_type == "llama": # To enable padding in the HF backend. tokenizer.pad_token = tokenizer.eos_token @@ -224,14 +253,15 @@ def run_hf( # Check if we can add more requests to the batch. next_prompt_len = requests[i + 1].prompt_len next_output_len = requests[i + 1].expected_output_len - if (max(max_prompt_len, next_prompt_len) + - max(max_output_len, next_output_len)) <= 2048: + if ( + max(max_prompt_len, next_prompt_len) + + max(max_output_len, next_output_len) + ) <= 2048: # We can add more requests to the batch. continue # Generate the sequences. - input_ids = tokenizer(batch, return_tensors="pt", - padding=True).input_ids + input_ids = tokenizer(batch, return_tensors="pt", padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), do_sample=True, @@ -261,6 +291,7 @@ def run_mii( output_len: int, ) -> float: from mii import client, serve + llm = serve(model, tensor_parallel=tensor_parallel_size) prompts = [request.prompt for request in requests] @@ -272,8 +303,9 @@ def run_mii( return end - start -def save_to_pytorch_benchmark_format(args: argparse.Namespace, - results: dict[str, Any]) -> None: +def save_to_pytorch_benchmark_format( + args: argparse.Namespace, results: dict[str, Any] +) -> None: pt_records = convert_to_pytorch_benchmark_format( args=args, metrics={ @@ -281,9 +313,9 @@ def save_to_pytorch_benchmark_format(args: argparse.Namespace, "tokens_per_second": [results["tokens_per_second"]], }, extra_info={ - k: results[k] - for k in ["elapsed_time", "num_requests", "total_num_tokens"] - }) + k: results[k] for k in ["elapsed_time", "num_requests", "total_num_tokens"] + }, + ) if pt_records: # Don't use json suffix here as we don't want CI to pick it up pt_file = f"{os.path.splitext(args.output_json)[0]}.pytorch.json" @@ -315,7 +347,8 @@ def get_requests(args, tokenizer): sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_name == "sonnet": assert tokenizer.chat_template or tokenizer.default_chat_template, ( - "Tokenizer/model must have chat template for sonnet dataset.") + "Tokenizer/model must have chat template for sonnet dataset." + ) dataset_cls = SonnetDataset sample_kwargs["prefix_len"] = args.prefix_len sample_kwargs["return_prompt_formatted"] = True @@ -324,21 +357,21 @@ def get_requests(args, tokenizer): elif args.dataset_name == "hf": if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS: dataset_cls = VisionArenaDataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in InstructCoderDataset.SUPPORTED_DATASET_PATHS: dataset_cls = InstructCoderDataset - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_split"] = "train" elif args.dataset_path in ConversationDataset.SUPPORTED_DATASET_PATHS: dataset_cls = ConversationDataset - common_kwargs['dataset_subset'] = args.hf_subset - common_kwargs['dataset_split'] = args.hf_split + common_kwargs["dataset_subset"] = args.hf_subset + common_kwargs["dataset_split"] = args.hf_split sample_kwargs["enable_multimodal_chat"] = True elif args.dataset_path in AIMODataset.SUPPORTED_DATASET_PATHS: dataset_cls = AIMODataset - common_kwargs['dataset_subset'] = None - common_kwargs['dataset_split'] = "train" + common_kwargs["dataset_subset"] = None + common_kwargs["dataset_split"] = "train" else: raise ValueError(f"Unknown dataset name: {args.dataset_name}") # Remove None values @@ -353,10 +386,10 @@ def main(args: argparse.Namespace): random.seed(args.seed) # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer, trust_remote_code=args.trust_remote_code) + args.tokenizer, trust_remote_code=args.trust_remote_code + ) requests = get_requests(args, tokenizer) - is_multi_modal = any(request.multi_modal_data is not None - for request in requests) + is_multi_modal = any(request.multi_modal_data is not None for request in requests) request_outputs: Optional[list[RequestOutput]] = None if args.backend == "vllm": if args.async_engine: @@ -367,23 +400,34 @@ def main(args: argparse.Namespace): AsyncEngineArgs.from_cli_args(args), args.disable_frontend_multiprocessing, args.disable_detokenize, - )) + ) + ) else: elapsed_time, request_outputs = run_vllm( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, + args.n, + EngineArgs.from_cli_args(args), + args.disable_detokenize, + ) elif args.backend == "hf": assert args.tensor_parallel_size == 1 - elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.hf_max_batch_size, args.trust_remote_code, - args.disable_detokenize) + elapsed_time = run_hf( + requests, + args.model, + tokenizer, + args.n, + args.hf_max_batch_size, + args.trust_remote_code, + args.disable_detokenize, + ) elif args.backend == "mii": - elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, - args.output_len) + elapsed_time = run_mii( + requests, args.model, args.tensor_parallel_size, args.output_len + ) elif args.backend == "vllm-chat": elapsed_time, request_outputs = run_vllm_chat( - requests, args.n, EngineArgs.from_cli_args(args), - args.disable_detokenize) + requests, args.n, EngineArgs.from_cli_args(args), args.disable_detokenize + ) else: raise ValueError(f"Unknown backend: {args.backend}") @@ -395,28 +439,31 @@ def main(args: argparse.Namespace): for ro in request_outputs: if not isinstance(ro, RequestOutput): continue - total_prompt_tokens += len( - ro.prompt_token_ids) if ro.prompt_token_ids else 0 - total_output_tokens += sum( - len(o.token_ids) for o in ro.outputs if o) + total_prompt_tokens += ( + len(ro.prompt_token_ids) if ro.prompt_token_ids else 0 + ) + total_output_tokens += sum(len(o.token_ids) for o in ro.outputs if o) total_num_tokens = total_prompt_tokens + total_output_tokens else: - total_num_tokens = sum(r.prompt_len + r.expected_output_len - for r in requests) + total_num_tokens = sum(r.prompt_len + r.expected_output_len for r in requests) total_output_tokens = sum(r.expected_output_len for r in requests) total_prompt_tokens = total_num_tokens - total_output_tokens if is_multi_modal and args.backend != "vllm-chat": - print("\033[91mWARNING\033[0m: Multi-modal request with " - f"{args.backend} backend detected. The " - "following metrics are not accurate because image tokens are not" - " counted. See vllm-project/vllm/issues/9778 for details.") + print( + "\033[91mWARNING\033[0m: Multi-modal request with " + f"{args.backend} backend detected. The " + "following metrics are not accurate because image tokens are not" + " counted. See vllm-project/vllm/issues/9778 for details." + ) # TODO(vllm-project/vllm/issues/9778): Count multi-modal token length. # vllm-chat backend counts the image tokens now - print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " - f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " - f"{total_output_tokens / elapsed_time:.2f} output tokens/s") + print( + f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} total tokens/s, " + f"{total_output_tokens / elapsed_time:.2f} output tokens/s" + ) print(f"Total num prompt tokens: {total_prompt_tokens}") print(f"Total num output tokens: {total_output_tokens}") @@ -444,7 +491,8 @@ def validate_args(args): warnings.warn( "The '--dataset' argument will be deprecated in the next release. " "Please use '--dataset-name' and '--dataset-path' instead.", - stacklevel=2) + stacklevel=2, + ) args.dataset_path = args.dataset if not getattr(args, "tokenizer", None): @@ -457,9 +505,8 @@ def validate_args(args): # === Dataset Configuration === if not args.dataset and not args.dataset_path: - print( - "When dataset path is not set, it will default to random dataset") - args.dataset_name = 'random' + print("When dataset path is not set, it will default to random dataset") + args.dataset_name = "random" if args.input_len is None: raise ValueError("input_len must be provided for a random dataset") @@ -467,41 +514,55 @@ def validate_args(args): # --hf-subset and --hf-split: only used # when dataset_name is 'hf' if args.dataset_name != "hf" and ( - getattr(args, "hf_subset", None) is not None - or getattr(args, "hf_split", None) is not None): - warnings.warn("--hf-subset and --hf-split will be ignored \ + getattr(args, "hf_subset", None) is not None + or getattr(args, "hf_split", None) is not None + ): + warnings.warn( + "--hf-subset and --hf-split will be ignored \ since --dataset-name is not 'hf'.", - stacklevel=2) + stacklevel=2, + ) elif args.dataset_name == "hf": if args.dataset_path in ( - VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() - | ConversationDataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm-chat", f"{args.dataset_path} needs to use vllm-chat as the backend." #noqa: E501 - elif args.dataset_path in (InstructCoderDataset.SUPPORTED_DATASET_PATHS - | AIMODataset.SUPPORTED_DATASET_PATHS): - assert args.backend == "vllm", f"{args.dataset_path} needs to use vllm as the backend." #noqa: E501 + VisionArenaDataset.SUPPORTED_DATASET_PATHS.keys() + | ConversationDataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm-chat", ( + f"{args.dataset_path} needs to use vllm-chat as the backend." + ) # noqa: E501 + elif args.dataset_path in ( + InstructCoderDataset.SUPPORTED_DATASET_PATHS + | AIMODataset.SUPPORTED_DATASET_PATHS + ): + assert args.backend == "vllm", ( + f"{args.dataset_path} needs to use vllm as the backend." + ) # noqa: E501 else: - raise ValueError( - f"{args.dataset_path} is not supported by hf dataset.") + raise ValueError(f"{args.dataset_path} is not supported by hf dataset.") # --random-range-ratio: only used when dataset_name is 'random' - if args.dataset_name != 'random' and args.random_range_ratio is not None: - warnings.warn("--random-range-ratio will be ignored since \ + if args.dataset_name != "random" and args.random_range_ratio is not None: + warnings.warn( + "--random-range-ratio will be ignored since \ --dataset-name is not 'random'.", - stacklevel=2) + stacklevel=2, + ) # --prefix-len: only used when dataset_name is 'random', 'sonnet', or not # set. - if args.dataset_name not in {"random", "sonnet", None - } and args.prefix_len is not None: - warnings.warn("--prefix-len will be ignored since --dataset-name\ + if ( + args.dataset_name not in {"random", "sonnet", None} + and args.prefix_len is not None + ): + warnings.warn( + "--prefix-len will be ignored since --dataset-name\ is not 'random', 'sonnet', or not set.", - stacklevel=2) + stacklevel=2, + ) # === LoRA Settings === if getattr(args, "enable_lora", False) and args.backend != "vllm": - raise ValueError( - "LoRA benchmarking is only supported for vLLM backend") + raise ValueError("LoRA benchmarking is only supported for vLLM backend") if getattr(args, "enable_lora", False) and args.lora_path is None: raise ValueError("LoRA path must be provided when enable_lora is True") @@ -511,8 +572,10 @@ def validate_args(args): if args.backend != "hf" and args.hf_max_batch_size is not None: raise ValueError("HF max batch size is only for HF backend.") - if args.backend in {"hf", "mii"} and getattr(args, "quantization", - None) is not None: + if ( + args.backend in {"hf", "mii"} + and getattr(args, "quantization", None) is not None + ): raise ValueError("Quantization is only for vLLM backend.") if args.backend == "mii" and args.dtype != "auto": @@ -520,29 +583,32 @@ def validate_args(args): if args.backend == "mii" and args.n != 1: raise ValueError("n must be 1 for MII backend.") if args.backend == "mii" and args.tokenizer != args.model: - raise ValueError( - "Tokenizer must be the same as the model for MII backend.") + raise ValueError("Tokenizer must be the same as the model for MII backend.") # --data-parallel is not supported currently. # https://github.com/vllm-project/vllm/issues/16222 if args.data_parallel_size > 1: raise ValueError( "Data parallel is not supported in offline benchmark, \ - please use benchmark serving instead") + please use benchmark serving instead" + ) if __name__ == "__main__": parser = FlexibleArgumentParser(description="Benchmark the throughput.") - parser.add_argument("--backend", - type=str, - choices=["vllm", "hf", "mii", "vllm-chat"], - default="vllm") + parser.add_argument( + "--backend", + type=str, + choices=["vllm", "hf", "mii", "vllm-chat"], + default="vllm", + ) parser.add_argument( "--dataset-name", type=str, choices=["sharegpt", "random", "sonnet", "burstgpt", "hf"], help="Name of the dataset to benchmark on.", - default="sharegpt") + default="sharegpt", + ) parser.add_argument( "--dataset", type=str, @@ -550,57 +616,70 @@ if __name__ == "__main__": help="Path to the ShareGPT dataset, will be deprecated in\ the next release. The dataset is expected to " "be a json in form of list[dict[..., conversations: " - "list[dict[..., value: ]]]]") - parser.add_argument("--dataset-path", - type=str, - default=None, - help="Path to the dataset") - parser.add_argument("--input-len", - type=int, - default=None, - help="Input prompt length for each request") - parser.add_argument("--output-len", - type=int, - default=None, - help="Output length for each request. Overrides the " - "output length from the dataset.") - parser.add_argument("--n", - type=int, - default=1, - help="Number of generated sequences per prompt.") - parser.add_argument("--num-prompts", - type=int, - default=1000, - help="Number of prompts to process.") - parser.add_argument("--hf-max-batch-size", - type=int, - default=None, - help="Maximum batch size for HF backend.") + "list[dict[..., value: ]]]]", + ) + parser.add_argument( + "--dataset-path", type=str, default=None, help="Path to the dataset" + ) + parser.add_argument( + "--input-len", + type=int, + default=None, + help="Input prompt length for each request", + ) + parser.add_argument( + "--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.", + ) + parser.add_argument( + "--n", type=int, default=1, help="Number of generated sequences per prompt." + ) parser.add_argument( - '--output-json', + "--num-prompts", type=int, default=1000, help="Number of prompts to process." + ) + parser.add_argument( + "--hf-max-batch-size", + type=int, + default=None, + help="Maximum batch size for HF backend.", + ) + parser.add_argument( + "--output-json", type=str, default=None, - help='Path to save the throughput results in JSON format.') - parser.add_argument("--async-engine", - action='store_true', - default=False, - help="Use vLLM async engine rather than LLM class.") - parser.add_argument("--disable-frontend-multiprocessing", - action='store_true', - default=False, - help="Disable decoupled async engine frontend.") + help="Path to save the throughput results in JSON format.", + ) + parser.add_argument( + "--async-engine", + action="store_true", + default=False, + help="Use vLLM async engine rather than LLM class.", + ) + parser.add_argument( + "--disable-frontend-multiprocessing", + action="store_true", + default=False, + help="Disable decoupled async engine frontend.", + ) parser.add_argument( "--disable-detokenize", action="store_true", - help=("Do not detokenize the response (i.e. do not include " - "detokenization time in the measurement)")) + help=( + "Do not detokenize the response (i.e. do not include " + "detokenization time in the measurement)" + ), + ) # LoRA parser.add_argument( "--lora-path", type=str, default=None, - help="Path to the lora adapters to use. This can be an absolute path, " - "a relative path, or a Hugging Face model identifier.") + help="Path to the LoRA adapters to use. This can be an absolute path, " + "a relative path, or a Hugging Face model identifier.", + ) parser.add_argument( "--prefix-len", type=int, @@ -614,7 +693,8 @@ if __name__ == "__main__": f"prefix_len (default: {SonnetDataset.DEFAULT_PREFIX_LEN}) " "controls how much of the input is fixed lines versus " "random lines, but the total input length remains approximately " - "input_len tokens.") + "input_len tokens.", + ) # random dataset parser.add_argument( "--random-range-ratio", @@ -628,14 +708,12 @@ if __name__ == "__main__": ) # hf dtaset - parser.add_argument("--hf-subset", - type=str, - default=None, - help="Subset of the HF dataset.") - parser.add_argument("--hf-split", - type=str, - default=None, - help="Split of the HF dataset.") + parser.add_argument( + "--hf-subset", type=str, default=None, help="Subset of the HF dataset." + ) + parser.add_argument( + "--hf-split", type=str, default=None, help="Split of the HF dataset." + ) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index 45a0ddbd5d08d65a2ecf7200bb179eff0e792ce1..b0c4fca92c3d0035904691d5a1fc9b99e741b7a0 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -7,9 +7,9 @@ import os from typing import Any -def convert_to_pytorch_benchmark_format(args: argparse.Namespace, - metrics: dict[str, list], - extra_info: dict[str, Any]) -> list: +def convert_to_pytorch_benchmark_format( + args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] +) -> list: """ Save the benchmark results in the format used by PyTorch OSS benchmark with on metric per record @@ -37,12 +37,12 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, }, } - tp = record["benchmark"]["extra_info"]["args"].get( - "tensor_parallel_size") + tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") # Save tensor_parallel_size parameter if it's part of the metadata if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"][ - "tensor_parallel_size"] = extra_info["tensor_parallel_size"] + record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( + extra_info["tensor_parallel_size"] + ) records.append(record) @@ -50,7 +50,6 @@ def convert_to_pytorch_benchmark_format(args: argparse.Namespace, class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): if isinstance(o, dict): return {k: self.clear_inf(v) for k, v in o.items()} diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index 9e36b0a9d3bb959457e2b6fec6772552f0fd7eb8..da258f98e085f973110dc623111012f5f6b61b93 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -23,8 +23,9 @@ DEFAULT_TP_SIZES = [1] # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -41,16 +42,18 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, ).blocked_autorange(min_run_time=min_run_time) -def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_int8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.int8 b_compressed, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -63,54 +66,107 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16), - b.to(dtype=torch.bfloat16))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16), + ) + ) # pytorch impl - float16 timers.append( - bench_fn(label, sub_label, - "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, - a.to(dtype=torch.float16), b.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.float16), + b.to(dtype=torch.float16), + ) + ) # cutlass impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, - bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass sparse impl timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass sparse with bias timers.append( - bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_i8_i8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) return timers -def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench_fp8( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: assert dtype == torch.float8_e4m3fn - b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, - k) + b_compressed, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) - out = ops.cutlass_scaled_sparse_mm(a, b_compressed, e, scale_a, scale_b, - torch.bfloat16) + out = ops.cutlass_scaled_sparse_mm( + a, b_compressed, e, scale_a, scale_b, torch.bfloat16 + ) out_ref = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16) if not torch.allclose(out, out_ref): @@ -124,97 +180,165 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, # pytorch impl w. bf16 timers.append( - bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", - torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"))) + bench_fn( + label, + sub_label, + "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, + a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"), + ) + ) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + ) + ) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True, + ) + ) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + ) + ) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(label, - sub_label, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", - torch._scaled_mm, - a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=torch.float16, - use_fast_accum=True)) + bench_fn( + label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", - ops.cutlass_scaled_mm, a, b, scale_a, scale_b, - torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, + a, + b, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: bf16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + ) + ) # cutlass impl: fp16 output timers.append( - bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_sparse_mm", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + ) + ) # cutlass impl: bf16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.bfloat16, bias)) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_bf16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.bfloat16, + bias, + ) + ) # cutlass impl: fp16 output, with bias timers.append( - bench_fn(label, sub_label, - "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", - ops.cutlass_scaled_sparse_mm, a, b_compressed, e, scale_a, - scale_b, torch.float16, bias.to(dtype=torch.float16))) + bench_fn( + label, + sub_label, + "cutlass_fp8_fp8_fp16_scaled_sparse_mm_bias", + ops.cutlass_scaled_sparse_mm, + a, + b_compressed, + e, + scale_a, + scale_b, + torch.float16, + bias.to(dtype=torch.float16), + ) + ) return timers -def bench(dtype: torch.dtype, m: int, k: int, n: int, label: str, - sub_label: str) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, m: int, k: int, n: int, label: str, sub_label: str +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label) if dtype == torch.float8_e4m3fn: @@ -228,12 +352,12 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, MKNs: Iterable[tuple[int, int, int]] +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})") + timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", f"MKN=({m}x{k}x{n})") print_timers(timers) results.extend(timers) @@ -241,10 +365,12 @@ def run(dtype: torch.dtype, # output makers -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -258,8 +384,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs) @@ -319,7 +444,7 @@ def run_model_bench(args): pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -344,12 +469,15 @@ Benchmark Cutlass GEMM. Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) - - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) subparsers = parser.add_subparsers(dest="cmd") square_parser = subparsers.add_parser("square_bench") @@ -368,19 +496,19 @@ Benchmark Cutlass GEMM. range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index fe4d8fdfc0669b8d3f4723fbbda83edf10c40d00..7e9f5a7fc0f464718e15e1cc024df89081da74f9 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -10,8 +10,9 @@ import vllm._custom_ops as ops def to_fp8(tensor: torch.Tensor) -> torch.Tensor: finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) def to_int8(tensor: torch.Tensor) -> torch.Tensor: @@ -26,10 +27,11 @@ def to_fp16(tensor: torch.Tensor) -> torch.Tensor: return tensor.to(dtype=torch.float16) -def make_rand_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 if dtype == torch.int8: return to_int8(a), to_int8(b) @@ -49,9 +51,7 @@ def prune_to_2_4(tensor): # Create binary mask mask = torch.zeros_like(reshaped) - mask.scatter_(dim=1, - index=indices, - src=torch.ones_like(indices, dtype=mask.dtype)) + mask.scatter_(dim=1, index=indices, src=torch.ones_like(indices, dtype=mask.dtype)) # Apply mask and reshape back pruned = reshaped * mask @@ -62,10 +62,11 @@ def prune_to_2_4(tensor): return pruned.reshape(original_shape) -def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, - k: int) -> tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 - b = torch.randn((n, k), device='cuda').t() * 5 +def make_rand_sparse_tensors( + dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[torch.Tensor, torch.Tensor]: + a = torch.randn((m, k), device="cuda") * 5 + b = torch.randn((n, k), device="cuda").t() * 5 b = prune_to_2_4(b.t()).t() @@ -86,9 +87,9 @@ def make_rand_sparse_tensors(dtype: torch.dtype, m: int, n: int, return b_compressed, e, a, b -def make_n_rand_sparse_tensors(num_tensors: int, dtype: torch.dtype, - m: int, n: int, k: int) -> \ - tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: +def make_n_rand_sparse_tensors( + num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int +) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: ABs = [] for _ in range(num_tensors): b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index e7b742d8bec9363576176573fed8db265a4a6efb..08e93837f7ddff3da18c37c482e527695dff75b2 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -16,7 +16,8 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - w8a8_block_fp8_matmul) + w8a8_block_fp8_matmul, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) @@ -25,8 +26,9 @@ DEFAULT_TP_SIZES = [1] # bench -def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, - **kwargs) -> TMeasurement: +def bench_fn( + label: str, sub_label: str, description: str, fn: Callable, *args, **kwargs +) -> TMeasurement: min_run_time = 1 globals = { @@ -44,45 +46,48 @@ def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, def bench_int8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark INT8-based kernels.""" assert dtype == torch.int8 a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) - azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) - azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m,), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n,), device="cuda", dtype=torch.int32) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "cutlass_i8_i8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_i8_i8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_i8_i8_bf16_scaled_mm_azp": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj), - "cutlass_i8_i8_bf16_scaled_mm_azp_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, None, bias), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp), - "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": - lambda: ops.cutlass_scaled_mm_azp(a, b, scale_a, scale_b, torch. - bfloat16, azp_adj, azp, bias), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "cutlass_i8_i8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_i8_i8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, None, bias + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp + ), + "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias": lambda: ops.cutlass_scaled_mm_azp( + a, b, scale_a, scale_b, torch.bfloat16, azp_adj, azp, bias + ), } timers = [] @@ -96,73 +101,73 @@ def bench_int8( def bench_fp8( - dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: """Benchmark FP8-based kernels.""" assert dtype == torch.float8_e4m3fn a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - block_scale_a = torch.rand((m, k // 128), - device="cuda", - dtype=torch.float32) - block_scale_b = torch.rand((k // 128, n // 128), - device="cuda", - dtype=torch.float32) + + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + block_scale_a = torch.rand( + (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 + ) + block_scale_b = torch.rand( + ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() - bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) print(m, k, n) bench_fns = { - "pytorch_bf16_bf16_bf16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) - ), - "pytorch_fp16_fp16_fp16_matmul-no-scales": - lambda: torch.mm(a.to(dtype=torch.float16), b.to(dtype=torch.float16)), - "pytorch_fp8_fp8_fp16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.float16), - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.float16, - use_fast_accum=True), - "pytorch_fp8_fp8_bf16_scaled_mm": - lambda: torch._scaled_mm( - a, b, scale_a, scale_b, out_dtype=torch.bfloat16), - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": - lambda: torch._scaled_mm(a, - b, - scale_a, - scale_b, - out_dtype=torch.bfloat16, - use_fast_accum=True), - "cutlass_fp8_fp8_bf16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16), - "cutlass_fp8_fp8_fp16_scaled_mm": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16), - "cutlass_fp8_fp8_bf16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.bfloat16, - bias), - "cutlass_fp8_fp8_fp16_scaled_mm_bias": - lambda: ops.cutlass_scaled_mm(a, b, scale_a, scale_b, torch.float16, - bias.to(dtype=torch.float16)), - "triton_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: w8a8_block_fp8_matmul(a_cont, b.t(), block_scale_a, - block_scale_b.t(), (128, 128)), - "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": - lambda: ops.cutlass_scaled_mm(a, b, block_scale_a_M_major, - block_scale_b_K_major, torch.float16), + "pytorch_bf16_bf16_bf16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.bfloat16), b.to(dtype=torch.bfloat16) + ), + "pytorch_fp16_fp16_fp16_matmul-no-scales": lambda: torch.mm( + a.to(dtype=torch.float16), b.to(dtype=torch.float16) + ), + "pytorch_fp8_fp8_fp16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16 + ), + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.float16, use_fast_accum=True + ), + "pytorch_fp8_fp8_bf16_scaled_mm": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16 + ), + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum": lambda: torch._scaled_mm( + a, b, scale_a, scale_b, out_dtype=torch.bfloat16, use_fast_accum=True + ), + "cutlass_fp8_fp8_bf16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16 + ), + "cutlass_fp8_fp8_fp16_scaled_mm": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16 + ), + "cutlass_fp8_fp8_bf16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.bfloat16, bias + ), + "cutlass_fp8_fp8_fp16_scaled_mm_bias": lambda: ops.cutlass_scaled_mm( + a, b, scale_a, scale_b, torch.float16, bias.to(dtype=torch.float16) + ), + "triton_fp8_fp8_fp16_scaled_mm_blockwise": lambda: w8a8_block_fp8_matmul( + a_cont, b.t(), block_scale_a, block_scale_b.t(), (128, 128) + ), + "cutlass_fp8_fp8_fp16_scaled_mm_blockwise": lambda: ops.cutlass_scaled_mm( + a, b, block_scale_a_M_major, block_scale_b_K_major, torch.float16 + ), } timers = [] @@ -175,13 +180,15 @@ def bench_fp8( return timers -def bench(dtype: torch.dtype, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def bench( + dtype: torch.dtype, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: if dtype == torch.int8: return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels) if dtype == torch.float8_e4m3fn: @@ -195,27 +202,33 @@ def print_timers(timers: Iterable[TMeasurement]): compare.print() -def run(dtype: torch.dtype, - MKNs: Iterable[tuple[int, int, int]], - bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]: +def run( + dtype: torch.dtype, + MKNs: Iterable[tuple[int, int, int]], + bench_kernels: Optional[list[str]] = None, +) -> Iterable[TMeasurement]: results = [] for m, k, n in MKNs: - timers = bench(dtype, - m, - k, - n, - f"scaled-{dtype}-gemm", - f"MKN=({m}x{k}x{n})", - bench_kernels=bench_kernels) + timers = bench( + dtype, + m, + k, + n, + f"scaled-{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + bench_kernels=bench_kernels, + ) print_timers(timers) results.extend(timers) return results -def make_output(data: Iterable[TMeasurement], - MKNs: Iterable[tuple[int, int, int]], - base_description: str, - timestamp=None): +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[tuple[int, int, int]], + base_description: str, + timestamp=None, +): print(f"== All Results {base_description} ====") print_timers(data) @@ -226,8 +239,7 @@ def make_output(data: Iterable[TMeasurement], def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, MKNs, bench_kernels=args.kernels) make_output(data, MKNs, f"square_bench-{args.dtype}") @@ -285,7 +297,7 @@ def run_model_bench(args): pkl.dump(all_data, f) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -310,19 +322,21 @@ Benchmark Cutlass GEMM. Output: - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) - parser.add_argument("--dtype", - type=to_torch_dtype, - required=True, - help="Available options are ['int8', 'fp8']") + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['int8', 'fp8']", + ) parser.add_argument( "--kernels", nargs="+", type=str, default=None, - help= - "Exact names of the kernels to benchmark. If not set, runs all kernels." + help="Exact names of the kernels to benchmark. If not set, runs all kernels.", ) subparsers = parser.add_subparsers(dest="cmd") @@ -343,19 +357,19 @@ Benchmark Cutlass GEMM. range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index 3d1121df40d01c4b051cb3ce6abac2ed0921a9ea..d31b623a1ee604c2bf0b4ab7cb90d37ffa463adb 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -42,4 +42,4 @@ WEIGHT_SHAPES = { ([8192, 57344], 1), ([28672, 8192], 0), ], -} \ No newline at end of file +} diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index 980e68668911f7bd28a5b7c5e87f9781966bfbfc..fce156e1c96c62bf0938facd623e40f6fc2a22a3 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -12,39 +12,37 @@ app = Quart(__name__) async def forward_request(url, data): async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - headers = { - "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" - } - async with session.post(url=url, json=data, - headers=headers) as response: + headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"} + async with session.post(url=url, json=data, headers=headers) as response: if response.status == 200: # if response.headers.get('Transfer-Encoding') == 'chunked': if True: - async for chunk_bytes in response.content.iter_chunked( - 1024): + async for chunk_bytes in response.content.iter_chunked(1024): yield chunk_bytes else: content = await response.read() yield content -@app.route('/v1/completions', methods=['POST']) +@app.route("/v1/completions", methods=["POST"]) async def handle_request(): try: original_request_data = await request.get_json() prefill_request = original_request_data.copy() # change max_tokens = 1 to let it only do prefill - prefill_request['max_tokens'] = 1 + prefill_request["max_tokens"] = 1 # finish prefill - async for _ in forward_request('http://localhost:8100/v1/completions', - prefill_request): + async for _ in forward_request( + "http://localhost:8100/v1/completions", prefill_request + ): continue # return decode - generator = forward_request('http://localhost:8200/v1/completions', - original_request_data) + generator = forward_request( + "http://localhost:8200/v1/completions", original_request_data + ) response = await make_response(generator) response.timeout = None @@ -53,11 +51,12 @@ async def handle_request(): except Exception as e: import sys import traceback + exc_info = sys.exc_info() print("Error occurred in disagg prefill proxy server") print(e) print("".join(traceback.format_exception(*exc_info))) -if __name__ == '__main__': +if __name__ == "__main__": app.run(port=8000) diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index c2ad4916bf0775ab4543afeb50ad24594cb65fee..fd19b40bf252c3076f317b5482e439de2c01ebbb 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -8,7 +8,6 @@ from aiohttp import web class RoundRobinProxy: - def __init__(self, target_ports): self.target_ports = target_ports self.port_cycle = itertools.cycle(self.target_ports) @@ -21,14 +20,15 @@ class RoundRobinProxy: try: # Forward the request async with session.request( - method=request.method, - url=target_url, - headers=request.headers, - data=request.content, + method=request.method, + url=target_url, + headers=request.headers, + data=request.content, ) as response: # Start sending the response - resp = web.StreamResponse(status=response.status, - headers=response.headers) + resp = web.StreamResponse( + status=response.status, headers=response.headers + ) await resp.prepare(request) # Stream the response content @@ -45,11 +45,11 @@ class RoundRobinProxy: async def main(): proxy = RoundRobinProxy([8100, 8200]) app = web.Application() - app.router.add_route('*', '/{path:.*}', proxy.handle_request) + app.router.add_route("*", "/{path:.*}", proxy.handle_request) runner = web.AppRunner(app) await runner.setup() - site = web.TCPSite(runner, 'localhost', 8000) + site = web.TCPSite(runner, "localhost", 8000) await site.start() print("Proxy server started on http://localhost:8000") @@ -58,5 +58,5 @@ async def main(): await asyncio.Event().wait() -if __name__ == '__main__': +if __name__ == "__main__": asyncio.run(main()) diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index a7b4b9e8bf302975458a675a1710e9eb653c1551..484d0cb3cba7d74db8dea04e68e17dde38be25e6 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -6,43 +6,41 @@ import matplotlib.pyplot as plt import pandas as pd if __name__ == "__main__": - data = [] - for name in ['disagg_prefill', 'chunked_prefill']: + for name in ["disagg_prefill", "chunked_prefill"]: for qps in [2, 4, 6, 8]: with open(f"results/{name}-qps-{qps}.json") as f: x = json.load(f) - x['name'] = name - x['qps'] = qps + x["name"] = name + x["qps"] = qps data.append(x) df = pd.DataFrame.from_dict(data) - dis_df = df[df['name'] == 'disagg_prefill'] - chu_df = df[df['name'] == 'chunked_prefill'] + dis_df = df[df["name"] == "disagg_prefill"] + chu_df = df[df["name"] == "chunked_prefill"] - plt.style.use('bmh') - plt.rcParams['font.size'] = 20 + plt.style.use("bmh") + plt.rcParams["font.size"] = 20 for key in [ - 'mean_ttft_ms', 'median_ttft_ms', 'p99_ttft_ms', 'mean_itl_ms', - 'median_itl_ms', 'p99_itl_ms' + "mean_ttft_ms", + "median_ttft_ms", + "p99_ttft_ms", + "mean_itl_ms", + "median_itl_ms", + "p99_itl_ms", ]: - fig, ax = plt.subplots(figsize=(11, 7)) - plt.plot(dis_df['qps'], - dis_df[key], - label='disagg_prefill', - marker='o', - linewidth=4) - plt.plot(chu_df['qps'], - chu_df[key], - label='chunked_prefill', - marker='o', - linewidth=4) + plt.plot( + dis_df["qps"], dis_df[key], label="disagg_prefill", marker="o", linewidth=4 + ) + plt.plot( + chu_df["qps"], chu_df[key], label="chunked_prefill", marker="o", linewidth=4 + ) ax.legend() - ax.set_xlabel('QPS') + ax.set_xlabel("QPS") ax.set_ylabel(key) ax.set_ylim(bottom=0) - fig.savefig(f'results/{key}.png') + fig.savefig(f"results/{key}.png") plt.close(fig) diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 3da583a334480f81a3f0edd18a22970273b0e21b..37a9173a1a937808271cc832ee49fe4172474601 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -24,10 +24,12 @@ class bench_params_t: dtype: torch.dtype def description(self): - return (f'N {self.num_tokens} ' - f'x D {self.hidden_size} ' - f'x R {self.add_residual} ' - f'x DT {self.dtype}') + return ( + f"N {self.num_tokens} " + f"x D {self.hidden_size} " + f"x R {self.add_residual} " + f"x DT {self.dtype}" + ) def get_bench_params() -> list[bench_params_t]: @@ -38,15 +40,19 @@ def get_bench_params() -> list[bench_params_t]: DTYPES = [torch.bfloat16, torch.float] combinations = product(NUM_TOKENS, HIDDEN_SIZES, ADD_RESIDUAL, DTYPES) - bench_params = list(map(lambda x: \ - bench_params_t(x[0], x[1], x[2], x[3]), combinations)) + bench_params = list( + map(lambda x: bench_params_t(x[0], x[1], x[2], x[3]), combinations) + ) return bench_params # Reference impls -def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_int8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -58,9 +64,12 @@ def unfused_int8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, torch_out, _, _ = ops.scaled_int8_quant(torch_out) -def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): +def unfused_fp8_impl( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): # Norm torch_out = None if residual is None: @@ -73,22 +82,27 @@ def unfused_fp8_impl(rms_norm_layer: RMSNorm, x: torch.Tensor, def fused_impl( - rms_norm_layer: RMSNorm, # this stores the weights - x: torch.Tensor, - residual: Optional[torch.Tensor], - quant_dtype: torch.dtype): - out, _ = ops.rms_norm_dynamic_per_token_quant(x, - rms_norm_layer.weight, - 1e-6, - quant_dtype, - residual=residual) + rms_norm_layer: RMSNorm, # this stores the weights + x: torch.Tensor, + residual: Optional[torch.Tensor], + quant_dtype: torch.dtype, +): + out, _ = ops.rms_norm_dynamic_per_token_quant( + x, rms_norm_layer.weight, 1e-6, quant_dtype, residual=residual + ) # Bench functions -def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, - quant_dtype: torch.dtype, label: str, sub_label: str, - fn: Callable, description: str) -> TMeasurement: - +def bench_fn( + rms_norm_layer: RMSNorm, + x: torch.Tensor, + residual: torch.Tensor, + quant_dtype: torch.dtype, + label: str, + sub_label: str, + fn: Callable, + description: str, +) -> TMeasurement: min_run_time = 1 globals = { @@ -106,43 +120,81 @@ def bench_fn(rms_norm_layer: RMSNorm, x: torch.Tensor, residual: torch.Tensor, description=description, ).blocked_autorange(min_run_time=min_run_time) -def bench(params: bench_params_t, label: str, sub_label: str) \ - -> Iterable[TMeasurement]: +def bench(params: bench_params_t, label: str, sub_label: str) -> Iterable[TMeasurement]: # Make inputs layer = RMSNorm(params.hidden_size, 1e-6).to(dtype=params.dtype) # Make weights layer.weight.data.normal_(mean=1.0, std=0.1) # Make inputs scale = 1 / params.hidden_size - x = torch.randn(params.num_tokens, - params.hidden_size, - dtype=params.dtype, - device='cuda') * scale - residual = (torch.randn_like(x) * scale).to(device='cuda') \ - if params.add_residual else None + x = ( + torch.randn( + params.num_tokens, params.hidden_size, dtype=params.dtype, device="cuda" + ) + * scale + ) + residual = ( + (torch.randn_like(x) * scale).to(device="cuda") if params.add_residual else None + ) timers = [] # unfused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, - unfused_int8_impl, "unfused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + unfused_int8_impl, + "unfused_int8_impl", + ) + ) # unfused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - unfused_fp8_impl, "unfused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + unfused_fp8_impl, + "unfused_fp8_impl", + ) + ) # fused int8 impl. timers.append( - bench_fn(layer, x, residual, torch.int8, label, sub_label, fused_impl, - "fused_int8_impl")) + bench_fn( + layer, + x, + residual, + torch.int8, + label, + sub_label, + fused_impl, + "fused_int8_impl", + ) + ) # fused fp8 impl. timers.append( - bench_fn(layer, x, residual, torch.float8_e4m3fn, label, sub_label, - fused_impl, "fused_fp8_impl")) + bench_fn( + layer, + x, + residual, + torch.float8_e4m3fn, + label, + sub_label, + fused_impl, + "fused_fp8_impl", + ) + ) print_timers(timers) @@ -157,13 +209,12 @@ def print_timers(timers: Iterable[TMeasurement]): def main(): - torch.set_default_device('cuda') + torch.set_default_device("cuda") bench_params = get_bench_params() timers = [] for bp in tqdm(bench_params): - timers.extend( - bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) + timers.extend(bench(bp, "rms-norm-dynamic-per-token-quant", bp.description())) print_timers(timers) # pickle all the results @@ -172,5 +223,5 @@ def main(): pkl.dump(timers, f) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index 8d20b91560dd62cb0c404e813da4d64fe48dda69..e9934aa479dd6bb21d0e6e45d5e15999c0d23dda 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -9,32 +9,39 @@ import torch.nn.functional as F from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.aqlm import ( - dequantize_weight, generic_dequantize_gemm, get_int_dtype, - optimized_dequantize_gemm) + dequantize_weight, + generic_dequantize_gemm, + get_int_dtype, + optimized_dequantize_gemm, +) from vllm.utils import FlexibleArgumentParser -os.environ['CUDA_VISIBLE_DEVICES'] = '0' +os.environ["CUDA_VISIBLE_DEVICES"] = "0" def torch_mult( - input: torch.Tensor, # [..., in_features] - weights: torch.Tensor, - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + weights: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, ) -> torch.Tensor: output = F.linear(input, weights) return output def dequant_out_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) if bias is None: @@ -46,40 +53,42 @@ def dequant_out_scale( flattened_output *= b_scales return flattened_output.view(orig_shape) else: - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_weight_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) - b_scales = scales.view(scales.shape[:-3] + (-1, )).expand( - -1, weights.shape[1]) + b_scales = scales.view(scales.shape[:-3] + (-1,)).expand(-1, weights.shape[1]) weights *= b_scales return F.linear(input, weights, bias) def dequant_no_scale( - input: torch.Tensor, # [..., in_features] - codes: torch.IntTensor, # [num_out_groups, num_in_groups, num_codebooks] - codebooks: torch. - Tensor, # [num_codebooks, codebook_size, out_group_size, in_group_size] - scales: torch.Tensor, # [num_out_groups, 1, 1, 1] + # [..., in_features] + input: torch.Tensor, + # [num_out_groups, num_in_groups, num_codebooks] + codes: torch.IntTensor, + # [num_codebooks, codebook_size, out_group_size, in_group_size] + codebooks: torch.Tensor, + # [num_out_groups, 1, 1, 1] + scales: torch.Tensor, output_partition_sizes: torch.IntTensor, bias: Optional[torch.Tensor], ) -> torch.Tensor: - weights = ops.aqlm_dequant(codes, codebooks, output_partition_sizes) return F.linear(input, weights, bias) @@ -89,23 +98,26 @@ def dequant_no_scale( # the generic pytorch version. # Just visual comparison. def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: - n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) count = 0 for index in range(16): @@ -138,24 +150,25 @@ def dequant_test(k: int, parts: torch.Tensor, nbooks: int, bits: int) -> None: def main(): - parser = FlexibleArgumentParser(description="Benchmark aqlm performance.") # Add arguments - parser.add_argument("--nbooks", - type=int, - default=1, - help="Number of codebooks (default: 1)") - parser.add_argument("--bits", - type=int, - default=16, - help="Number of bits per code element (default: 16)") + parser.add_argument( + "--nbooks", type=int, default=1, help="Number of codebooks (default: 1)" + ) + parser.add_argument( + "--bits", + type=int, + default=16, + help="Number of bits per code element (default: 16)", + ) parser.add_argument( "--test", type=bool, default=False, help="Run the decompression/dequant tester rather than benchmarking " - "(default: False)") + "(default: False)", + ) # Parse the arguments args = parser.parse_args() @@ -165,7 +178,7 @@ def main(): bits = args.bits if args.test: - dequant_test(4096, torch.tensor((4096, )), nbooks, bits) + dequant_test(4096, torch.tensor((4096,)), nbooks, bits) return # Otherwise, benchmark. @@ -184,31 +197,54 @@ def main(): with open(filename, "w") as f: sys.stdout = f - print('m | k | n | n parts', end='') + print("m | k | n | n parts", end="") for method in methods: - print(f" | {method.__name__.replace('_', ' ')} (µs)", end='') - print('') + print(f" | {method.__name__.replace('_', ' ')} (µs)", end="") + print("") # These are reasonable prefill sizes. - ksandpartions = ((4096, (4096, 4096, 4096)), (4096, (4096, )), - (4096, (11008, 11008)), (11008, (4096, ))) + ksandpartions = ( + (4096, (4096, 4096, 4096)), + (4096, (4096,)), + (4096, (11008, 11008)), + (11008, (4096,)), + ) # reasonable ranges for m. for m in [ - 1, 2, 4, 8, 10, 12, 14, 16, 24, 32, 48, 52, 56, 64, 96, 112, - 128, 256, 512, 1024, 1536, 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 10, + 12, + 14, + 16, + 24, + 32, + 48, + 52, + 56, + 64, + 96, + 112, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ]: - print(f'{m}', file=sys.__stdout__) + print(f"{m}", file=sys.__stdout__) for ksp in ksandpartions: - run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, - methods) + run_grid(m, ksp[0], torch.tensor(ksp[1]), nbooks, bits, methods) sys.stdout = sys.__stdout__ -def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, - methods): - +def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, methods): # I didn't see visible improvements from increasing these, but feel free :) num_warmup_trials = 1 num_trials = 1 @@ -229,7 +265,7 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, ) n = parts.sum().item() - print(f'{m} | {k} | {n} | {parts.tolist()}', end='') + print(f"{m} | {k} | {n} | {parts.tolist()}", end="") for method in methods: best_time_us = 1e20 @@ -249,32 +285,36 @@ def run_grid(m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, if kernel_dur_us < best_time_us: best_time_us = kernel_dur_us - print(f' | {kernel_dur_us:.0f}', end='') + print(f" | {kernel_dur_us:.0f}", end="") - print('') + print("") -def run_timing(num_calls: int, m: int, k: int, parts: torch.Tensor, - nbooks: int, bits: int, method) -> float: - +def run_timing( + num_calls: int, m: int, k: int, parts: torch.Tensor, nbooks: int, bits: int, method +) -> float: n = int(parts.sum().item()) - device = torch.device('cuda:0') + device = torch.device("cuda:0") input = torch.randn((1, m, k), dtype=torch.float16, device=device) code_range = (1 << bits) // 2 ingroups = 8 - codes = torch.randint(-code_range, - code_range, - size=(n, k // ingroups, nbooks), - dtype=get_int_dtype(bits), - device=device) - - codebooks = torch.randn(size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), - dtype=torch.float16, - device=device) + codes = torch.randint( + -code_range, + code_range, + size=(n, k // ingroups, nbooks), + dtype=get_int_dtype(bits), + device=device, + ) + + codebooks = torch.randn( + size=(parts.shape[0] * nbooks, 1 << bits, 1, 8), + dtype=torch.float16, + device=device, + ) scales = torch.randn(size=(n, 1, 1, 1), dtype=torch.float16, device=device) diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index b23b4f3ea685aa1c07b2d9f26c56102ba16785c1..d40ab70ec539b27af09739167a7c11900e6e936d 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -3,27 +3,33 @@ # Licensed under the MIT License. from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( - MINIMUM_BITBLAS_VERSION) + MINIMUM_BITBLAS_VERSION, +) try: import bitblas + if bitblas.__version__ < MINIMUM_BITBLAS_VERSION: - raise ImportError("bitblas version is wrong. Please " - f"install bitblas>={MINIMUM_BITBLAS_VERSION}") + raise ImportError( + "bitblas version is wrong. Please " + f"install bitblas>={MINIMUM_BITBLAS_VERSION}" + ) except ImportError as e: bitblas_import_exception = e - raise ValueError("Trying to use the bitblas backend, but could not import" - f"with the following error: {bitblas_import_exception}. " - "Please install bitblas through the following command: " - f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" - ) from bitblas_import_exception + raise ValueError( + "Trying to use the bitblas backend, but could not import" + f"with the following error: {bitblas_import_exception}. " + "Please install bitblas through the following command: " + f"`pip install bitblas>={MINIMUM_BITBLAS_VERSION}`" + ) from bitblas_import_exception from bitblas import Matmul, MatmulConfig, auto_detect_nvidia_target from vllm.utils import FlexibleArgumentParser parser = FlexibleArgumentParser( - description="Benchmark BitBLAS int4 on a specific target.") + description="Benchmark BitBLAS int4 on a specific target." +) # Add arguments to the parser parser.add_argument( @@ -32,10 +38,9 @@ parser.add_argument( default=auto_detect_nvidia_target(), help="Specify the target device for benchmarking.", ) -parser.add_argument("--group_size", - type=int, - default=None, - help="Group size for grouped quantization.") +parser.add_argument( + "--group_size", type=int, default=None, help="Group size for grouped quantization." +) parser.add_argument( "--A_dtype", type=str, @@ -82,17 +87,17 @@ parser.add_argument( choices=["nt", "nn"], help="Matrix layout, 'nt' for non-transpose A and transpose W.", ) -parser.add_argument("--with_bias", - action="store_true", - help="Include bias in the benchmark.") +parser.add_argument( + "--with_bias", action="store_true", help="Include bias in the benchmark." +) parser.add_argument( "--with_scaling", action="store_true", help="Include scaling factor in the quantization.", ) -parser.add_argument("--with_zeros", - action="store_true", - help="Include zeros in the quantization.") +parser.add_argument( + "--with_zeros", action="store_true", help="Include zeros in the quantization." +) parser.add_argument( "--zeros_mode", type=str, @@ -170,8 +175,7 @@ shapes = [ ] # Build test shapes with all the shared arguments -test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) - for shape in shapes] +test_shapes = [(MatmulConfig, Matmul, (*shape, *shared_args)) for shape in shapes] benchmark_sets = [] benchmark_sets.extend(test_shapes) @@ -206,12 +210,12 @@ for config_key, values in benchmark_results.items(): func_name = args_split[0] input_args_str = "-".join(args_split[1:]) col_widths[0] = max(col_widths[0], len(func_name) + 2, len(headers[0]) + 2) - col_widths[1] = max(col_widths[1], - len(input_args_str) + 2, - len(headers[1]) + 2) - col_widths[2] = max(col_widths[2], - len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, - len(headers[2]) + 2) + col_widths[1] = max(col_widths[1], len(input_args_str) + 2, len(headers[1]) + 2) + col_widths[2] = max( + col_widths[2], + len(f"{values['BitBLAS_top20_latency']:.3f} ms") + 2, + len(headers[2]) + 2, + ) # break only if you want to measure widths from a single example; # otherwise, let it loop over all items. @@ -232,5 +236,6 @@ for config_key, values in benchmark_results.items(): f"{values['BitBLAS_top20_latency']:.3f} ms", ] row_str = "".join( - [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)]) + [str(cell).ljust(col_widths[idx]) for idx, cell in enumerate(row)] + ) print(row_str) diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..d39d8a6e3aba31612817e62f50ed8bb911fb66e5 --- /dev/null +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -0,0 +1,489 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe +kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit +activations. The triton_moe kernel takes in fp8 weights(tensor scaled to fp8) +and 16-bit activations. +""" + +import nvtx +import torch +import torch.utils.benchmark as benchmark + +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts, fused_topk +from vllm.scalar_type import scalar_types +from vllm.utils import FlexibleArgumentParser + +WEIGHT_SHAPES_MOE = { + "nvidia/DeepSeek-R1-FP4": [ + [256, 8, 2048, 7168], + ], +} + +DEFAULT_MODELS = [ + "nvidia/DeepSeek-R1-FP4", +] + +DEFAULT_BATCH_SIZES = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048] +DEFAULT_TP_SIZES = [1] + +PER_ACT_TOKEN_OPTS = [False] +PER_OUT_CH_OPTS = [False] +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + + +def to_fp8(tensor: torch.Tensor): + finfo = torch.finfo(torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) + + +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): + label = "NVFP4 Blockscaled CUTLASS MOE vs FP8 Tensor Scaled Triton" + + sub_label = ( + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) + + print(f"Testing: {sub_label}") + + (m, k, n) = mkn + + dtype = torch.half + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + w1 = torch.randn((num_experts, 2 * n, k), device=device, dtype=dtype) / 10 + w2 = torch.randn((num_experts, k, n), device=device, dtype=dtype) / 10 + + _, a_fp8_scale = ops.scaled_fp8_quant(a) + + w1_fp8q = torch.empty( + (num_experts, 2 * n, k), device=device, dtype=torch.float8_e4m3fn + ) + w2_fp8q = torch.empty((num_experts, k, n), device=device, dtype=torch.float8_e4m3fn) + w1_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + w2_fp8scale = torch.empty((num_experts, 1, 1), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_fp8q[expert], w1_fp8scale[expert] = ops.scaled_fp8_quant(w1[expert]) + w2_fp8q[expert], w2_fp8scale[expert] = ops.scaled_fp8_quant(w2[expert]) + + w1_fp8q_notransp = w1_fp8q.clone() + w2_fp8q_notransp = w2_fp8q.clone() + w1_fp8q = w1_fp8q.transpose(1, 2) + w2_fp8q = w2_fp8q.transpose(1, 2) + + score = torch.randn((m, num_experts), device=device, dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + quant_blocksize = 16 + w1_blockscale = torch.empty( + (num_experts, 2 * n, k // quant_blocksize), + device=device, + dtype=torch.float8_e4m3fn, + ) + w2_blockscale = torch.empty( + (num_experts, k, n // quant_blocksize), device=device, dtype=torch.float8_e4m3fn + ) + + # n_b_scales = 2 * n if per_out_ch else 1 + # k_b_scales = k if per_out_ch else 1 + w1_fp4 = torch.empty((num_experts, 2 * n, k // 2), device=device, dtype=torch.uint8) + w2_fp4 = torch.empty((num_experts, k, n // 2), device=device, dtype=torch.uint8) + + w1_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + w2_gs = torch.empty((num_experts,), device=device, dtype=torch.float32) + a1_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + a2_gs = torch.ones((num_experts,), device=device, dtype=torch.float32) + + for expert in range(num_experts): + w1_e = w1[expert] + w2_e = w2[expert] + w1_amax = torch.abs(w1_e).max().to(torch.float32) + w2_amax = torch.abs(w2_e).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_fp4[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1_e, w1_gs[expert] + ) + + w2_fp4[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2_e, w2_gs[expert] + ) + + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + num_repeats: int, + ): + for _ in range(num_repeats): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def run_cutlass_moe_fp4( + a: torch.Tensor, + w1_fp4: torch.Tensor, + w2_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w2_blockscale: torch.Tensor, + w1_gs: torch.Tensor, + w2_gs: torch.Tensor, + a1_gs: torch.Tensor, + a2_gs: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + num_repeats: int, + ): + for _ in range(num_repeats): + with nvtx.annotate("cutlass_moe_fp4", color="green"): + cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + a2_gscale=a2_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_cutlass_from_graph( + a: torch.Tensor, + a1_gscale: torch.Tensor, + w1_fp4: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alphas: torch.Tensor, + a2_gscale: torch.Tensor, + w2_fp4: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alphas: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + m: int, + n: int, + k: int, + e: int, + device: torch.device, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_alphas, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_alphas, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_fp8_scale: torch.Tensor, + ): + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_fp8_scale, + ) + + def replay_graph(graph, num_repeats): + for _ in range(num_repeats): + graph.replay() + torch.cuda.synchronize() + + cutlass_stream = torch.cuda.Stream() + cutlass_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): + run_cutlass_from_graph( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_fp4, + w1_blockscale=w1_blockscale, + w1_alphas=w1_gs, + a2_gscale=a2_gs, + w2_fp4=w2_fp4, + w2_blockscale=w2_blockscale, + w2_alphas=w2_gs, + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=num_experts, + device=device, + ) + torch.cuda.synchronize() + + triton_stream = torch.cuda.Stream() + triton_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(triton_graph, stream=triton_stream): + run_triton_from_graph( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + ) + torch.cuda.synchronize() + + min_run_time = 5 + num_warmup = 5 + num_runs = 25 + + globals = { + # Baseline params + "w1": w1, + "w2": w2, + "score": score, + "topk": topk, + "w1_fp8q_notransp": w1_fp8q_notransp, + "w2_fp8q_notransp": w2_fp8q_notransp, + "w1_fp8scale": w1_fp8scale, + "w2_fp8scale": w2_fp8scale, + "a_fp8_scale": a_fp8_scale, + # Cutlass params + "a": a, + "a1_gscale": a1_gs, + "w1_fp4": w1_fp4, + "w1_blockscale": w1_blockscale, + "w1_alphas": w1_gs, + "a2_gscale": a2_gs, + "w2_fp4": w2_fp4, + "w2_blockscale": w2_blockscale, + "w2_alphas": w2_gs, + "topk_weights": topk_weights, + "topk_ids": topk_ids, + "m": m, + "n": n, + "k": k, + "e": num_experts, + "device": device, + # cuda graph params + "cutlass_graph": cutlass_graph, + "triton_graph": triton_graph, + # Gen params + "num_runs": num_runs, + # Kernels + "run_triton_moe": run_triton_moe, + "run_cutlass_moe_fp4": run_cutlass_moe_fp4, + "replay_graph": replay_graph, + } + + # Warmup + run_triton_moe( + a, + w1_fp8q_notransp, + w2_fp8q_notransp, + topk_weights, + topk_ids, + w1_fp8scale, + w2_fp8scale, + a_fp8_scale, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_triton_moe(a, w1_fp8q_notransp, w2_fp8q_notransp, topk_weights, topk_ids, w1_fp8scale, w2_fp8scale, a_fp8_scale, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(triton_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(triton_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="triton_moe_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + + run_cutlass_moe_fp4( + a, + w1_fp4, + w2_fp4, + w1_blockscale, + w2_blockscale, + w1_gs, + w2_gs, + a1_gs, + a2_gs, + topk_weights, + topk_ids, + m, + n, + k, + num_experts, + device, + num_warmup, + ) + + results.append( + benchmark.Timer( + stmt="run_cutlass_moe_fp4(a, w1_fp4, w2_fp4, w1_blockscale, w2_blockscale, w1_alphas, w2_alphas, a1_gscale, a2_gscale, topk_weights, topk_ids, m, n, k, e, device, num_runs)", # noqa: E501 + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4", + ).blocked_autorange(min_run_time=min_run_time) + ) + + # Warmup + replay_graph(cutlass_graph, num_warmup) + + results.append( + benchmark.Timer( + stmt="replay_graph(cutlass_graph, num_runs)", + globals=globals, + label=label, + sub_label=sub_label, + description="cutlass_moe_fp4_cuda_graphs", + ).blocked_autorange(min_run_time=min_run_time) + ) + + +def main(args): + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + results: list[benchmark.Measurement] = [] + + for model in args.models: + for tp in args.tp_sizes: + for layer in WEIGHT_SHAPES_MOE[model]: + num_experts = layer[0] + topk = layer[1] + size_k = layer[2] + size_n = layer[3] // tp + + if len(args.limit_k) > 0 and size_k not in args.limit_k: + continue + + if len(args.limit_n) > 0 and size_n not in args.limit_n: + continue + + for per_act_token in PER_ACT_TOKEN_OPTS: + for per_out_ch in PER_OUT_CH_OPTS: + for size_m in args.batch_sizes: + mkn = (size_m, size_k, size_n) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) + + compare = benchmark.Compare(results) + compare.print() + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description="Benchmark NVFP4 CUTLASS MOE across specified models/shapes/batches" + ) + parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES_MOE.keys(), + ) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + parser.add_argument("--limit-k", nargs="+", type=int, default=[]) + parser.add_argument("--limit-n", nargs="+", type=int, default=[]) + parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) + parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) + + args = parser.parse_args() + main(args) diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index bcdbf6c7551a32d449165522b87f8e8da85dba73..2197bceabe6c034eb591fae5308e7d483016c317 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -6,14 +6,18 @@ from benchmark_shapes import WEIGHT_SHAPES_MOE from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config -from vllm.model_executor.layers.fused_moe.fused_moe import (cutlass_moe_fp8, - fused_experts, - fused_topk) +from vllm.model_executor.layers.fused_moe.fused_moe import ( + cutlass_moe_fp8, + fused_experts, + fused_topk, +) from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = [ - "nm-testing/Mixtral-8x7B-Instruct-v0.1", "nm-testing/deepseekv2-lite", - "ibm-granite/granite-3.0-1b-a400m", "ibm-granite/granite-3.0-3b-a800m" + "nm-testing/Mixtral-8x7B-Instruct-v0.1", + "nm-testing/deepseekv2-lite", + "ibm-granite/granite-3.0-1b-a400m", + "ibm-granite/granite-3.0-3b-a800m", ] DEFAULT_BATCH_SIZES = [1, 4, 8, 16, 32, 64, 128, 256, 512] DEFAULT_TP_SIZES = [1] @@ -24,19 +28,27 @@ PER_OUT_CH_OPTS = [False] def to_fp8(tensor: torch.Tensor): finfo = torch.finfo(torch.float8_e4m3fn) - return torch.round(tensor.clamp( - min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn) + return torch.round(tensor.clamp(min=finfo.min, max=finfo.max)).to( + dtype=torch.float8_e4m3fn + ) -def bench_run(results: list[benchmark.Measurement], model: str, - num_experts: int, topk: int, per_act_token: bool, - per_out_ch: bool, mkn: tuple[int, int, int]): +def bench_run( + results: list[benchmark.Measurement], + model: str, + num_experts: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + mkn: tuple[int, int, int], +): label = "Quant Matmul" sub_label = ( - "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, " - "MKN=({})".format(model, num_experts, topk, per_act_token, per_out_ch, - mkn)) + "{}, num_experts={}, topk={}, per_act_token={} per_out_ch={}, MKN=({})".format( + model, num_experts, topk, per_act_token, per_out_ch, mkn + ) + ) print(f"Testing: {sub_label}") @@ -50,35 +62,17 @@ def bench_run(results: list[benchmark.Measurement], model: str, _, a_scale = ops.scaled_fp8_quant(a) - w1_q = torch.empty((num_experts, 2 * n, k), - device="cuda", - dtype=torch.float8_e4m3fn) - w2_q = torch.empty((num_experts, k, n), - device="cuda", - dtype=torch.float8_e4m3fn) - w1_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - w2_scale = torch.empty((num_experts, 1, 1), - device="cuda", - dtype=torch.float32) - - ab_strides1 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) - c_strides1 = torch.full((num_experts, ), - 2 * n, - device="cuda", - dtype=torch.int64) - ab_strides2 = torch.full((num_experts, ), - n, - device="cuda", - dtype=torch.int64) - c_strides2 = torch.full((num_experts, ), - k, - device="cuda", - dtype=torch.int64) + w1_q = torch.empty( + (num_experts, 2 * n, k), device="cuda", dtype=torch.float8_e4m3fn + ) + w2_q = torch.empty((num_experts, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) + + ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) + c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) + ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) + c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) @@ -90,82 +84,121 @@ def bench_run(results: list[benchmark.Measurement], model: str, score = torch.randn((m, num_experts), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score, topk, renormalize=False + ) - def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - a_scale: torch.Tensor, num_repeats: int): + def run_triton_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) - - def run_cutlass_moe(a: torch.Tensor, a_scale: torch.Tensor, - w1: torch.Tensor, w2: torch.Tensor, - w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor, - num_repeats: int): + fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) + + def run_cutlass_moe( + a: torch.Tensor, + a_scale: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + num_repeats: int, + ): for _ in range(num_repeats): - cutlass_moe_fp8(a, - w1, - w2, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) + cutlass_moe_fp8( + a, + w1, + w2, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) def run_cutlass_from_graph( - a: torch.Tensor, a_scale: torch.Tensor, w1_q: torch.Tensor, - w2_q: torch.Tensor, w1_scale: torch.Tensor, w2_scale: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, c_strides2: torch.Tensor): + a: torch.Tensor, + a_scale: torch.Tensor, + w1_q: torch.Tensor, + w2_q: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + ab_strides1: torch.Tensor, + c_strides1: torch.Tensor, + ab_strides2: torch.Tensor, + c_strides2: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return cutlass_moe_fp8(a, - w1_q, - w2_q, - w1_scale, - w2_scale, - topk_weights, - topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - a1_scale=a_scale) - - def run_triton_from_graph(a: torch.Tensor, w1: torch.Tensor, - w2: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, w1_scale: torch.Tensor, - w2_scale: torch.Tensor, a_scale: torch.Tensor): + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return cutlass_moe_fp8( + a, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + a1_scale=a_scale, + ) + + def run_triton_from_graph( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a_scale: torch.Tensor, + ): with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - return fused_experts(a, - w1, - w2, - topk_weights, - topk_ids, - use_fp8_w8a8=True, - w1_scale=w1_scale, - w2_scale=w2_scale, - a1_scale=a_scale) + VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) + ): + return fused_experts( + a, + w1, + w2, + topk_weights, + topk_ids, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a_scale, + ) def replay_graph(graph, num_repeats): for _ in range(num_repeats): @@ -175,16 +208,35 @@ def bench_run(results: list[benchmark.Measurement], model: str, cutlass_stream = torch.cuda.Stream() cutlass_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(cutlass_graph, stream=cutlass_stream): - run_cutlass_from_graph(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, - topk_weights, topk_ids, ab_strides1, c_strides1, - ab_strides2, c_strides2) + run_cutlass_from_graph( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + ) torch.cuda.synchronize() triton_stream = torch.cuda.Stream() triton_graph = torch.cuda.CUDAGraph() with torch.cuda.graph(triton_graph, stream=triton_stream): - run_triton_from_graph(a, w1_q_notransp, w2_q_notransp, topk_weights, - topk_ids, w1_scale, w2_scale, a_scale) + run_triton_from_graph( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + ) torch.cuda.synchronize() min_run_time = 5 @@ -224,18 +276,27 @@ def bench_run(results: list[benchmark.Measurement], model: str, } # Warmup - run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, - w1_scale, w2_scale, a_scale, num_warmup) + run_triton_moe( + a, + w1_q_notransp, + w2_q_notransp, + topk_weights, + topk_ids, + w1_scale, + w2_scale, + a_scale, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="triton_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(triton_graph, num_warmup) @@ -247,22 +308,35 @@ def bench_run(results: list[benchmark.Measurement], model: str, label=label, sub_label=sub_label, description="triton_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup - run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, - topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, - num_warmup) + run_cutlass_moe( + a, + a_scale, + w1_q, + w2_q, + w1_scale, + w2_scale, + topk_weights, + topk_ids, + ab_strides1, + c_strides1, + ab_strides2, + c_strides2, + num_warmup, + ) results.append( benchmark.Timer( - stmt= - "run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="grouped_gemm_moe", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) # Warmup replay_graph(cutlass_graph, num_warmup) @@ -274,7 +348,8 @@ def bench_run(results: list[benchmark.Measurement], model: str, label=label, sub_label=sub_label, description="grouped_gemm_moe_cuda_graphs", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -302,8 +377,15 @@ def main(args): for per_out_ch in PER_OUT_CH_OPTS: for size_m in DEFAULT_BATCH_SIZES: mkn = (size_m, size_k, size_n) - bench_run(results, model, num_experts, topk, - per_act_token, per_out_ch, mkn) + bench_run( + results, + model, + num_experts, + topk, + per_act_token, + per_out_ch, + mkn, + ) compare = benchmark.Compare(results) compare.print() @@ -311,7 +393,8 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -319,21 +402,14 @@ if __name__ == "__main__": default=DEFAULT_MODELS, choices=WEIGHT_SHAPES_MOE.keys(), ) - parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument("--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-num-groups", nargs="+", type=int, default=[]) - parser.add_argument("--limit-per-act-token", - nargs="+", - type=int, - default=[]) + parser.add_argument("--limit-per-act-token", nargs="+", type=int, default=[]) parser.add_argument("--limit-per-out-ch", nargs="+", type=int, default=[]) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index e12d74c01e43c4a7ac52cfdebd76b2f186696124..f21ca97eeb8a9b2abbf04f979f90ece9003f1cef 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -10,14 +10,16 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - add_residual: bool, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,33 +58,35 @@ def main(num_tokens: int, print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - parser = FlexibleArgumentParser( - description="Benchmark the layernorm kernel.") +if __name__ == "__main__": + parser = FlexibleArgumentParser(description="Benchmark the layernorm kernel.") parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--add-residual", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - add_residual=args.add_residual, - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index d382ede10b41be1b2168214d88789100dc04e1d4..6c1284930c1ec3f6963dd9ec325e72179d14da5f 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -20,18 +20,36 @@ from weight_shapes import WEIGHT_SHAPES from vllm.triton_utils import HAS_TRITON if HAS_TRITON: - from vllm.lora.ops.triton_ops import (LoRAKernelMeta, lora_expand, - lora_shrink) - from vllm.lora.ops.triton_ops.utils import (_LORA_A_PTR_DICT, - _LORA_B_PTR_DICT) + from vllm.lora.ops.triton_ops import LoRAKernelMeta, lora_expand, lora_shrink + from vllm.lora.ops.triton_ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.utils import FlexibleArgumentParser DEFAULT_MODELS = list(WEIGHT_SHAPES.keys()) DEFAULT_TP_SIZES = [1] DEFAULT_BATCH_SIZES = [ - 1, 16, 32, 64, 128, 192, 256, 320, 384, 448, 512, 640, 768, 896, 1024, - 2048, 3072, 4096, 5120, 6144, 7168, 8192 + 1, + 16, + 32, + 64, + 128, + 192, + 256, + 320, + 384, + 448, + 512, + 640, + 768, + 896, + 1024, + 2048, + 3072, + 4096, + 5120, + 6144, + 7168, + 8192, ] DEFAULT_HIDDEN_SIZES = [1024, 2048, 4096, 8192, 16384] DEFAULT_LORA_RANKS = [16] @@ -52,12 +70,9 @@ def dtype_to_str(dtype: torch.dtype): raise ValueError(f"Unsupported dtype {dtype}") -def make_rand_lora_weight_tensor(k: int, - n: int, - num_loras: int, - dtype: torch.dtype, - device: str = "cuda") -> torch.Tensor: - +def make_rand_lora_weight_tensor( + k: int, n: int, num_loras: int, dtype: torch.dtype, device: str = "cuda" +) -> torch.Tensor: # LoRA weights column major return torch.rand((num_loras, n, k), dtype=dtype).to(device) @@ -78,18 +93,15 @@ def make_rand_tensors( A = torch.rand(a_shape, dtype=a_dtype).to(device) # LoRA weights column major - Bs = [ - torch.rand(b_shape, dtype=b_dtype).to(device) - for _ in range(num_slices) - ] + Bs = [torch.rand(b_shape, dtype=b_dtype).to(device) for _ in range(num_slices)] C = torch.zeros(c_shape, dtype=c_dtype).to(device) return A, Bs, C -def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, - sort_by_lora_id: bool, - device: str) -> torch.Tensor: +def make_prompt_lora_mapping( + num_prompts: int, num_active_loras: int, sort_by_lora_id: bool, device: str +) -> torch.Tensor: """ All prompts are mapped to a LoRA ID in range [0, num_active_loras). where 0 refers to first lora, 1 refers to second lora and so on. @@ -97,9 +109,7 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, assert num_active_loras > 0 if not sort_by_lora_id: - return torch.randint(0, - num_active_loras, (num_prompts, ), - dtype=torch.long) + return torch.randint(0, num_active_loras, (num_prompts,), dtype=torch.long) # Divide LoRAs equally and in order. part_size = num_prompts // num_active_loras @@ -110,14 +120,18 @@ def make_prompt_lora_mapping(num_prompts: int, num_active_loras: int, while len(prompt_lora_mapping) < num_prompts: prompt_lora_mapping.extend([lora_id] * part_size) lora_id = lora_id + 1 if lora_id + 1 < num_active_loras else lora_id - return torch.tensor(prompt_lora_mapping[:num_prompts], - dtype=torch.long, - device=device) - - -def make_token_lora_mapping(num_tokens: int, num_prompts: int, - prompt_lora_mapping: torch.Tensor, - seq_len_tensor: torch.Tensor, device: str): + return torch.tensor( + prompt_lora_mapping[:num_prompts], dtype=torch.long, device=device + ) + + +def make_token_lora_mapping( + num_tokens: int, + num_prompts: int, + prompt_lora_mapping: torch.Tensor, + seq_len_tensor: torch.Tensor, + device: str, +): """ Make token_lora_mapping from prompt_lora_mapping and seq_lens_tensor """ @@ -136,11 +150,15 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int, return torch.tensor(token_lora_mapping, dtype=torch.long, device=device) -def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - seq_lens_cpu: torch.Tensor, - prompt_lora_mapping_cpu: torch.Tensor, scaling: float, - add_inputs: Optional[bool]): +def ref_group_gemm( + ref_out: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + seq_lens_cpu: torch.Tensor, + prompt_lora_mapping_cpu: torch.Tensor, + scaling: float, + add_inputs: Optional[bool], +): """ Torch group gemm reference implementation to test correctness of benchmarking operations. @@ -149,7 +167,7 @@ def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor, out_list = [] current_offset = 0 for lora_index, b_length in zip(range(batches), seq_lens_cpu): - x = input[current_offset:b_length + current_offset, :] + x = input[current_offset : b_length + current_offset, :] current_offset += b_length w = lora_weights[prompt_lora_mapping_cpu[lora_index]] result = torch.nn.functional.linear(x, w) @@ -168,6 +186,7 @@ class OpType(Enum): """ LoRA Ops to benchmark and its properties. """ + LORA_SHRINK = auto() LORA_EXPAND = auto() @@ -188,8 +207,9 @@ class OpType(Enum): def num_slices(self) -> list[int]: return [1, 2, 3] - def mkn(self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int) -> tuple[int, int, int]: + def mkn( + self, batch_size: int, seq_length: int, hidden_size: int, lora_rank: int + ) -> tuple[int, int, int]: num_tokens = batch_size * seq_length if self.is_shrink_fn(): m = num_tokens @@ -203,7 +223,7 @@ class OpType(Enum): return m, k, n def matmul_dtypes( - self, op_dtype: torch.dtype + self, op_dtype: torch.dtype ) -> tuple[torch.dtype, torch.dtype, torch.dtype]: """ return a type, b type and c type for A x B = C @@ -215,9 +235,14 @@ class OpType(Enum): return torch.float32, op_dtype, op_dtype def matmul_shapes( - self, batch_size: int, seq_length: int, hidden_size: int, - lora_rank: int, num_loras: int, - num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]: + self, + batch_size: int, + seq_length: int, + hidden_size: int, + lora_rank: int, + num_loras: int, + num_slices: int, + ) -> tuple[tuple[int], tuple[int], tuple[int]]: """ Given num_slices, return the shapes of the A, B, and C matrices in A x B = C, for the op_type @@ -241,31 +266,38 @@ class OpType(Enum): raise ValueError(f"Unrecognized optype {self}") - def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor, - lora_weights: list[torch.Tensor], - **kwargs) -> Callable: + def run_ref_group_gemm( + self, + output: torch.Tensor, + input: torch.Tensor, + lora_weights: list[torch.Tensor], + **kwargs, + ) -> Callable: """Each benchmark operation expects the input, lora_weights and outputs - in a slightly different format. Refer to self.matmul_shapes(). - run_ref_group_gemm accounts for those differences in executing a - reference group gemm for correctness testing. + in a slightly different format. Refer to self.matmul_shapes(). + run_ref_group_gemm accounts for those differences in executing a + reference group gemm for correctness testing. """ w_dtype = lora_weights[0].dtype num_slices = len(lora_weights) if self in [OpType.LORA_SHRINK]: for slice_idx in range(num_slices): - ref_group_gemm(ref_out=output[slice_idx, :], - input=input, - lora_weights=lora_weights[slice_idx], - **kwargs) + ref_group_gemm( + ref_out=output[slice_idx, :], + input=input, + lora_weights=lora_weights[slice_idx], + **kwargs, + ) elif self in [OpType.LORA_EXPAND]: hidden_size = lora_weights[0].shape[1] for slice_idx in range(num_slices): slice_offset = slice_idx * hidden_size ref_group_gemm( - ref_out=output[:, slice_offset:slice_offset + hidden_size], + ref_out=output[:, slice_offset : slice_offset + hidden_size], input=input[slice_idx].clone().to(dtype=w_dtype), lora_weights=lora_weights[slice_idx], - **kwargs) + **kwargs, + ) else: raise ValueError(f"Unrecognized optype {self}") @@ -275,6 +307,7 @@ class BenchmarkContext: """ LoRA benchmark context """ + batch_size: int hidden_size: int num_loras: int @@ -299,17 +332,18 @@ class BenchmarkContext: return f"lora-{self.dtype}" def bench_sublabel(self, op_type: OpType) -> str: - m, k, n = op_type.mkn(self.batch_size, self.seq_length, - self.hidden_size, self.lora_rank) + m, k, n = op_type.mkn( + self.batch_size, self.seq_length, self.hidden_size, self.lora_rank + ) desc = { - 'bs': self.batch_size, - 'sl': self.seq_length, - 'm': m, - 'k': k, - 'n': n, - 'num_loras': self.num_loras, - 'sort_by_lora': self.sort_by_lora_id, - 'num_slices': self.num_slices, + "bs": self.batch_size, + "sl": self.seq_length, + "m": m, + "k": k, + "n": n, + "num_loras": self.num_loras, + "sort_by_lora": self.sort_by_lora_id, + "num_slices": self.num_slices, } return json.dumps(desc) @@ -319,6 +353,7 @@ class BenchmarkTensors: """ Input/Output tensors used for benchmarks """ + # matmul tensors input: torch.Tensor lora_weights_lst: list[torch.Tensor] @@ -330,23 +365,29 @@ class BenchmarkTensors: prompt_lora_mapping: torch.Tensor def io_types(self) -> str: - return (f"{dtype_to_str(self.input.dtype)}x" - f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" - f"{dtype_to_str(self.output.dtype)}") + return ( + f"{dtype_to_str(self.input.dtype)}x" + f"{dtype_to_str(self.lora_weights_lst[0].dtype)}=>" + f"{dtype_to_str(self.output.dtype)}" + ) @staticmethod - def make(ctx: BenchmarkContext, - op_type: OpType, - device: str = "cuda") -> "BenchmarkTensors": - + def make( + ctx: BenchmarkContext, op_type: OpType, device: str = "cuda" + ) -> "BenchmarkTensors": # Make input / output matmul tensors. a_shape, b_shape, c_shape = op_type.matmul_shapes( - ctx.batch_size, ctx.seq_length, ctx.hidden_size, ctx.lora_rank, - ctx.num_loras, ctx.num_slices) + ctx.batch_size, + ctx.seq_length, + ctx.hidden_size, + ctx.lora_rank, + ctx.num_loras, + ctx.num_slices, + ) a_type, b_type, c_type = op_type.matmul_dtypes(ctx.dtype) - input_tensor, lora_weights, output_tensor = \ - make_rand_tensors(a_shape, b_shape, c_shape, a_type, b_type, c_type, - num_slices = ctx.num_slices) + input_tensor, lora_weights, output_tensor = make_rand_tensors( + a_shape, b_shape, c_shape, a_type, b_type, c_type, num_slices=ctx.num_slices + ) # Make metadata tensors. # Keep the metadata tensors in the CPU for further processing if needed. @@ -356,27 +397,38 @@ class BenchmarkTensors: # Make metadata tensors involved in correctness testing. # Prepare seq lens tensor - seq_len_tensor = torch.randint(ctx.seq_length, ctx.seq_length + 1, - (ctx.batch_size, )) + seq_len_tensor = torch.randint( + ctx.seq_length, ctx.seq_length + 1, (ctx.batch_size,) + ) assert total_tokens == seq_len_tensor.sum() # Prepare prompt lora indices tensor prompt_lora_indices_tensor = make_prompt_lora_mapping( - ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu") + ctx.batch_size, ctx.num_active_loras, ctx.sort_by_lora_id, "cpu" + ) # Make LoRAKernelMeta token_lora_indices_tensor = make_token_lora_mapping( - total_tokens, ctx.batch_size, prompt_lora_indices_tensor, - seq_len_tensor, "cpu") + total_tokens, + ctx.batch_size, + prompt_lora_indices_tensor, + seq_len_tensor, + "cpu", + ) lora_kernel_meta = LoRAKernelMeta.make( max_loras=ctx.num_loras, max_num_tokens=token_lora_indices_tensor.size(0), - device="cpu") - lora_kernel_meta.prepare_tensors( - token_lora_mapping=token_lora_indices_tensor) - - return BenchmarkTensors(input_tensor, lora_weights, output_tensor, - lora_kernel_meta, seq_len_tensor, - prompt_lora_indices_tensor) + device="cpu", + ) + lora_kernel_meta.prepare_tensors(token_lora_mapping=token_lora_indices_tensor) + + return BenchmarkTensors( + input_tensor, + lora_weights, + output_tensor, + lora_kernel_meta, + seq_len_tensor, + prompt_lora_indices_tensor, + ) def sanity_check(self) -> None: """ @@ -386,7 +438,7 @@ class BenchmarkTensors: # check metadata tensors assert torch.sum(self.seq_lens) == num_tokens num_seqs = self.seq_lens.shape[0] - #assert self.seq_start_loc.shape[0] == num_seqs + # assert self.seq_start_loc.shape[0] == num_seqs assert self.prompt_lora_mapping.shape[0] == num_seqs assert self.lora_kernel_meta.token_lora_mapping.shape[0] == num_tokens @@ -430,8 +482,11 @@ class BenchmarkTensors: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape [num_tokens, hidden_size] assert len(i_shape) == 2 assert i_shape[0] == num_tokens @@ -445,16 +500,17 @@ class BenchmarkTensors: assert o_shape == (num_slices, num_tokens, lora_rank) return { - 'inputs': self.input, - 'lora_a_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'scaling': 1.0, + "inputs": self.input, + "lora_a_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "scaling": 1.0, } def as_lora_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]: @@ -464,8 +520,11 @@ class BenchmarkTensors: _, num_tokens, _, num_slices = self.metadata() # Sanity check matrix shapes. - i_shape, lw_shape, o_shape = self.input.shape, self.lora_weights_lst[ - 0].shape, self.output.shape + i_shape, lw_shape, o_shape = ( + self.input.shape, + self.lora_weights_lst[0].shape, + self.output.shape, + ) # Expected input shape : [num_slices, num_tokens, lora_rank] assert len(i_shape) == 3 assert i_shape[0] == num_slices @@ -480,22 +539,23 @@ class BenchmarkTensors: assert o_shape == (num_tokens, hidden_size * num_slices) return { - 'inputs': self.input, - 'lora_b_weights': self.lora_weights_lst, - 'output_tensor': self.output, - 'token_lora_mapping': self.lora_kernel_meta.token_lora_mapping, - 'token_indices_sorted_by_lora_ids': - self.lora_kernel_meta.token_indices_sorted_by_lora_ids, - 'num_tokens_per_lora': self.lora_kernel_meta.num_tokens_per_lora, - 'lora_token_start_loc': self.lora_kernel_meta.lora_token_start_loc, - 'lora_ids': self.lora_kernel_meta.active_lora_ids, - 'offset_start': 0, - 'add_inputs': add_inputs, + "inputs": self.input, + "lora_b_weights": self.lora_weights_lst, + "output_tensor": self.output, + "token_lora_mapping": self.lora_kernel_meta.token_lora_mapping, + "token_indices_sorted_by_lora_ids": ( + self.lora_kernel_meta.token_indices_sorted_by_lora_ids + ), + "num_tokens_per_lora": self.lora_kernel_meta.num_tokens_per_lora, + "lora_token_start_loc": self.lora_kernel_meta.lora_token_start_loc, + "lora_ids": self.lora_kernel_meta.active_lora_ids, + "offset_start": 0, + "add_inputs": add_inputs, } - def bench_fn_kwargs(self, - op_type: OpType, - add_inputs: Optional[bool] = None) -> dict[str, Any]: + def bench_fn_kwargs( + self, op_type: OpType, add_inputs: Optional[bool] = None + ) -> dict[str, Any]: if op_type.is_shrink_fn(): assert add_inputs is None else: @@ -507,8 +567,9 @@ class BenchmarkTensors: return self.as_lora_expand_kwargs(add_inputs) raise ValueError(f"Unrecognized optype {self}") - def test_correctness(self, op_type: OpType, - expand_fn_add_inputs: Optional[bool]) -> bool: + def test_correctness( + self, op_type: OpType, expand_fn_add_inputs: Optional[bool] + ) -> bool: """ Test correctness of op_type implementation against a grouped gemm reference implementation. @@ -518,8 +579,7 @@ class BenchmarkTensors: ref_output = self.output.clone() self.output.zero_() - op_type.bench_fn()( - **self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) + op_type.bench_fn()(**self.bench_fn_kwargs(op_type, expand_fn_add_inputs)) op_type.run_ref_group_gemm( ref_output, @@ -528,7 +588,8 @@ class BenchmarkTensors: seq_lens_cpu=seq_lens_cpu, prompt_lora_mapping_cpu=prompt_lora_mapping_cpu, scaling=1.0, - add_inputs=expand_fn_add_inputs) + add_inputs=expand_fn_add_inputs, + ) rtol, atol = { torch.float16: (6e-2, 6e-2), @@ -539,13 +600,14 @@ class BenchmarkTensors: return torch.allclose(ref_output, self.output, rtol=rtol, atol=atol) -def bench_optype(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None, - expand_fn_add_inputs: Optional[bool] = None, - test_correctness: bool = False) -> TMeasurement: - +def bench_optype( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, + expand_fn_add_inputs: Optional[bool] = None, + test_correctness: bool = False, +) -> TMeasurement: assert arg_pool_size >= 1 if op_type.is_shrink_fn(): assert expand_fn_add_inputs is None @@ -553,17 +615,17 @@ def bench_optype(ctx: BenchmarkContext, assert expand_fn_add_inputs is not None # BenchmarkContext -> BenchmarkTensors - bench_tensors : list[BenchmarkTensors] = \ - [BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)] + bench_tensors: list[BenchmarkTensors] = [ + BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size) + ] for bt in bench_tensors: bt.sanity_check() # Test correctness of our implementation. if test_correctness: - assert all([ - bt.test_correctness(op_type, expand_fn_add_inputs) - for bt in bench_tensors - ]) + assert all( + [bt.test_correctness(op_type, expand_fn_add_inputs) for bt in bench_tensors] + ) # BenchmarkTensors -> dict (kwargs) kwargs_list = [ @@ -585,40 +647,49 @@ def bench_optype(ctx: BenchmarkContext, for k, v in _kwargs.items(): kwargs[k].values.append(v) - describe_args = (f"add_inputs={expand_fn_add_inputs}" - if expand_fn_add_inputs is not None else "") - description = ( - f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})") + describe_args = ( + f"add_inputs={expand_fn_add_inputs}" if expand_fn_add_inputs is not None else "" + ) + description = f"{op_type.name}({describe_args}) ({bench_tensors[0].io_types()})" cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) timer = None - with Bench(cuda_graph_params, - ctx.bench_label(), ctx.bench_sublabel(op_type), description, - op_type.bench_fn(), **kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + op_type.bench_fn(), + **kwargs, + ) as bench: timer = bench.run() return timer -def bench_torch_mm(ctx: BenchmarkContext, - arg_pool_size: int, - op_type: OpType, - cuda_graph_nops: Optional[int] = None) -> TMeasurement: +def bench_torch_mm( + ctx: BenchmarkContext, + arg_pool_size: int, + op_type: OpType, + cuda_graph_nops: Optional[int] = None, +) -> TMeasurement: """ Benchmark basic torch.mm as a roofline. When all the input tokens have the same LoRA ID, the LoRA kernels are just - a matmul. This torch.mm benchmark serves as a roofline for that case. + a matmul. This torch.mm benchmark serves as a roofline for that case. input op_type is used in determining the m, k, n dimensions for the matmul. """ - batch_size, hidden_size, lora_rank, seq_length, dtype = (ctx.batch_size, - ctx.hidden_size, - ctx.lora_rank, - ctx.seq_length, - ctx.dtype) + batch_size, hidden_size, lora_rank, seq_length, dtype = ( + ctx.batch_size, + ctx.hidden_size, + ctx.lora_rank, + ctx.seq_length, + ctx.dtype, + ) m, k, n = op_type.mkn(batch_size, seq_length, hidden_size, lora_rank) # For a fairer comparison. @@ -632,18 +703,24 @@ def bench_torch_mm(ctx: BenchmarkContext, Cs.append(torch.rand((m, n), dtype=dtype).to("cuda")) # Make torch.mm kwargs - mm_kwargs = {'input': ArgPool(As), 'mat2': ArgPool(Bs), 'out': ArgPool(Cs)} + mm_kwargs = {"input": ArgPool(As), "mat2": ArgPool(Bs), "out": ArgPool(Cs)} description = ( f"single-lora roofline using torch.mm ({dtype_to_str(dtype)}" f"x{dtype_to_str(dtype)}" - f"=>{dtype_to_str(dtype)})") + f"=>{dtype_to_str(dtype)})" + ) cuda_graph_params = None if cuda_graph_nops: cuda_graph_params = CudaGraphBenchParams(cuda_graph_nops) - with Bench(cuda_graph_params, ctx.bench_label(), - ctx.bench_sublabel(op_type), description, torch.mm, - **mm_kwargs) as bench: + with Bench( + cuda_graph_params, + ctx.bench_label(), + ctx.bench_sublabel(op_type), + description, + torch.mm, + **mm_kwargs, + ) as bench: return bench.run() @@ -660,8 +737,7 @@ def use_cuda_graph_recommendation() -> str: """ -def print_timers(timers: list[TMeasurement], - args: Optional[argparse.Namespace] = None): +def print_timers(timers: list[TMeasurement], args: Optional[argparse.Namespace] = None): compare = TBenchmark.Compare(timers) compare.print() @@ -670,22 +746,23 @@ def print_timers(timers: list[TMeasurement], f"Note : The timings reported above is for {args.cuda_graph_nops} " "consecutive invocations of the benchmarking functions. " f"Please divide by {args.cuda_graph_nops} for single invocation " - "timings.") + "timings." + ) - print("Note on Comparison with torch.mm : The torch.mm numbers are " - "benchmark numbers of a simple matmul emulating the single lora " - "case. It is provided as a roofline for comparing our LoRA Kernel " - "implementations. It is expected that the LoRA kernels will be " - "slower than torch.mm in cases where num_loras is big. But for " - "small num_loras the goal should be to match the torch.mm numbers.") + print( + "Note on Comparison with torch.mm : The torch.mm numbers are " + "benchmark numbers of a simple matmul emulating the single lora " + "case. It is provided as a roofline for comparing our LoRA Kernel " + "implementations. It is expected that the LoRA kernels will be " + "slower than torch.mm in cases where num_loras is big. But for " + "small num_loras the goal should be to match the torch.mm numbers." + ) def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): - if args.cuda_graph_nops is not None: assert args.cuda_graph_nops > 0 - print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA " - "Graph") + print(f"Benchmarking {args.cuda_graph_nops} invocations inside a CUDA Graph") else: print(f"CUDA Graphs not enabled.\n{use_cuda_graph_recommendation()}") @@ -697,21 +774,30 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): for bench_op in bench_ops: for num_slices in bench_op.num_slices(): _ctx = bench_ctx.with_seq_length(seq_len).with_num_slices( - num_slices) + num_slices + ) # Benchmark torch.mm as a roofline seq_len_timers.append( - bench_torch_mm(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops)) + bench_torch_mm( + _ctx, args.arg_pool_size, bench_op, args.cuda_graph_nops + ) + ) # Benchmark bench_op - expand_fn_add_inputs = [ - None - ] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + expand_fn_add_inputs = ( + [None] if bench_op.is_shrink_fn() else args.expand_fn_add_inputs + ) for add_input_arg in expand_fn_add_inputs: seq_len_timers.append( - bench_optype(_ctx, args.arg_pool_size, bench_op, - args.cuda_graph_nops, add_input_arg, - args.test_correctness)) + bench_optype( + _ctx, + args.arg_pool_size, + bench_op, + args.cuda_graph_nops, + add_input_arg, + args.test_correctness, + ) + ) print_timers(seq_len_timers) timers.extend(seq_len_timers) @@ -733,13 +819,17 @@ def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]): pickle.dump(timers, f) -def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], - args: argparse.Namespace) -> list[BenchmarkContext]: - +def as_benchmark_contexts( + hidden_sizes: list[int], lora_ranks: list[int], args: argparse.Namespace +) -> list[BenchmarkContext]: ctxs: list[BenchmarkContext] = [] for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa - args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras, - args.sort_by_lora_id): + args.batch_sizes, + list(hidden_sizes), + lora_ranks, + args.num_loras, + args.sort_by_lora_id, + ): ctxs.append( BenchmarkContext( batch_size=batch_size, @@ -747,13 +837,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], lora_rank=lora_rank, num_loras=num_loras, num_active_loras=args.num_active_loras - if args.num_active_loras else num_loras, + if args.num_active_loras + else num_loras, # To be filled based on the OpType to benchmark seq_length=None, sort_by_lora_id=sort_by_lora_id, dtype=args.dtype, # To be filled based on the OpType to benchmark - num_slices=None)) + num_slices=None, + ) + ) return ctxs @@ -761,13 +854,16 @@ def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int], def run_list_bench(args: argparse.Namespace): print(args) - print("List bench :\n" - f" Hidden Sizes {args.hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print( + "List bench :\n" + f" Hidden Sizes {args.hidden_sizes}" + f" LoRA Ranks {args.lora_ranks}" + ) # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) @@ -776,19 +872,22 @@ def run_range_bench(args: argparse.Namespace): print(args) hidden_sizes = list( - range(args.hidden_sizes_start, args.hidden_sizes_end + 1, - args.hidden_sizes_increment)) + range( + args.hidden_sizes_start, + args.hidden_sizes_end + 1, + args.hidden_sizes_increment, + ) + ) lora_ranks = list( - range(args.lora_ranks_start, args.lora_ranks_end + 1, - args.lora_ranks_increment)) + range(args.lora_ranks_start, args.lora_ranks_end + 1, args.lora_ranks_increment) + ) - print("Range bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {lora_ranks}") + print(f"Range bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args + ) run(args, bench_contexts) @@ -806,21 +905,19 @@ def run_model_bench(args: argparse.Namespace): # Get all hidden sizes hidden_sizes: set[int] = set() for model_name, tp_size in product(args.models, args.tp_sizes): - hidden_sizes = hidden_sizes.union( - hidden_sizes_from_model(model_name, tp_size)) + hidden_sizes = hidden_sizes.union(hidden_sizes_from_model(model_name, tp_size)) - print("Model bench :\n" - f" Hidden Sizes {hidden_sizes}" - f" LoRA Ranks {args.lora_ranks}") + print(f"Model bench :\n Hidden Sizes {hidden_sizes} LoRA Ranks {args.lora_ranks}") # Get all benchmarking contexts bench_contexts: list[BenchmarkContext] = as_benchmark_contexts( - hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args) + hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args + ) run(args, bench_contexts) -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "torch.float16": @@ -830,14 +927,15 @@ if __name__ == '__main__': raise ValueError("unsupported dtype") def get_bool(s: str) -> bool: - return s.lower() in ['true', '1'] + return s.lower() in ["true", "1"] def add_common_command_args(p: argparse.ArgumentParser): p.add_argument( "--dtype", type=to_torch_dtype, required=True, - help="Available options are ['torch.float16', 'torch.bfloat16']") + help="Available options are ['torch.float16', 'torch.bfloat16']", + ) p.add_argument( "--arg-pool-size", @@ -845,56 +943,66 @@ if __name__ == '__main__': default=32, help="Run profiles with a pool of input/output/meta tensors instead" "of simply reusing the same tensors for all runs. A bigger arg-pool" - "mitigates hardware caching effects during benchmarking.") + "mitigates hardware caching effects during benchmarking.", + ) p.add_argument( "--cuda-graph-nops", type=int, - help=("when set profiling is done using cudagraph, " - "with the given number of operations in a graph." - "Note that the measurement returned is the time " - "taken for N consecutive executions of the benchmarking " - "functions, where N is the value of this argument.")) - p.add_argument("--num-loras", - nargs="+", - type=int, - default=DEFAULT_NUM_LORAS) - p.add_argument("--num-active-loras", - type=int, - default=None, - help="Active LoRAs. When None, all LoRAs are active") - p.add_argument("--sort-by-lora-id", - nargs="+", - type=get_bool, - default=DEFAULT_SORT_BY_LORA_IDS) - p.add_argument("--op-types", - nargs="+", - type=OpType.from_str, - default=list(OpType)) - p.add_argument('--seq-lengths', - nargs="+", - type=int, - default=DEFAULT_SEQ_LENGTHS) - p.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) - p.add_argument("--expand-fn-add-inputs", - nargs="+", - type=get_bool, - default=DEFAULT_EXPAND_FN_ADD_INPUTS) + help=( + "when set profiling is done using cudagraph, " + "with the given number of operations in a graph." + "Note that the measurement returned is the time " + "taken for N consecutive executions of the benchmarking " + "functions, where N is the value of this argument." + ), + ) + p.add_argument("--num-loras", nargs="+", type=int, default=DEFAULT_NUM_LORAS) + p.add_argument( + "--num-active-loras", + type=int, + default=None, + help="Active LoRAs. When None, all LoRAs are active", + ) + p.add_argument( + "--sort-by-lora-id", + nargs="+", + type=get_bool, + default=DEFAULT_SORT_BY_LORA_IDS, + ) + p.add_argument( + "--op-types", nargs="+", type=OpType.from_str, default=list(OpType) + ) + p.add_argument( + "--seq-lengths", nargs="+", type=int, default=DEFAULT_SEQ_LENGTHS + ) + p.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) + p.add_argument( + "--expand-fn-add-inputs", + nargs="+", + type=get_bool, + default=DEFAULT_EXPAND_FN_ADD_INPUTS, + ) p.add_argument( - '-o', - '--output-directory', + "-o", + "--output-directory", type=str, - help=("Output directory to store a the list of benchmarking" - "TMeasurement objects as a pickle file")) + help=( + "Output directory to store a the list of benchmarking" + "TMeasurement objects as a pickle file" + ), + ) p.add_argument( "--test-correctness", - action='store_true', - help=("When enabled, the benchmarking functions are tested" - "for correctness before the actual benchmarking")) + action="store_true", + help=( + "When enabled, the benchmarking functions are tested" + "for correctness before the actual benchmarking" + ), + ) parser = FlexibleArgumentParser( description=f""" @@ -910,50 +1018,45 @@ Benchmark LoRA kernels: range_bench example: python3 benchmarks/kernels/benchmark_lora.py range_bench --arg-pool-size 32 --batch-sizes 1 16 32 --dtype torch.float16 --num-loras 1 4 --op-types lora_shrink lora_expand --seq-lengths 1 16 --sort-by-lora-id 1 --cuda-graph-nops 32 --hidden-sizes-start 1024 --hidden-sizes-end 4096 --hidden-sizes-increment 1024 --lora-ranks-start 8 --lora-ranks-end 24 --lora-ranks-increment 8 """, # noqa: E501 - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) subparsers = parser.add_subparsers(dest="cmd", required=True) list_parser = subparsers.add_parser("list_bench") - list_parser.add_argument("--hidden-sizes", - nargs="+", - type=int, - default=DEFAULT_HIDDEN_SIZES) - list_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + list_parser.add_argument( + "--hidden-sizes", nargs="+", type=int, default=DEFAULT_HIDDEN_SIZES + ) + list_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(list_parser) list_parser.set_defaults(func=run_list_bench) range_parser = subparsers.add_parser("range_bench") range_parser.add_argument("--hidden-sizes-start", type=int, required=True) range_parser.add_argument("--hidden-sizes-end", type=int, required=True) - range_parser.add_argument("--hidden-sizes-increment", - type=int, - required=True) + range_parser.add_argument("--hidden-sizes-increment", type=int, required=True) range_parser.add_argument("--lora-ranks-start", type=int, required=True) range_parser.add_argument("--lora-ranks-end", type=int, required=True) - range_parser.add_argument("--lora-ranks-increment", - type=int, - required=True) + range_parser.add_argument("--lora-ranks-increment", type=int, required=True) add_common_command_args(range_parser) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") - model_parser.add_argument("--models", - nargs="+", - type=str, - default=DEFAULT_MODELS, - choices=WEIGHT_SHAPES.keys()) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--lora-ranks", - nargs="+", - type=int, - default=DEFAULT_LORA_RANKS) + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--lora-ranks", nargs="+", type=int, default=DEFAULT_LORA_RANKS + ) add_common_command_args(model_parser) model_parser.set_defaults(func=run_model_bench) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index a661ea9d7e60be322f28eeff62f9ace449ee702e..f8f1db04790bfb714751ef29da3391440a01e378 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -20,12 +20,18 @@ from weight_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales, - marlin_zero_points) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + marlin_permute_scales, + marlin_zero_points, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace) + MarlinWorkspace, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - pack_rows, quantize_weights) + pack_rows, + quantize_weights, +) from vllm.scalar_type import ScalarType, scalar_types from vllm.utils import FlexibleArgumentParser @@ -82,12 +88,14 @@ def rand_data(shape, dtype=torch.float16, scale=1): return torch.randint(-15, 15, shape, dtype=dtype, device="cuda") -def quantize_and_pack(atype: torch.dtype, - w: torch.Tensor, - wtype: ScalarType, - stype: Optional[torch.dtype], - group_size: Optional[int], - zero_points: bool = False): +def quantize_and_pack( + atype: torch.dtype, + w: torch.Tensor, + wtype: ScalarType, + stype: Optional[torch.dtype], + group_size: Optional[int], + zero_points: bool = False, +): assert wtype.is_integer(), "TODO: support floating point weights" w_ref, w_q, w_s, w_zp = quantize_weights( @@ -96,21 +104,24 @@ def quantize_and_pack(atype: torch.dtype, group_size=group_size, zero_points=zero_points, # to match how the kernel applies zps - ref_zero_points_after_scales=True) + ref_zero_points_after_scales=True, + ) w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) return w_ref, w_q, w_s, w_zp -def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, - group_size: Optional[int]) -> list[BenchmarkTensors]: +def create_bench_tensors( + shape: tuple[int, int, int], types: TypeConfig, group_size: Optional[int] +) -> list[BenchmarkTensors]: m, n, k = shape # we want to make sure that weights don't fit into L2 cache between runs so # we construct enough weights to exceed L2 cache, which is 50mb on a H100 # so we target total weight size > 2*50mb - num_weights = math.ceil(2 * 50 * 1024**2 * 8 / - (k * n * types.weight_type.size_bits)) + num_weights = math.ceil( + 2 * 50 * 1024**2 * 8 / (k * n * types.weight_type.size_bits) + ) a = rand_data((m, k), types.act_type, scale=5) @@ -124,8 +135,13 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w = w.to(torch.float16) w_ref, w_q_packed, w_s, w_zp = quantize_and_pack( - a.dtype, w, types.weight_type, types.group_scale_type, group_size, - types.group_zero_type is not None) + a.dtype, + w, + types.weight_type, + types.group_scale_type, + group_size, + types.group_zero_type is not None, + ) if not a.dtype.is_floating_point: aiinfo = torch.iinfo(a.dtype) @@ -133,21 +149,30 @@ def create_bench_tensors(shape: tuple[int, int, int], types: TypeConfig, w_ref = w_ref.to(torch.float32) - w_ch_s = None if types.channel_scale_type is None else\ - rand_data((n,), types.channel_scale_type) - w_tok_s = None if types.token_scale_type is None else\ - rand_data((m,), types.token_scale_type) + w_ch_s = ( + None + if types.channel_scale_type is None + else rand_data((n,), types.channel_scale_type) + ) + w_tok_s = ( + None + if types.token_scale_type is None + else rand_data((m,), types.token_scale_type) + ) benchmark_tensors.append( - BenchmarkTensors(w_ref=w_ref, - a=a, - w_q=w_q_packed, - wtype=types.weight_type, - w_g_s=w_s, - w_g_zp=w_zp, - group_size=group_size, - w_ch_s=w_ch_s, - w_tok_s=w_tok_s)) + BenchmarkTensors( + w_ref=w_ref, + a=a, + w_q=w_q_packed, + wtype=types.weight_type, + w_g_s=w_s, + w_g_zp=w_zp, + group_size=group_size, + w_ch_s=w_ch_s, + w_tok_s=w_tok_s, + ) + ) return benchmark_tensors @@ -170,50 +195,57 @@ def cutlass_scaled_mm_create_bench_fn(bt: BenchmarkTensors) -> Callable: scale_b = torch.tensor(1.0, dtype=torch.float32, device=bt.a.device) w_col_major = bt.w_ref.to(bt.a.dtype).t().contiguous().t() return lambda: ops.cutlass_scaled_mm( - bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16) + bt.a, w_col_major, scale_a, scale_b, out_dtype=torch.float16 + ) def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: device = bt.a.device - workspace = MarlinWorkspace(bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = MarlinWorkspace( + bt.w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) if bt.w_g_zp is None: w_zp = torch.empty(0, dtype=torch.int, device=device) else: - w_zp = marlin_zero_points(bt.w_g_zp, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_zp = marlin_zero_points( + bt.w_g_zp, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.group_size is None: w_s = torch.tensor([], device="cuda", dtype=torch.half) else: - w_s = marlin_permute_scales(bt.w_g_s, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.group_size) + w_s = marlin_permute_scales( + bt.w_g_s, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.group_size + ) sort_indices = torch.empty(0, dtype=torch.int, device=device) g_idx = torch.empty(0, dtype=torch.int, device=device) - w_q = ops.gptq_marlin_repack(bt.w_q, sort_indices, bt.w_ref.shape[0], - bt.w_ref.shape[1], bt.wtype.size_bits) + w_q = ops.gptq_marlin_repack( + bt.w_q, sort_indices, bt.w_ref.shape[0], bt.w_ref.shape[1], bt.wtype.size_bits + ) if bt.a.dtype.is_floating_point: assert bt.w_ch_s is None assert bt.w_tok_s is None assert bt.group_size is not None - fn = lambda: ops.gptq_marlin_gemm(a=bt.a, - b_q_weight=w_q, - b_scales=w_s, - b_zeros=w_zp, - g_idx=g_idx, - perm=sort_indices, - workspace=workspace.scratch, - b_q_type=bt.wtype, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0], - is_k_full=True, - is_zp_float=False) + fn = lambda: ops.gptq_marlin_gemm( + a=bt.a, + b_q_weight=w_q, + b_scales=w_s, + b_zeros=w_zp, + g_idx=g_idx, + perm=sort_indices, + workspace=workspace.scratch, + b_q_type=bt.wtype, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + is_k_full=True, + is_zp_float=False, + ) else: assert bt.a.dtype == torch.int8 assert bt.wtype == scalar_types.uint4b8 @@ -221,36 +253,35 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable: if bt.w_ch_s is not None: s_ch = bt.w_ch_s.to(torch.float32) else: - s_ch = torch.ones(bt.w_ref.shape[1], - dtype=torch.float32, - device=device) + s_ch = torch.ones(bt.w_ref.shape[1], dtype=torch.float32, device=device) if bt.w_tok_s is not None: s_tok = bt.w_tok_s.to(torch.float32) else: - s_tok = torch.ones(bt.a.shape[0], - dtype=torch.float32, - device=device) - - fn = lambda: ops.marlin_qqq_gemm(a=bt.a, - b_q_weight=w_q, - s_group=w_s, - s_tok=s_tok, - s_ch=s_ch, - workspace=workspace.scratch, - size_m=bt.a.shape[0], - size_n=bt.w_ref.shape[1], - size_k=bt.w_ref.shape[0]) + s_tok = torch.ones(bt.a.shape[0], dtype=torch.float32, device=device) + + fn = lambda: ops.marlin_qqq_gemm( + a=bt.a, + b_q_weight=w_q, + s_group=w_s, + s_tok=s_tok, + s_ch=s_ch, + workspace=workspace.scratch, + size_m=bt.a.shape[0], + size_n=bt.w_ref.shape[1], + size_k=bt.w_ref.shape[0], + ) return fn -def machete_create_bench_fn(bt: BenchmarkTensors, - out_type=torch.dtype, - schedule=None) -> Callable: +def machete_create_bench_fn( + bt: BenchmarkTensors, out_type=torch.dtype, schedule=None +) -> Callable: w_q = bt.w_q.t().contiguous().t() # make col major - w_q = ops.machete_prepack_B(w_q, bt.a.dtype, bt.wtype, - None if bt.w_g_s is None else bt.w_g_s.dtype) + w_q = ops.machete_prepack_B( + w_q, bt.a.dtype, bt.wtype, None if bt.w_g_s is None else bt.w_g_s.dtype + ) w_g_zp = bt.w_g_zp if w_g_zp is not None: @@ -275,26 +306,24 @@ def machete_create_bench_fn(bt: BenchmarkTensors, # bench -def bench_fns(label: str, sub_label: str, description: str, - fns: list[Callable]): - +def bench_fns(label: str, sub_label: str, description: str, fns: list[Callable]): min_run_time = 1 if not NVTX_PROFILE else 0.1 res = TBenchmark.Timer( stmt=""" for fn in fns: fn() """, - globals={ - "fns": fns - }, + globals={"fns": fns}, label=label, sub_label=sub_label, description=description, ).blocked_autorange(min_run_time=min_run_time) if NVTX_PROFILE: - with nvtx.annotate("mm-bench"), nvtx.annotate( - f"{label}|{sub_label}|{description}"): + with ( + nvtx.annotate("mm-bench"), + nvtx.annotate(f"{label}|{sub_label}|{description}"), + ): fns[0]() return res @@ -304,19 +333,20 @@ _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None -def bench(types: TypeConfig, - group_size: int, - m: int, - k: int, - n: int, - label: str, - sub_label: str, - sweep_schedules: bool = True) -> list[TMeasurement]: +def bench( + types: TypeConfig, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + sweep_schedules: bool = True, +) -> list[TMeasurement]: benchmark_tensors = create_bench_tensors((m, n, k), types, group_size) sub_label += f", L={len(benchmark_tensors)}" - name_type_string = f"W{types.weight_type}"+\ - f"-A{terse_type_name(types.act_type)}" + name_type_string = f"W{types.weight_type}" + f"-A{terse_type_name(types.act_type)}" if types.group_scale_type is not None: name_type_string += f"-GS{terse_type_name(types.group_scale_type)}" if types.group_zero_type is not None: @@ -332,31 +362,45 @@ def bench(types: TypeConfig, # pytorch impl timers.append( bench_fns( - label, sub_label, "torch.matmul (fp16)", - [torch_matmul_f16_create_bench_fn(bt) - for bt in benchmark_tensors])) + label, + sub_label, + "torch.matmul (fp16)", + [torch_matmul_f16_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type == torch.int8 or types.act_type == torch.float8_e4m3fn: timers.append( bench_fns( - label, sub_label, - f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", [ - cutlass_scaled_mm_create_bench_fn(bt) - for bt in benchmark_tensors - ])) + label, + sub_label, + f"cutlass_scaled_mm ({terse_type_name(types.act_type)})", + [cutlass_scaled_mm_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) if types.act_type != torch.float8_e4m3fn: timers.append( - bench_fns(label, sub_label, f"marlin ({name_type_string})", - [marlin_create_bench_fn(bt) - for bt in benchmark_tensors])) + bench_fns( + label, + sub_label, + f"marlin ({name_type_string})", + [marlin_create_bench_fn(bt) for bt in benchmark_tensors], + ) + ) # machete timers.append( - bench_fns(label, sub_label, f"machete ({name_type_string})", [ - machete_create_bench_fn(bt, out_type=types.output_type) - for bt in benchmark_tensors - ])) + bench_fns( + label, + sub_label, + f"machete ({name_type_string})", + [ + machete_create_bench_fn(bt, out_type=types.output_type) + for bt in benchmark_tensors + ], + ) + ) if sweep_schedules: global _SWEEP_SCHEDULES_RESULTS @@ -371,7 +415,8 @@ def bench(types: TypeConfig, group_zeros_type=types.group_zero_type, token_scales_type=types.token_scale_type, channel_scales_type=types.channel_scale_type, - out_type=types.output_type) + out_type=types.output_type, + ) if schedules is None or len(schedules) == 0: raise ValueError("No schedules found to sweep") @@ -383,11 +428,17 @@ def bench(types: TypeConfig, if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: continue - res = bench_fns(label, sub_label, "machete_best", [ - machete_create_bench_fn( - bt, out_type=types.output_type, schedule=schedule) - for bt in benchmark_tensors - ]) + res = bench_fns( + label, + sub_label, + "machete_best", + [ + machete_create_bench_fn( + bt, out_type=types.output_type, schedule=schedule + ) + for bt in benchmark_tensors + ], + ) results_row = { "M": m, @@ -398,10 +449,8 @@ def bench(types: TypeConfig, "median": res.median, } if _SWEEP_SCHEDULES_RESULTS is None: - _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( - columns=results_row.keys()) - _SWEEP_SCHEDULES_RESULTS.\ - loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame(columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row print(f" {res.median:5.5} ", schedule) if not best or res.median < best.median: @@ -422,8 +471,9 @@ def print_timers(timers: list[TMeasurement]): def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: types = TypeConfig( act_type=args.act_type, - weight_type=scalar_types.uint4b8 if args.group_zero_type is None \ - else scalar_types.uint4, + weight_type=scalar_types.uint4b8 + if args.group_zero_type is None + else scalar_types.uint4, output_type=args.out_type, group_scale_type=args.group_scale_type, group_zero_type=args.group_zero_type, @@ -433,14 +483,16 @@ def run(args, MKNs: Iterable[tuple[int, int, int]]) -> Iterable[TMeasurement]: results: list[TMeasurement] = [] for m, k, n in MKNs: - timers = bench(types, - args.group_size, - m, - k, - n, - f"{args.act_type}-gemm", - f"MKN=({m}x{k}x{n})", - sweep_schedules=args.sweep_schedules) + timers = bench( + types, + args.group_size, + m, + k, + n, + f"{args.act_type}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=args.sweep_schedules, + ) print_timers(timers) results.extend(timers) @@ -454,7 +506,6 @@ def make_output( base_description: str, timestamp=None, ): - print(f"== All Results {base_description} ====") print_timers(data) @@ -468,8 +519,7 @@ def make_output( def run_square_bench(args): - dim_sizes = list( - range(args.dim_start, args.dim_end + 1, args.dim_increment)) + dim_sizes = list(range(args.dim_start, args.dim_end + 1, args.dim_increment)) MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) data = run(args.dtype, args.sweep_schedules, MKNs) @@ -479,8 +529,9 @@ def run_square_bench(args): def run_range_bench(args): m_start, k_start, n_start = (int(x) for x in args.dim_start.split(",")) m_end, k_end, n_end = (int(x) for x in args.dim_end.split(",")) - m_increment, k_increment, n_increment = \ - (int(x) for x in args.dim_increment.split(",")) + m_increment, k_increment, n_increment = ( + int(x) for x in args.dim_increment.split(",") + ) Ms = list(range(m_start, m_end + 1, m_increment)) Ks = list(range(k_start, k_end + 1, k_increment)) Ns = list(range(n_start, n_end + 1, n_increment)) @@ -492,7 +543,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") @@ -535,10 +585,13 @@ def run_model_bench(args): with open(f"model_bench-{type_string}-{timestr}.pkl", "wb") as f: args_dict = vars(args) args_dict.pop("func") - pkl.dump({ - "args": args_dict, - "results": all_results, - }, f) + pkl.dump( + { + "args": args_dict, + "results": all_results, + }, + f, + ) if __name__ == "__main__": @@ -554,7 +607,6 @@ if __name__ == "__main__": }[dt] class ToTorchDtype(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, to_torch_dtype(values)) @@ -580,32 +632,32 @@ Benchmark Machete GEMM. "--act-type", action=ToTorchDtype, required=True, - choices=['bfloat16', 'float16', 'int8', 'float8_e4m3fn'], + choices=["bfloat16", "float16", "int8", "float8_e4m3fn"], ) parser.add_argument( "--group-scale-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-zero-type", type=to_torch_dtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--channel-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--token-scale-type", action=ToTorchDtype, - choices=['float'], + choices=["float"], ) parser.add_argument( "--out-type", action=ToTorchDtype, - choices=['bfloat16', 'float16'], + choices=["bfloat16", "float16"], ) parser.add_argument( "--group-size", @@ -618,9 +670,11 @@ Benchmark Machete GEMM. action="store_true", help="Run a sweep over all supported schedules", ) - parser.add_argument("--sweep-csv-out", - help="CSV to store sweep results", - default="sch_sweep_results.csv") + parser.add_argument( + "--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv", + ) subparsers = parser.add_subparsers(dest="cmd", required=True) square_parser = subparsers.add_parser("square_bench") @@ -634,17 +688,20 @@ Benchmark Machete GEMM. "--dim-start", type=str, required=True, - help="Start value for M,K,N as common separated list") + help="Start value for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-end", type=str, required=True, - help="End value (inclusive) for M,K,N as common separated list") + help="End value (inclusive) for M,K,N as common separated list", + ) range_parser.add_argument( "--dim-increment", type=str, required=True, - help="Increment value for M,K,N as common separated list") + help="Increment value for M,K,N as common separated list", + ) range_parser.set_defaults(func=run_range_bench) model_parser = subparsers.add_parser("model_bench") @@ -655,14 +712,12 @@ Benchmark Machete GEMM. default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - model_parser.add_argument("--tp-sizes", - nargs="+", - type=int, - default=DEFAULT_TP_SIZES) - model_parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + model_parser.add_argument( + "--tp-sizes", nargs="+", type=int, default=DEFAULT_TP_SIZES + ) + model_parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) model_parser.set_defaults(func=run_model_bench) args = parser.parse_args() diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index 1e785ac8fc73a539abc74c13b3f75e9df1846228..b17baff2e5f5d36042e4737454b75c4dd5868cd6 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -6,19 +6,34 @@ from benchmark_shapes import WEIGHT_SHAPES from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES) + GPTQ_MARLIN_24_MAX_PARALLEL, + GPTQ_MARLIN_24_MIN_THREAD_N, + GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, + GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.allspark_utils import ( - ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, ALLSPARK_SUPPORTED_QUANT_TYPES) + ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD, + ALLSPARK_SUPPORTED_QUANT_TYPES, +) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, - MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types) + GPTQ_MARLIN_MAX_PARALLEL, + GPTQ_MARLIN_MIN_THREAD_N, + MARLIN_SUPPORTED_GROUP_SIZES, + query_marlin_supported_quant_types, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - MarlinWorkspace, marlin_quantize) + MarlinWorkspace, + marlin_quantize, +) from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import ( - marlin_24_quantize) + marlin_24_quantize, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( - gptq_pack, gptq_quantize_weights, quantize_weights, sort_weights) + gptq_pack, + gptq_quantize_weights, + quantize_weights, + sort_weights, +) from vllm.scalar_type import ScalarType from vllm.utils import FlexibleArgumentParser @@ -29,22 +44,29 @@ ACT_ORDER_OPTS = [False, True] K_FULL_OPTS = [False, True] -def bench_run(results: list[benchmark.Measurement], model: str, - act_order: bool, is_k_full: bool, quant_type: ScalarType, - group_size: int, size_m: int, size_k: int, size_n: int): +def bench_run( + results: list[benchmark.Measurement], + model: str, + act_order: bool, + is_k_full: bool, + quant_type: ScalarType, + group_size: int, + size_m: int, + size_k: int, + size_n: int, +): label = "Quant Matmul" - sub_label = ("{}, act={} k_full={}, q={}, g={}, " - "MKN=({}x{}x{})".format(model, act_order, is_k_full, - str(quant_type), group_size, size_m, - size_k, size_n)) + sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format( + model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n + ) print(f"Testing: {sub_label}") a = torch.randn(size_m, size_k).to(torch.half).cuda() b = torch.rand(size_k, size_n).to(torch.half).cuda() - a_tmp = (torch.zeros(size_m, size_k).to(torch.half).cuda()) + a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda() # Marlin quant ( @@ -57,14 +79,16 @@ def bench_run(results: list[benchmark.Measurement], model: str, ) = marlin_quantize(b, quant_type, group_size, act_order) # Marlin_24 quant - (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, - marlin_24_s) = marlin_24_quantize(b, quant_type, group_size) + (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = ( + marlin_24_quantize(b, quant_type, group_size) + ) marlin_zp = torch.empty(0, dtype=torch.int, device=b.device) # GPTQ quant - (w_ref, q_w, s, g_idx, - rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order) + (w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights( + b, quant_type, group_size, act_order + ) q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n) # For act_order, sort the "weights" and "g_idx" @@ -74,32 +98,37 @@ def bench_run(results: list[benchmark.Measurement], model: str, (q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx) # Prepare - marlin_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + marlin_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL + ) - marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N, - GPTQ_MARLIN_24_MAX_PARALLEL) + marlin_24_workspace = MarlinWorkspace( + size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL + ) marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int) # AllSpark W8A16 quant - as_supported_case = (quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES - and group_size == -1 and not act_order and is_k_full) + as_supported_case = ( + quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES + and group_size == -1 + and not act_order + and is_k_full + ) if as_supported_case: properties = torch.cuda.get_device_properties(b.device.index) sm_count = properties.multi_processor_count sm_version = properties.major * 10 + properties.minor - supported_arch = (sm_version >= 80 and sm_version < 90) + supported_arch = sm_version >= 80 and sm_version < 90 as_supported_case = as_supported_case and supported_arch if supported_arch: has_zp = False - w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, - has_zp) + w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp) qw = qw.to(torch.uint8) - qw_reorder, s_reorder, zp_reorder = \ - ops.allspark_repack_weight( - qw, s, zp, has_zp) + qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight( + qw, s, zp, has_zp + ) CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD globals = { @@ -136,8 +165,7 @@ def bench_run(results: list[benchmark.Measurement], model: str, "zp_reorder": zp_reorder if as_supported_case else None, "sm_count": sm_count if as_supported_case else None, "sm_version": sm_version if as_supported_case else None, - "CUBLAS_M_THRESHOLD": - CUBLAS_M_THRESHOLD if as_supported_case else None, + "CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None, # Kernels "gptq_marlin_gemm": ops.gptq_marlin_gemm, "gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm, @@ -158,60 +186,63 @@ def bench_run(results: list[benchmark.Measurement], model: str, label=label, sub_label=sub_label, description="pytorch_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp16", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 + stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) - if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES - and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES): + if ( + quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES + and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES + ): results.append( benchmark.Timer( - stmt= - "output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 + stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_24_gemm", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) results.append( benchmark.Timer( - stmt= - "q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 + stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="gptq_marlin_repack", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) if as_supported_case: results.append( benchmark.Timer( - stmt= - "output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 + stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, description="allspark_w8a16_gemm_fp32", - ).blocked_autorange(min_run_time=min_run_time)) + ).blocked_autorange(min_run_time=min_run_time) + ) def main(args): @@ -233,37 +264,50 @@ def main(args): continue for act_order in ACT_ORDER_OPTS: - if len(args.limit_act_order - ) > 0 and act_order not in args.limit_act_order: + if ( + len(args.limit_act_order) > 0 + and act_order not in args.limit_act_order + ): continue for is_k_full in K_FULL_OPTS: - if len(args.limit_k_full - ) > 0 and is_k_full not in args.limit_k_full: + if ( + len(args.limit_k_full) > 0 + and is_k_full not in args.limit_k_full + ): continue - for quant_type in query_marlin_supported_quant_types( - False): - if len(args.limit_num_bits) > 0 and \ - quant_type.size_bits not in args.limit_num_bits: + for quant_type in query_marlin_supported_quant_types(False): + if ( + len(args.limit_num_bits) > 0 + and quant_type.size_bits not in args.limit_num_bits + ): continue for group_size in MARLIN_SUPPORTED_GROUP_SIZES: - if len( - args.limit_group_size - ) > 0 and group_size not in args.limit_group_size: + if ( + len(args.limit_group_size) > 0 + and group_size not in args.limit_group_size + ): continue # For act_order, the group_size must be less than # size_k - if act_order and (group_size == size_k - or group_size == -1): + if act_order and (group_size == size_k or group_size == -1): continue for size_m in args.batch_sizes: - bench_run(results, model, act_order, is_k_full, - quant_type, group_size, size_m, - size_k, size_n) + bench_run( + results, + model, + act_order, + is_k_full, + quant_type, + group_size, + size_m, + size_k, + size_n, + ) compare = benchmark.Compare(results) compare.print() @@ -274,7 +318,8 @@ def main(args): # if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark Marlin across specified models/shapes/batches") + description="Benchmark Marlin across specified models/shapes/batches" + ) parser.add_argument( "--models", nargs="+", @@ -282,10 +327,9 @@ if __name__ == "__main__": default=DEFAULT_MODELS, choices=WEIGHT_SHAPES.keys(), ) - parser.add_argument("--batch-sizes", - nargs="+", - type=int, - default=DEFAULT_BATCH_SIZES) + parser.add_argument( + "--batch-sizes", nargs="+", type=int, default=DEFAULT_BATCH_SIZES + ) parser.add_argument("--limit-k", nargs="+", type=int, default=[]) parser.add_argument("--limit-n", nargs="+", type=int, default=[]) parser.add_argument("--limit-group-size", nargs="+", type=int, default=[]) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index a274537a67515db153f0a255c2a6ebabbc8baf5f..c2f7660858f574791d1477d69f936b354eba6f20 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -6,16 +6,17 @@ import time from contextlib import nullcontext from datetime import datetime from itertools import product +from types import SimpleNamespace from typing import Any, TypedDict import ray import torch -import triton from ray.experimental.tqdm_ray import tqdm -from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * from vllm.platforms import current_platform +from vllm.transformers_utils.config import get_config +from vllm.triton_utils import triton from vllm.utils import FlexibleArgumentParser FP8_DTYPE = current_platform.fp8_dtype() @@ -30,56 +31,60 @@ class BenchmarkConfig(TypedDict): num_stages: int -def benchmark_config(config: BenchmarkConfig, - num_tokens: int, - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - num_iters: int = 100, - block_quant_shape: List[int] = None, - use_deep_gemm: bool = False) -> float: +def benchmark_config( + config: BenchmarkConfig, + num_tokens: int, + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + block_quant_shape: List[int] = None, + use_deep_gemm: bool = False, +) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) if use_int8_w8a16: - w1 = torch.randint(-127, - 127, ( - num_experts, - shard_intermediate_size, - hidden_size, - ), - dtype=torch.int8) - w2 = torch.randint(-127, - 127, ( - num_experts, - hidden_size, - shard_intermediate_size // 2, - ), - dtype=torch.int8) + w1 = torch.randint( + -127, + 127, + ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8, + ) + w2 = torch.randint( + -127, + 127, + ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8, + ) else: - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) - gating_output = torch.randn(num_iters, - num_tokens, - num_experts, - dtype=torch.float32) + w1 = torch.randn( + num_experts, shard_intermediate_size, hidden_size, dtype=init_dtype + ) + w2 = torch.randn( + num_experts, hidden_size, shard_intermediate_size // 2, dtype=init_dtype + ) + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) w1_scale = None w2_scale = None a1_scale = None a2_scale = None if use_int8_w8a16: - w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), - dtype=torch.float32) + w1_scale = torch.randn( + (num_experts, 2 * shard_intermediate_size), dtype=torch.float32 + ) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: if block_quant_shape: @@ -92,10 +97,14 @@ def benchmark_config(config: BenchmarkConfig, n_tiles_w2 = (K + block_n - 1) // block_n k_tiles_w1 = (K + block_k - 1) // block_k k_tiles_w2 = (N + block_k - 1) // block_k - w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), - dtype=torch.float32) * factor_for_scale - w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), - dtype=torch.float32) * factor_for_scale + w1_scale = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_scale = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) else: w1_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32) @@ -113,10 +122,12 @@ def benchmark_config(config: BenchmarkConfig, def run(): from vllm.model_executor.layers.fused_moe import override_config + with override_config(config): if use_deep_gemm: - topk_weights, topk_ids = fused_topk(x, input_gating, topk, - False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + x, input_gating, topk, False + ) return fused_experts( x, w1, @@ -212,8 +223,7 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16, - block_quant_shape) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -249,20 +259,25 @@ def get_configs_compute_bound(use_fp16, if block_quant_shape is not None and not use_fp16: block_n, block_k = block_quant_shape[0], block_quant_shape[1] for config in configs[:]: - if config["BLOCK_SIZE_K"] % block_k != 0 or config[ - "BLOCK_SIZE_N"] % block_n != 0: + if ( + config["BLOCK_SIZE_K"] % block_k != 0 + or config["BLOCK_SIZE_N"] % block_n != 0 + ): configs.remove(config) return configs -def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16, topk): +def prune_rocm_search_space( + num_tokens, shard_intermediate_size, hidden_size, search_space, is_fp16, topk +): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, - search_space, is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, - search_space, is_fp16) + pruned_space_1 = prune_rocm_configs( + num_tokens * topk, N1, K1, search_space, is_fp16 + ) + pruned_space_2 = prune_rocm_configs( + num_tokens * topk, N2, K2, search_space, is_fp16 + ) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -300,14 +315,14 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): SPLIT_K = config.get("SPLIT_K", 1) GROUP_M = config.get("GROUP_SIZE_M") if is_fp16: - if (matrix_instr_nonkdim > BLOCK_SIZE_M - or matrix_instr_nonkdim > BLOCK_SIZE_N): + if ( + matrix_instr_nonkdim > BLOCK_SIZE_M + or matrix_instr_nonkdim > BLOCK_SIZE_N + ): continue - if (matrix_instr_nonkdim >= M - and matrix_instr_nonkdim != BLOCK_SIZE_M): + if matrix_instr_nonkdim >= M and matrix_instr_nonkdim != BLOCK_SIZE_M: continue - if (matrix_instr_nonkdim >= N - and matrix_instr_nonkdim != BLOCK_SIZE_N): + if matrix_instr_nonkdim >= N and matrix_instr_nonkdim != BLOCK_SIZE_N: continue # Skip BLOCK_SIZE that is too large compare to M/N # unless BLOCK_SIZE is already small enough @@ -328,8 +343,10 @@ def prune_rocm_configs(M, N, K, configs, is_fp16=True): continue # out of shared memory resource # TODO (zhanglx): This does not consider the LDS usage in the epilogue - LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + - BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + LDS = ( + BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b + ) if LDS > 65536: continue # Skip small block sizes and num_warps for large gemm @@ -363,7 +380,6 @@ def merge_unique_dicts(list1, list2): @ray.remote(num_gpus=1) class BenchmarkWorker: - def __init__(self, seed: int) -> None: torch.set_default_device("cuda") current_platform.seed_everything(seed) @@ -387,36 +403,40 @@ class BenchmarkWorker: use_deep_gemm: bool = False, ) -> tuple[dict[str, int], float]: current_platform.seed_everything(self.seed) - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, - dtype_str) + op_config = get_moe_configs( + num_experts, shard_intermediate_size // 2, dtype_str + ) if op_config is None: - config = get_default_config(num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype_str, - is_marlin=False) + config = get_default_config( + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype_str, + is_marlin=False, + ) else: - config = op_config[min(op_config.keys(), - key=lambda x: abs(x - num_tokens))] - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=100, - block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + config = op_config[min(op_config.keys(), key=lambda x: abs(x - num_tokens))] + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + block_quant_shape=block_quant_shape, + use_deep_gemm=use_deep_gemm, + ) return config, kernel_time def tune( @@ -437,13 +457,22 @@ class BenchmarkWorker: best_time = float("inf") if current_platform.is_rocm(): is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = prune_rocm_search_space(num_tokens, - shard_intermediate_size, - hidden_size, search_space, - is_fp16, topk) + search_space = prune_rocm_search_space( + num_tokens, + shard_intermediate_size, + hidden_size, + search_space, + is_fp16, + topk, + ) + + need_device_guard = False + if current_platform.is_rocm(): + visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None) + if visible_device != f"{self.device_id}": + need_device_guard = True - with torch.cuda.device(self.device_id) if current_platform.is_rocm( - ) else nullcontext(): + with torch.cuda.device(self.device_id) if need_device_guard else nullcontext(): for config in tqdm(search_space): try: kernel_time = benchmark_config( @@ -458,7 +487,8 @@ class BenchmarkWorker: use_int8_w8a16, num_iters=20, block_quant_shape=block_quant_shape, - use_deep_gemm=use_deep_gemm) + use_deep_gemm=use_deep_gemm, + ) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -474,42 +504,44 @@ class BenchmarkWorker: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: return { - "BLOCK_SIZE_M": - config["BLOCK_SIZE_M"], - "BLOCK_SIZE_N": - config["BLOCK_SIZE_N"], - "BLOCK_SIZE_K": - config["BLOCK_SIZE_K"], - "GROUP_SIZE_M": - config["GROUP_SIZE_M"], - "num_warps": - config["num_warps"], - "num_stages": - config["num_stages"], - **({ - "waves_per_eu": config["waves_per_eu"] - } if "waves_per_eu" in config else {}), - **({ - "matrix_instr_nonkdim": config["matrix_instr_nonkdim"] - } if "matrix_instr_nonkdim" in config else {}), - **({ - "kpack": config["kpack"] - } if "kpack" in config else {}), + "BLOCK_SIZE_M": config["BLOCK_SIZE_M"], + "BLOCK_SIZE_N": config["BLOCK_SIZE_N"], + "BLOCK_SIZE_K": config["BLOCK_SIZE_K"], + "GROUP_SIZE_M": config["GROUP_SIZE_M"], + "num_warps": config["num_warps"], + "num_stages": config["num_stages"], + **( + {"waves_per_eu": config["waves_per_eu"]} if "waves_per_eu" in config else {} + ), + **( + {"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]} + if "matrix_instr_nonkdim" in config + else {} + ), + **({"kpack": config["kpack"]} if "kpack" in config else {}), } -def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, - shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, - block_quant_shape: List[int]) -> None: - dtype_str = get_config_dtype_str(dtype, - use_int8_w8a16=use_int8_w8a16, - use_fp8_w8a8=use_fp8_w8a8) +def save_configs( + configs: dict[int, BenchmarkConfig], + num_experts: int, + shard_intermediate_size: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + block_quant_shape: List[int], +) -> None: + dtype_str = get_config_dtype_str( + dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8 + ) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. - filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str, block_quant_shape) + filename = get_config_file_name( + num_experts, shard_intermediate_size // 2, dtype_str, block_quant_shape + ) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -518,18 +550,20 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def get_weight_block_size_safety(config, default_value=None): - - quantization_config = getattr(config, 'quantization_config', {}) + quantization_config = getattr(config, "quantization_config", {}) if isinstance(quantization_config, dict): - return quantization_config.get('weight_block_size', default_value) + return quantization_config.get("weight_block_size", default_value) return default_value def main(args: argparse.Namespace): print(args) - config = AutoConfig.from_pretrained( - args.model, trust_remote_code=args.trust_remote_code) + config = get_config(model=args.model, trust_remote_code=args.trust_remote_code) + if args.model_prefix: + config = getattr(config, args.model_prefix) + config = SimpleNamespace(**config) + if config.architectures[0] == "DbrxForCausalLM": E = config.ffn_config.moe_num_experts topk = config.ffn_config.moe_top_k @@ -540,15 +574,12 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif (config.architectures[0] == "DeepseekV3ForCausalLM" - or config.architectures[0] == "DeepseekV2ForCausalLM"): + elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"): E = config.n_routed_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size - elif config.architectures[0] in [ - "Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM" - ]: + elif config.architectures[0] in ("Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"): E = config.num_experts topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size @@ -563,21 +594,51 @@ def main(args: argparse.Namespace): shard_intermediate_size = 2 * intermediate_size // args.tp_size hidden_size = config.hidden_size - dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + dtype = ( + torch.float16 + if current_platform.is_rocm() + else getattr(torch, config.torch_dtype) + ) use_fp8_w8a8 = args.dtype == "fp8_w8a8" use_int8_w8a16 = args.dtype == "int8_w8a16" block_quant_shape = get_weight_block_size_safety(config) if args.batch_size is None: batch_sizes = [ - 1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536, - 2048, 3072, 4096 + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, ] else: batch_sizes = [args.batch_size] use_deep_gemm = bool(args.use_deep_gemm) + if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ: + # Ray will set ROCR_VISIBLE_DEVICES for device visibility + logger.warning( + "Ray uses ROCR_VISIBLE_DEVICES to control device accessibility." + "Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES." + ) + val = os.environ["HIP_VISIBLE_DEVICES"] + os.environ["ROCR_VISIBLE_DEVICES"] = val + del os.environ["HIP_VISIBLE_DEVICES"] + ray.init() num_gpus = int(ray.available_resources()["GPU"]) workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] @@ -600,25 +661,59 @@ def main(args: argparse.Namespace): start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, - block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + "tune", + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + search_space, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) best_configs = { - M: sort_config(config) - for M, config in zip(batch_sizes, configs) + M: sort_config(config) for M, config in zip(batch_sizes, configs) } - save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, - block_quant_shape) + save_configs( + best_configs, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + ) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( "benchmark", - [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, - use_fp8_w8a8, use_int8_w8a16, block_quant_shape, use_deep_gemm) - for batch_size in batch_sizes]) + [ + ( + batch_size, + E, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + block_quant_shape, + use_deep_gemm, + ) + for batch_size in batch_sizes + ], + ) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") @@ -627,23 +722,21 @@ def main(args: argparse.Namespace): if __name__ == "__main__": parser = FlexibleArgumentParser() - parser.add_argument("--model", - type=str, - default="mistralai/Mixtral-8x7B-Instruct-v0.1") - parser.add_argument("--tp-size", - "-tp", - "--tensor-parallel-size", - type=int, - default=2) - parser.add_argument("--dtype", - type=str, - choices=["auto", "fp8_w8a8", "int8_w8a16"], - default="auto") + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--tp-size", "-tp", "--tensor-parallel-size", type=int, default=2 + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) parser.add_argument("--use-deep-gemm", action="store_true") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--tune", action="store_true") parser.add_argument("--trust-remote-code", action="store_true") + parser.add_argument("--model-prefix", type=str, required=False) args = parser.parse_args() main(args) diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py new file mode 100644 index 0000000000000000000000000000000000000000..333986fdf5eff52574194bbea7c51a8000c37424 --- /dev/null +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -0,0 +1,416 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +from typing import Any, TypedDict + +import ray +import torch +from transformers import AutoConfig + +from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + _moe_permute, + _moe_unpermute_and_reduce, +) +from vllm.model_executor.layers.fused_moe.fused_moe import * +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import * +from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize +from vllm.platforms import current_platform +from vllm.utils import FlexibleArgumentParser + +FP8_DTYPE = current_platform.fp8_dtype() + + +class BenchmarkConfig(TypedDict): + BLOCK_SIZE_M: int + BLOCK_SIZE_N: int + BLOCK_SIZE_K: int + GROUP_SIZE_M: int + num_warps: int + num_stages: int + + +def benchmark_permute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + # output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + gating_output = torch.randn(num_iters, num_tokens, num_experts, dtype=torch.float32) + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False + ) + + def prepare(i: int): + input_gating.copy_(gating_output[i]) + + def run(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) + else: + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) + + # JIT compilation & warmup + run() + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run() + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + prepare(i) + torch.cuda.synchronize() + + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +def benchmark_unpermute( + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + num_iters: int = 100, + use_customized_permute: bool = False, +) -> float: + # init_dtype = torch.float16 if use_fp8_w8a8 else dtype + hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype) + output_hidden_states = torch.empty_like(hidden_states) + if use_fp8_w8a8: + align_block_size = 128 # deepgemm needs 128 m aligned block + qhidden_states, scale = _fp8_quantize(hidden_states, None, None) + else: + align_block_size = None + qhidden_states = hidden_states + + input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32) + + topk_weights, topk_ids, token_expert_indices = fused_topk( + qhidden_states, input_gating, topk, False + ) + + def prepare(): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = ( + moe_permute( + qhidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + token_expert_indices=token_expert_indices, + topk=topk, + n_expert=num_experts, + n_local_expert=num_experts, + expert_map=None, + align_block_size=align_block_size, + ) + ) + # convert to fp16/bf16 as gemm output + return ( + permuted_hidden_states.to(dtype), + first_token_off, + inv_perm_idx, + m_indices, + ) + else: + ( + permuted_qhidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = _moe_permute( + qhidden_states, None, topk_ids, num_experts, None, align_block_size + ) + # convert to fp16/bf16 as gemm output + return ( + permuted_qhidden_states.to(dtype), + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) + + def run(input: tuple): + if use_customized_permute: + (permuted_hidden_states, first_token_off, inv_perm_idx, m_indices) = input + moe_unpermute( + permuted_hidden_states, + topk_weights, + topk_ids, + inv_perm_idx, + first_token_off, + topk, + num_experts, + num_experts, + ) + else: + ( + permuted_hidden_states, + a1q_scale, + sorted_token_ids, + expert_ids, + inv_perm, + ) = input + _moe_unpermute_and_reduce( + output_hidden_states, permuted_hidden_states, inv_perm, topk_weights + ) + + # JIT compilation & warmup + input = prepare() + run(input) + torch.cuda.synchronize() + + # Capture 10 invocations with CUDA graph + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + for _ in range(10): + run(input) + torch.cuda.synchronize() + + # Warmup + for _ in range(5): + graph.replay() + torch.cuda.synchronize() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + latencies: list[float] = [] + for i in range(num_iters): + torch.cuda.synchronize() + start_event.record() + graph.replay() + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + graph.reset() + return avg + + +@ray.remote(num_gpus=1) +class BenchmarkWorker: + def __init__(self, seed: int) -> None: + torch.set_default_device("cuda") + current_platform.seed_everything(seed) + self.seed = seed + # Get the device ID to allocate tensors and kernels + # on the respective GPU. This is required for Ray to work + # correctly with multi-GPU tuning on the ROCm platform. + self.device_id = int(ray.get_gpu_ids()[0]) + + def benchmark( + self, + num_tokens: int, + num_experts: int, + hidden_size: int, + topk: int, + dtype: torch.dtype, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_customized_permute: bool = False, + ) -> tuple[dict[str, int], float]: + current_platform.seed_everything(self.seed) + + permute_time = benchmark_permute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute, + ) + unpermute_time = benchmark_unpermute( + num_tokens, + num_experts, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=100, + use_customized_permute=use_customized_permute, + ) + return permute_time, unpermute_time + + +def get_weight_block_size_safety(config, default_value=None): + quantization_config = getattr(config, "quantization_config", {}) + if isinstance(quantization_config, dict): + return quantization_config.get("weight_block_size", default_value) + return default_value + + +def main(args: argparse.Namespace): + print(args) + + config = AutoConfig.from_pretrained( + args.model, trust_remote_code=args.trust_remote_code + ) + if config.architectures[0] == "DbrxForCausalLM": + E = config.ffn_config.moe_num_experts + topk = config.ffn_config.moe_top_k + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + elif ( + config.architectures[0] == "DeepseekV3ForCausalLM" + or config.architectures[0] == "DeepseekV2ForCausalLM" + ): + E = config.n_routed_experts + topk = config.num_experts_per_tok + elif config.architectures[0] in ["Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"]: + E = config.num_experts + topk = config.num_experts_per_tok + + else: + # Support for llama4 + config = config.get_text_config() + # Default: Mixtral. + E = config.num_local_experts + topk = config.num_experts_per_tok + + hidden_size = config.hidden_size + dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" + use_customized_permute = args.use_customized_permute + + if args.batch_size is None: + batch_sizes = [ + 1, + 2, + 4, + 8, + 16, + 24, + 32, + 48, + 64, + 96, + 128, + 256, + 512, + 1024, + 1536, + 2048, + 3072, + 4096, + ] + else: + batch_sizes = [args.batch_size] + + ray.init() + num_gpus = int(ray.available_resources()["GPU"]) + workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)] + + def _distribute(method: str, inputs: list[Any]) -> list[Any]: + outputs = [] + worker_idx = 0 + for input_args in inputs: + worker = workers[worker_idx] + worker_method = getattr(worker, method) + output = worker_method.remote(*input_args) + outputs.append(output) + worker_idx = (worker_idx + 1) % num_gpus + return ray.get(outputs) + + outputs = _distribute( + "benchmark", + [ + ( + batch_size, + E, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + use_customized_permute, + ) + for batch_size in batch_sizes + ], + ) + + for batch_size, (permute, unpermute) in zip(batch_sizes, outputs): + print(f"Batch size: {batch_size}") + print(f"Permute time: {permute:.2f} us") + print(f"Unpermute time: {unpermute:.2f} us") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + parser.add_argument( + "--model", type=str, default="mistralai/Mixtral-8x7B-Instruct-v0.1" + ) + parser.add_argument( + "--dtype", type=str, choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto" + ) + parser.add_argument("--use-customized-permute", action="store_true") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--batch-size", type=int, required=False) + parser.add_argument("--trust-remote-code", action="store_true") + args = parser.parse_args() + + main(args) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 2625239b08ef29a0574abd110e54fe9166745ca6..17432159c94e7fa4a2ceb3ff80e36cc041013448 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -9,8 +9,11 @@ import torch from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) +from vllm.utils import ( + STR_DTYPE_TO_TORCH_DTYPE, + FlexibleArgumentParser, + create_kv_caches_with_random, +) logger = init_logger(__name__) @@ -38,19 +41,15 @@ def main( current_platform.seed_everything(seed) scale = float(1.0 / (head_size**0.5)) - query = torch.empty(num_seqs, - num_query_heads, - head_size, - dtype=dtype, - device=device) + query = torch.empty( + num_seqs, num_query_heads, head_size, dtype=dtype, device=device + ) query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 alibi_slopes = None if use_alibi: - alibi_slopes = torch.randn(num_query_heads, - dtype=torch.float, - device=device) + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float, device=device) seq_lens = [seq_len for _ in range(num_seqs)] max_seq_len = max(seq_lens) @@ -61,24 +60,23 @@ def main( block_tables_lst: list[list[int]] = [] for _ in range(num_seqs): block_table = [ - random.randint(0, NUM_BLOCKS - 1) - for _ in range(max_num_blocks_per_seq) + random.randint(0, NUM_BLOCKS - 1) for _ in range(max_num_blocks_per_seq) ] block_tables_lst.append(block_table) - block_tables = torch.tensor(block_tables_lst, - dtype=torch.int, - device=device) + block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device=device) # Create the KV cache. - key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS, - block_size, - 1, - num_kv_heads, - head_size, - kv_cache_dtype, - dtype, - device=device) + key_caches, value_caches = create_kv_caches_with_random( + NUM_BLOCKS, + block_size, + 1, + num_kv_heads, + head_size, + kv_cache_dtype, + dtype, + device=device, + ) key_cache, value_cache = key_caches[0], value_caches[0] # Prepare for the paged attention kernel. @@ -86,11 +84,8 @@ def main( if version == "v2": if current_platform.is_rocm(): global PARTITION_SIZE - if not args.custom_paged_attn: - PARTITION_SIZE = 1024 - else: - PARTITION_SIZE = PARTITION_SIZE_ROCM - num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + PARTITION_SIZE = 1024 if not args.custom_paged_attn else PARTITION_SIZE_ROCM + num_partitions = (max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), dtype=output.dtype, @@ -110,9 +105,7 @@ def main( start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = torch.tensor(1.0, - dtype=torch.float32, - device=device) + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) for _ in range(num_iters): if version == "v1": @@ -195,30 +188,29 @@ def main( print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': - logger.warning("This script benchmarks the paged attention kernel. " - "By default this is no longer used in vLLM inference.") +if __name__ == "__main__": + logger.warning( + "This script benchmarks the paged attention kernel. " + "By default this is no longer used in vLLM inference." + ) - parser = FlexibleArgumentParser( - description="Benchmark the paged attention kernel.") - parser.add_argument("--version", - type=str, - choices=["v1", "v2"], - default="v2") + parser = FlexibleArgumentParser(description="Benchmark the paged attention kernel.") + parser.add_argument("--version", type=str, choices=["v1", "v2"], default="v2") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--seq-len", type=int, default=4096) parser.add_argument("--num-query-heads", type=int, default=64) parser.add_argument("--num-kv-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--block-size", type=int, choices=[16, 32], default=16) parser.add_argument("--use-alibi", action="store_true") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument( @@ -228,10 +220,11 @@ if __name__ == '__main__': default="auto", help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " - "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") - parser.add_argument("--custom-paged-attn", - action="store_true", - help="Use custom paged attention") + "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)", + ) + parser.add_argument( + "--custom-paged-attn", action="store_true", help="Use custom paged attention" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index b643897a60eef3ed32a4796f8b5a4bded3830f5d..2463dfebe83cce8511a68775e4a4666a6607f892 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -10,15 +10,17 @@ from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser @torch.inference_mode() -def main(num_tokens: int, - hidden_size: int, - static_scale: bool, - quant_dtype: torch.dtype, - dtype: torch.dtype, - seed: int = 0, - do_profile: bool = False, - num_warmup_iters: int = 5, - num_iters: int = 100) -> None: +def main( + num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100, +) -> None: current_platform.seed_everything(seed) torch.set_default_device("cuda") @@ -56,7 +58,7 @@ def main(num_tokens: int, print(f"Kernel running time: {latency * 1000000:.3f} us") -if __name__ == '__main__': +if __name__ == "__main__": def to_torch_dtype(dt): if dt == "int8": @@ -66,37 +68,40 @@ if __name__ == '__main__': raise ValueError(f"Unsupported dtype: {dt}") parser = FlexibleArgumentParser( - description="Benchmark the quantization (fp8 or int8) kernel.") + description="Benchmark the quantization (fp8 or int8) kernel." + ) parser.add_argument("--num-tokens", type=int, default=4096) parser.add_argument("--hidden-size", type=int, default=8192) parser.add_argument("--static-scale", action="store_true") - parser.add_argument("--quant-dtype", - type=str, - choices=["fp8", "int8"], - default="int8") - parser.add_argument("--dtype", - type=str, - choices=["half", "bfloat16", "float"], - default="half") + parser.add_argument( + "--quant-dtype", type=str, choices=["fp8", "int8"], default="int8" + ) + parser.add_argument( + "--dtype", type=str, choices=["half", "bfloat16", "float"], default="half" + ) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--profile", action="store_true") parser.add_argument("--num-warmup-iters", type=int, default=5) - parser.add_argument("--num-iters", - type=int, - default=100, - help="Number of benchmark iterations. " - "If --profile is set, this number is ignored") + parser.add_argument( + "--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored", + ) args = parser.parse_args() print(args) - main(num_tokens=args.num_tokens, - hidden_size=args.hidden_size, - static_scale=args.static_scale, - quant_dtype=to_torch_dtype(args.quant_dtype), - dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], - seed=args.seed, - do_profile=args.profile, - num_warmup_iters=args.num_warmup_iters, - num_iters=args.num_iters) + main( + num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters, + ) diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index eaf6b25e8ca4f33a0c5f9fb5ed2d915d8116aa8c..d720083b615037d9f0236086cf42d392a3494b00 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -4,15 +4,14 @@ import itertools from typing import Optional, Union import torch -import triton from flashinfer.norm import fused_add_rmsnorm, rmsnorm from torch import nn from vllm import _custom_ops as vllm_ops +from vllm.triton_utils import triton class HuggingFaceRMSNorm(nn.Module): - def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) @@ -114,23 +113,19 @@ def rmsnorm_vllm( def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): dtype = torch.bfloat16 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None output_naive = rmsnorm_naive( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_flashinfer = rmsnorm_flashinfer( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) output_vllm = rmsnorm_vllm( - x.clone(), weight, - residual.clone() if residual is not None else None) + x.clone(), weight, residual.clone() if residual is not None else None + ) if use_residual: output_naive = output_naive[0] @@ -141,9 +136,9 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): print(f"FlashInfer output={output_flashinfer}") print(f"vLLM output={output_vllm}") - if torch.allclose(output_naive, output_flashinfer, atol=1e-2, - rtol=1e-2) and torch.allclose( - output_naive, output_vllm, atol=1e-2, rtol=1e-2): + if torch.allclose( + output_naive, output_flashinfer, atol=1e-2, rtol=1e-2 + ) and torch.allclose(output_naive, output_vllm, atol=1e-2, rtol=1e-2): print("✅ All implementations match") else: print("❌ Implementations differ") @@ -152,12 +147,10 @@ def calculate_diff(batch_size, seq_len, hidden_size, use_residual=True): batch_size_range = [2**i for i in range(0, 7, 2)] seq_length_range = [2**i for i in range(6, 11, 1)] head_num_range = [32, 48] -configs = list( - itertools.product(head_num_range, batch_size_range, seq_length_range)) +configs = list(itertools.product(head_num_range, batch_size_range, seq_length_range)) def get_benchmark(use_residual): - @triton.testing.perf_report( triton.testing.Benchmark( x_names=["head_num", "batch_size", "seq_len"], @@ -167,19 +160,15 @@ def get_benchmark(use_residual): line_names=["HuggingFace", "FlashInfer", "vLLM"], styles=[("blue", "-"), ("green", "-"), ("red", "-")], ylabel="us", - plot_name= - f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", + plot_name=f"rmsnorm-perf-{'with' if use_residual else 'without'}-residual", args={}, - )) + ) + ) def benchmark(head_num, batch_size, seq_len, provider): dtype = torch.bfloat16 hidden_size = head_num * 128 # assuming head_dim = 128 - x = torch.randn(batch_size, - seq_len, - hidden_size, - dtype=dtype, - device="cuda") + x = torch.randn(batch_size, seq_len, hidden_size, dtype=dtype, device="cuda") weight = torch.ones(hidden_size, dtype=dtype, device="cuda") residual = torch.randn_like(x) if use_residual else None @@ -240,9 +229,9 @@ if __name__ == "__main__": default=4096, help="Hidden size (2nd dimension) of the sequence", ) - parser.add_argument("--use-residual", - action="store_true", - help="Whether to use residual connection") + parser.add_argument( + "--use-residual", action="store_true", help="Whether to use residual connection" + ) parser.add_argument( "--save-path", type=str, @@ -253,10 +242,12 @@ if __name__ == "__main__": args = parser.parse_args() # Run correctness test - calculate_diff(batch_size=args.batch_size, - seq_len=args.seq_len, - hidden_size=args.hidden_size, - use_residual=args.use_residual) + calculate_diff( + batch_size=args.batch_size, + seq_len=args.seq_len, + hidden_size=args.hidden_size, + use_residual=args.use_residual, + ) # Get the benchmark function with proper use_residual setting benchmark = get_benchmark(args.use_residual) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 05d24fc4b16d4af78e3e75b73c00ad49e088853f..110d36db157fdf70afa3a2dde76e8c31cfab6922 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,8 +6,7 @@ from typing import Optional import nvtx import torch -from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, - get_rope) +from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding, get_rope from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser @@ -32,40 +31,49 @@ def benchmark_rope_kernels_multi_lora( # silulating serving 4 LoRAs scaling_factors = [1, 2, 4, 8] # batched RoPE can take multiple scaling factors - batched_rope = get_rope(head_size, rotary_dim, max_position, base, - is_neox_style, { - "rope_type": "linear", - "factor": tuple(scaling_factors) - }) + batched_rope = get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": tuple(scaling_factors)}, + ) # non-batched RoPE takes only one scaling factor, we create multiple # instances to simulate the same behavior non_batched_ropes: list[RotaryEmbedding] = [] for scaling_factor in scaling_factors: non_batched_ropes.append( - get_rope(head_size, rotary_dim, max_position, base, is_neox_style, - { - "rope_type": "linear", - "factor": (scaling_factor, ) - })) + get_rope( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + {"rope_type": "linear", "factor": (scaling_factor,)}, + ) + ) positions = torch.randint(0, max_position, (batch_size, seq_len)) - query = torch.randn(batch_size, - seq_len, - num_heads * head_size, - dtype=dtype) + query = torch.randn(batch_size, seq_len, num_heads * head_size, dtype=dtype) key = torch.randn_like(query) # create query offsets for batched RoPE, we concat multiple kv cache # together and each query needs to find the right kv cache of its type offset_map = torch.tensor( list( - accumulate([0] + [ - max_position * scaling_factor * 2 - for scaling_factor in scaling_factors[:-1] - ]))) - query_types = torch.randint(0, - len(scaling_factors), (batch_size, seq_len), - device=device) + accumulate( + [0] + + [ + max_position * scaling_factor * 2 + for scaling_factor in scaling_factors[:-1] + ] + ) + ) + ) + query_types = torch.randint( + 0, len(scaling_factors), (batch_size, seq_len), device=device + ) # map query types to offsets query_offsets = offset_map[query_types] # the kernel takes flattened offsets @@ -86,27 +94,28 @@ def benchmark_rope_kernels_multi_lora( torch.cuda.synchronize() -if __name__ == '__main__': +if __name__ == "__main__": parser = FlexibleArgumentParser( - description="Benchmark the rotary embedding kernels.") + description="Benchmark the rotary embedding kernels." + ) parser.add_argument("--is-neox-style", type=bool, default=True) parser.add_argument("--batch-size", type=int, default=16) parser.add_argument("--seq-len", type=int, default=512) parser.add_argument("--num-heads", type=int, default=8) - parser.add_argument("--head-size", - type=int, - choices=[64, 80, 96, 112, 120, 128, 192, 256], - default=128) + parser.add_argument( + "--head-size", + type=int, + choices=[64, 80, 96, 112, 120, 128, 192, 256], + default=128, + ) parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32) - parser.add_argument("--dtype", - type=str, - choices=["bfloat16", "float"], - default="float") + parser.add_argument( + "--dtype", type=str, choices=["bfloat16", "float"], default="float" + ) parser.add_argument("--seed", type=int, default=0) - parser.add_argument("--device", - type=str, - choices=["cuda:0", "cuda:1"], - default="cuda:0") + parser.add_argument( + "--device", type=str, choices=["cuda:0", "cuda:1"], default="cuda:0" + ) args = parser.parse_args() print(args) diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 8f07bc8ca52eb9c272b7753336695d7b32d14649..6315c1ee6cdd6893e9d5e513c7dadc79f228f7d3 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -14,14 +14,16 @@ import tqdm import triton from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - _w8a8_block_fp8_matmul) + _w8a8_block_fp8_matmul, +) from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser mp.set_start_method("spawn", force=True) -assert current_platform.is_cuda( -), "Only support tune w8a8 block fp8 kernel on CUDA device." +assert current_platform.is_cuda(), ( + "Only support tune w8a8 block fp8 kernel on CUDA device." +) DTYPE_MAP = { "float32": torch.float32, @@ -40,7 +42,7 @@ def w8a8_block_matmul( config: dict[str, Any], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: - """This function performs matrix multiplication with + """This function performs matrix multiplication with block-wise quantization. It takes two input tensors `A` and `B` with scales `As` and `Bs`. @@ -51,7 +53,7 @@ def w8a8_block_matmul( B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. - block_size: The block size for per-block quantization. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. @@ -71,18 +73,18 @@ def w8a8_block_matmul( assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] - C_shape = A.shape[:-1] + (N, ) + C_shape = A.shape[:-1] + (N,) C = A.new_empty(C_shape, dtype=output_dtype) def grid(META): - return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * - triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) if A.dtype == torch.float8_e4m3fn: kernel = _w8a8_block_fp8_matmul else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") kernel[grid]( A, @@ -119,14 +121,16 @@ def get_configs_compute_bound(): for block_n in [32, 64, 128, 256]: for num_warps in [4, 8]: for group_size in [1, 16, 32, 64]: - configs.append({ - "BLOCK_SIZE_M": block_m, - "BLOCK_SIZE_N": block_n, - "BLOCK_SIZE_K": block_k, - "GROUP_SIZE_M": group_size, - "num_warps": num_warps, - "num_stages": num_stages, - }) + configs.append( + { + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_size, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) return configs @@ -165,15 +169,9 @@ def get_weight_shapes(tp_size): return weight_shapes -def benchmark_config(A, - B, - As, - Bs, - block_size, - config, - out_dtype=torch.float16, - num_iters=10): - +def benchmark_config( + A, B, As, Bs, block_size, config, out_dtype=torch.float16, num_iters=10 +): def run(): w8a8_block_matmul(A, B, As, Bs, block_size, config, out_dtype) @@ -206,26 +204,26 @@ def tune(M, N, K, block_size, out_dtype, search_space, input_type): fp8_max, fp8_min = fp8_info.max, fp8_info.min A_fp32 = ( - (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(M, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) A = A_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) B_fp32 = ( - (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * - fp8_max) + (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max + ) B = B_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) else: - raise RuntimeError( - "Currently, only support tune w8a8 block fp8 kernel.") + raise RuntimeError("Currently, only support tune w8a8 block fp8 kernel.") block_n, block_k = block_size[0], block_size[1] n_tiles = (N + block_n - 1) // block_n k_tiles = (K + block_k - 1) // block_k - As = torch.rand(M, k_tiles, dtype=torch.float32, - device="cuda") * factor_for_scale - Bs = (torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") * - factor_for_scale) + As = torch.rand(M, k_tiles, dtype=torch.float32, device="cuda") * factor_for_scale + Bs = ( + torch.rand(n_tiles, k_tiles, dtype=torch.float32, device="cuda") + * factor_for_scale + ) best_config = None best_time = float("inf") @@ -267,7 +265,8 @@ def save_configs( device_name = current_platform.get_device_name().replace(" ", "_") json_file_name = ( f"N={N},K={K},device_name={device_name},dtype={input_type}_w8a8," - f"block_shape=[{block_n},{block_k}].json") + f"block_shape=[{block_n},{block_k}].json" + ) config_file_path = os.path.join(save_path, json_file_name) print(f"Writing best config to {config_file_path}...") @@ -295,8 +294,7 @@ def tune_on_gpu(args_dict): search_space = get_configs_compute_bound() search_space = [ - config for config in search_space - if block_k % config["BLOCK_SIZE_K"] == 0 + config for config in search_space if block_k % config["BLOCK_SIZE_K"] == 0 ] start = time.time() @@ -312,15 +310,11 @@ def tune_on_gpu(args_dict): out_dtype, search_space, input_type, - ) for batch_size in tqdm(batch_sizes, - desc=f"GPU {gpu_id} - Batch sizes") + ) + for batch_size in tqdm(batch_sizes, desc=f"GPU {gpu_id} - Batch sizes") ] - best_configs = { - M: config - for M, config in zip(batch_sizes, benchmark_results) - } - save_configs(N, K, block_n, block_k, best_configs, save_path, - input_type) + best_configs = {M: config for M, config in zip(batch_sizes, benchmark_results)} + save_configs(N, K, block_n, block_k, best_configs, save_path, input_type) end = time.time() print(f"Tuning on GPU {gpu_id} took {end - start:.2f} seconds") @@ -376,13 +370,14 @@ def main(args): process_args = [] for gpu_id in range(num_gpus): - process_args.append({ - "gpu_id": gpu_id, - "batch_sizes": batches_per_gpu[gpu_id], - "weight_shapes": - weight_shapes, # Each GPU processes all weight shapes - "args": args, - }) + process_args.append( + { + "gpu_id": gpu_id, + "batch_sizes": batches_per_gpu[gpu_id], + "weight_shapes": weight_shapes, # Each GPU processes all weight shapes + "args": args, + } + ) ctx = mp.get_context("spawn") with ctx.Pool(num_gpus) as pool: @@ -398,13 +393,11 @@ Tune triton w8a8 block fp8 for DeepSeek-V3/DeepSeek-R1: python3 benchmark_w8a8_block_fp8.py --tp-size 8 --input-type fp8 Then copy to model_executor/layers/quantization/utils/configs """, - formatter_class=argparse.RawTextHelpFormatter) + formatter_class=argparse.RawTextHelpFormatter, + ) parser.add_argument("--tp-size", "-tp", type=int, default=8) - parser.add_argument("--input-type", - type=str, - choices=["fp8"], - default="fp8") + parser.add_argument("--input-type", type=str, choices=["fp8"], default="fp8") parser.add_argument( "--out-dtype", type=str, diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index 7892f126e7d694c4845187f039f6233236bc10e2..e377648254512dab59b6b97b678a52872cdc2b22 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -6,13 +6,15 @@ import time # Import DeepGEMM functions import deep_gemm import torch -import triton from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor # Import vLLM functions from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - per_token_group_quant_fp8, w8a8_block_fp8_matmul) + per_token_group_quant_fp8, + w8a8_block_fp8_matmul, +) +from vllm.triton_utils import triton # Copied from diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index bd62173a7b3a643487357672902ac089e2e49a23..ab364a84d6cb2083269edd96a27c0544159054f5 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -14,13 +14,14 @@ from vllm.utils import FlexibleArgumentParser if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the latency of processing a single batch of ' - 'requests till completion.') - parser.add_argument('filename', type=str) + description="Benchmark the latency of processing a single batch of " + "requests till completion." + ) + parser.add_argument("filename", type=str) args = parser.parse_args() - with open(args.filename, 'rb') as f: + with open(args.filename, "rb") as f: data = pickle.load(f) raw_results: list[TMeasurement] = data["results"] @@ -38,11 +39,7 @@ if __name__ == "__main__": raise Exception("MKN not found") kernel = v.task_spec.description - results[KN].append({ - "kernel": kernel, - "batch_size": M, - "median": v.median - }) + results[KN].append({"kernel": kernel, "batch_size": M, "median": v.median}) rows = int(math.ceil(len(results) / 2)) fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) @@ -50,14 +47,16 @@ if __name__ == "__main__": for axs_idx, (shape, data) in enumerate(results.items()): plt.sca(axs[axs_idx]) df = pd.DataFrame(data) - sns.lineplot(data=df, - x="batch_size", - y="median", - hue="kernel", - style="kernel", - markers=True, - dashes=False, - palette="Dark2") + sns.lineplot( + data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2", + ) plt.title(f"Shape: {shape}") plt.ylabel("time (median, s)") plt.tight_layout() diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index ac64f786f18406d99b0e799b02e2d02d2db0b111..877a29feed9dfe226a28baaac2d4d5a72ef23943 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -23,6 +23,7 @@ class ArgPool: For every invocation during a benchmarking run, it will choose a different value from the list. """ + values: Iterable[Any] def __getitem__(self, index): @@ -30,9 +31,7 @@ class ArgPool: class Bench: - class ArgsIterator: - def __init__(self, args_list, kwargs_list): assert len(args_list) == len(kwargs_list) self.args_list = args_list @@ -53,10 +52,16 @@ class Bench: def n_args(self): return self.n - def __init__(self, cuda_graph_params: Optional[CudaGraphBenchParams], - label: str, sub_label: str, description: str, fn: Callable, - *args, **kwargs): - + def __init__( + self, + cuda_graph_params: Optional[CudaGraphBenchParams], + label: str, + sub_label: str, + description: str, + fn: Callable, + *args, + **kwargs, + ): self.cuda_graph_params = cuda_graph_params self.use_cuda_graph = self.cuda_graph_params is not None self.label = label @@ -67,10 +72,8 @@ class Bench: # Process args self._args = args self._kwargs = kwargs - self.args_list, self.kwargs_list = self.collapse_argpool( - *args, **kwargs) - self.args_iterator = self.ArgsIterator(self.args_list, - self.kwargs_list) + self.args_list, self.kwargs_list = self.collapse_argpool(*args, **kwargs) + self.args_iterator = self.ArgsIterator(self.args_list, self.kwargs_list) # Cudagraph runner self.g = None @@ -100,16 +103,13 @@ class Bench: for i in range(argpool_size): # collapse args; Just pick the ith value - args_list[i] = tuple([ - arg[i] if isinstance(arg, ArgPool) else arg - for arg in args_list[i] - ]) + args_list[i] = tuple( + [arg[i] if isinstance(arg, ArgPool) else arg for arg in args_list[i]] + ) # collapse kwargs kwargs_i = kwargs_list[i] - arg_pool_keys = [ - k for k, v in kwargs_i.items() if isinstance(v, ArgPool) - ] + arg_pool_keys = [k for k, v in kwargs_i.items() if isinstance(v, ArgPool)] for k in arg_pool_keys: # again just pick the ith value kwargs_i[k] = kwargs_i[k][i] @@ -142,7 +142,7 @@ class Bench: def run_cudagrah(self) -> TMeasurement: assert self.use_cuda_graph - globals = {'g': self.g} + globals = {"g": self.g} return TBenchmark.Timer( stmt="g.replay()", @@ -162,15 +162,15 @@ class Bench: has_arg_pool = self.args_iterator.n_args > 1 if has_arg_pool: - setup = ''' + setup = """ args_iterator.reset() args_it = args_iterator.__next__() - ''' - stmt = ''' + """ + stmt = """ args, kwargs = next(args_it) fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args_iterator': self.args_iterator} + """ + globals = {"fn": self.fn, "args_iterator": self.args_iterator} else: # no arg pool. Just use the args and kwargs directly self.args_iterator.reset() @@ -178,10 +178,10 @@ class Bench: args, kwargs = next(args_it) setup = "" - stmt = ''' + stmt = """ fn(*args, **kwargs) - ''' - globals = {'fn': self.fn, 'args': args, 'kwargs': kwargs} + """ + globals = {"fn": self.fn, "args": args, "kwargs": kwargs} return TBenchmark.Timer( stmt=stmt, diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 5f94552e9dc85233b82daf85b162aec47f287284..d5701a8fbd6d85f75c4830ca379590306fe59eae 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -7,9 +7,8 @@ from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser # A very long prompt, total number of tokens is about 15k. -LONG_PROMPT = ["You are an expert in large language models, aren't you?" - ] * 1000 -LONG_PROMPT = ' '.join(LONG_PROMPT) +LONG_PROMPT = ["You are an expert in large language models, aren't you?"] * 1000 +LONG_PROMPT = " ".join(LONG_PROMPT) def main(args): @@ -30,32 +29,35 @@ def main(args): print("------start generating------") for i in range(3): - profiler.runctx('llm.generate(LONG_PROMPT, sampling_params)', - globals(), locals()) + profiler.runctx( + "llm.generate(LONG_PROMPT, sampling_params)", globals(), locals() + ) # analyze the runtime of hashing function stats = pstats.Stats(profiler) - stats.sort_stats('cumulative') + stats.sort_stats("cumulative") total_time = 0 total_calls = 0 for func in stats.stats: - if 'hash_of_block' in func[2]: + if "hash_of_block" in func[2]: total_time = stats.stats[func][3] total_calls = stats.stats[func][0] percentage = (total_time / stats.total_tt) * 100 - print(f"Hashing took {total_time:.2f} seconds," - f"{percentage:.2f}% of the total runtime.") + print( + f"Hashing took {total_time:.2f} seconds,{percentage:.2f}% of the total runtime." + ) if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the performance of hashing function in' - 'automatic prefix caching.') - parser.add_argument('--model', type=str, default='lmsys/longchat-7b-16k') - parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) - parser.add_argument('--output-len', type=int, default=10) - parser.add_argument('--enable-prefix-caching', - action='store_true', - help='enable prefix caching') + description="Benchmark the performance of hashing function in" + "automatic prefix caching." + ) + parser.add_argument("--model", type=str, default="lmsys/longchat-7b-16k") + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--output-len", type=int, default=10) + parser.add_argument( + "--enable-prefix-caching", action="store_true", help="enable prefix caching" + ) args = parser.parse_args() main(args) diff --git a/benchmarks/pyproject.toml b/benchmarks/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..f825cb203269cf2f2e17bf1f530cfb2a599559f9 --- /dev/null +++ b/benchmarks/pyproject.toml @@ -0,0 +1,54 @@ +# This local pyproject file is part of the migration from yapf to ruff format. +# It uses the same core rules as the main pyproject.toml file, but with the +# following differences: +# - ruff line length is overridden to 88 +# - deprecated typing ignores (UP006, UP035) have been removed + +[tool.ruff] +line-length = 88 +exclude = [ + # External file, leaving license intact + "examples/other/fp8/quantizer/quantize.py", + "vllm/vllm_flash_attn/flash_attn_interface.pyi" +] + +[tool.ruff.lint.per-file-ignores] +"vllm/third_party/**" = ["ALL"] +"vllm/version.py" = ["F401"] +"vllm/_version.py" = ["ALL"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", + # flake8-logging-format + "G", +] +ignore = [ + # star imports + "F405", "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # f-string format + "UP032", + # Can remove once 3.10+ is the minimum Python version + "UP007", +] + +[tool.ruff.lint.isort] +known-first-party = ["vllm"] + +[tool.ruff.format] +docstring-code-format = true \ No newline at end of file diff --git a/benchmarks/run_structured_output_benchmark.sh b/benchmarks/run_structured_output_benchmark.sh index 126dfbc24416102f5eb9100f06dc86f7865f2ce5..b043ab83e4608e7b27e2f94e6ec24a05f7c474aa 100755 --- a/benchmarks/run_structured_output_benchmark.sh +++ b/benchmarks/run_structured_output_benchmark.sh @@ -1,41 +1,102 @@ #!/bin/bash -# Define the model to use -MODEL=${1:-"Qwen/Qwen2.5-7B-Instruct"} - -# Define the backend to use -BACKEND=${2:-"vllm"} - -# Define the dataset to use -DATASET=${3:-"xgrammar_bench"} - -# Define the guided decoding backend -GUIDED_BACKEND=${4:-"xgrammar"} - +# default values +MODEL=${MODEL:-"Qwen/Qwen2.5-7B-Instruct"} +BACKEND=${BACKEND:-"vllm"} +DATASET=${DATASET:-"xgrammar_bench"} SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -OUTPUT_DIR=${5:-"$SCRIPT_DIR/structured_output_benchmark_results"} - -GUIDED_RATIO=${6:-0.5} +OUTPUT_DIR=${OUTPUT_DIR:-"$SCRIPT_DIR/structured_output_benchmark_results"} +PORT=${PORT:-8000} +STRUCTURED_OUTPUT_RATIO=${STRUCTURED_OUTPUT_RATIO:-1} +TOTAL_SECONDS=${TOTAL_SECONDS:-90} +MAX_NEW_TOKENS=${MAX_NEW_TOKENS:-300} +TOKENIZER_MODE=${TOKENIZER_MODE:-"auto"} + +usage() { + echo "Usage: $0 [options]" + echo "Options:" + echo " --model MODEL Model to benchmark (default: $MODEL)" + echo " --backend BACKEND Backend to use (default: $BACKEND)" + echo " --dataset DATASET Dataset to use (default: $DATASET)" + echo " --max-new-tokens N Maximum number of tokens to generate (default: $MAX_NEW_TOKENS)" + echo " --output-dir DIR Output directory for results (default: $OUTPUT_DIR)" + echo " --port PORT Port to use (default: $PORT)" + echo " --structured-output-ratio N Ratio of structured outputs (default: $STRUCTURED_OUTPUT_RATIO)" + echo " --tokenizer-mode MODE Tokenizer mode to use (default: $TOKENIZER_MODE)" + echo " --total-seconds N Total seconds to run the benchmark (default: $TOTAL_SECONDS)" + echo " -h, --help Show this help message and exit" + exit 0 +} + +# parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --model) + MODEL="$2" + shift 2 + ;; + --backend) + BACKEND="$2" + shift 2 + ;; + --dataset) + DATASET="$2" + shift 2 + ;; + --max-new-tokens) + MAX_NEW_TOKENS="$2" + shift 2 + ;; + --output-dir) + OUTPUT_DIR="$2" + shift 2 + ;; + --port) + PORT="$2" + shift 2 + ;; + --structured-output-ratio) + STRUCTURED_OUTPUT_RATIO="$2" + shift 2 + ;; + --tokenizer-mode) + TOKENIZER_MODE="$2" + shift 2 + ;; + --total-seconds) + TOTAL_SECONDS="$2" + shift 2 + ;; + -h|--help) + usage + ;; + *) + echo "Unknown argument: $1\n" + usage + ;; + esac +done # Create output directory if it doesn't exist mkdir -p "$OUTPUT_DIR" # Define QPS values to test -QPS_VALUES=(70 60 50 25 20 15 10) +QPS_VALUES=(25 20 15 10 5 1) # Common parameters COMMON_PARAMS="--backend $BACKEND \ --model $MODEL \ --dataset $DATASET \ - --structured-output-backend $GUIDED_BACKEND \ - --structured-output-ratio $GUIDED_RATIO \ + --structured-output-ratio $STRUCTURED_OUTPUT_RATIO \ --save-results \ - --result-dir $OUTPUT_DIR" + --result-dir $OUTPUT_DIR \ + --output-len $MAX_NEW_TOKENS \ + --port $PORT \ + --tokenizer-mode $TOKENIZER_MODE" echo "Starting structured output benchmark with model: $MODEL" echo "Backend: $BACKEND" echo "Dataset: $DATASET" -echo "Structured output backend: $GUIDED_BACKEND" echo "Results will be saved to: $OUTPUT_DIR" echo "----------------------------------------" @@ -48,14 +109,17 @@ for qps in "${QPS_VALUES[@]}"; do GIT_BRANCH=$(git rev-parse --abbrev-ref HEAD 2>/dev/null || echo "unknown") # Construct filename for this run - FILENAME="${GUIDED_BACKEND}_${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + FILENAME="${BACKEND}_${qps}qps_$(basename $MODEL)_${DATASET}_${GIT_HASH}.json" + + NUM_PROMPTS=$(echo "$TOTAL_SECONDS * $qps" | bc) + NUM_PROMPTS=${NUM_PROMPTS%.*} # Remove fractional part + echo "Running benchmark with $NUM_PROMPTS prompts" # Run the benchmark python "$SCRIPT_DIR/benchmark_serving_structured_output.py" $COMMON_PARAMS \ --request-rate $qps \ --result-filename "$FILENAME" \ - --tokenizer-mode ${TOKENIZER_MODE:-"auto"} \ - --port ${PORT:-8000} + --num-prompts $NUM_PROMPTS echo "Completed benchmark with QPS: $qps" echo "----------------------------------------" diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 00670bd398b5d61bf8a62ba2a022ce665a83c92a..fb763db9fc359ef9dbe96b688fe914e36245011d 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -167,6 +167,33 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) +elseif(POWER10_FOUND) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.7.2 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + set(DNNL_CPU_RUNTIME "OMP") + + FetchContent_MakeAvailable(oneDNN) + list(APPEND LIBS dnnl) endif() @@ -197,6 +224,10 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED) "csrc/cpu/quant.cpp" "csrc/cpu/shm.cpp" ${VLLM_EXT_SRC}) +elseif(POWER10_FOUND) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) endif() # @@ -214,4 +245,4 @@ define_gpu_extension_target( WITH_SOABI ) -message(STATUS "Enabling C extension.") \ No newline at end of file +message(STATUS "Enabling C extension.") diff --git a/cmake/utils.cmake b/cmake/utils.cmake index f18a95a96a076ba6198d024a74610c314e9913a3..375d254ba343ff9c0e968cb8c7984e1a56e0dd6a 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -229,11 +229,26 @@ macro(set_gencode_flags_for_srcs) "${multiValueArgs}" ${ARGN} ) foreach(_ARCH ${arg_CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_ARCH}" - CODE "sm_${_ARCH}") + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() endforeach() if (${arg_BUILD_PTX_FOR_ARCH}) @@ -252,7 +267,10 @@ endmacro() # # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. # The loose intersection is defined as: # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # where `<=` is the version comparison operator. @@ -269,44 +287,63 @@ endmacro() # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) - set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS_) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() - if ("10.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") set(_CUDA_ARCHS "10.0a") endif() endif() - list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # is less or equal to ARCH (but has the same major version since SASS binary # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${TGT_CUDA_ARCHS_}) + foreach(_ARCH ${_TGT_CUDA_ARCHS}) set(_TMP_ARCH) # Extract the major version of the target arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) # Extract the major version of the source arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check major-version match AND version-less-or-equal + # Check version-less-or-equal, and allow PTX arches to match across majors if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) set(_TMP_ARCH "${_SRC_ARCH}") endif() else() @@ -322,6 +359,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83a10b1fd155cd8e96023036bb8d580..55e6596797010403c8f2d8cc4d2ebbcae1c75d7e 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { \ + return; \ + } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 57f00d6b050d3277a270b4ac488c4f49e12a0c94..72bd70474943e52785bb1839b92da42e74b64894 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -17,660 +17,660 @@ * limitations under the License. */ - #include - #include - #include - #include - - #include "attention_dtypes.h" - #include "attention_utils.cuh" - - #ifdef USE_ROCM - #include - #include "../quantization/fp8/amd/quant_utils.cuh" - typedef __hip_bfloat16 __nv_bfloat16; - #else - #include "../quantization/fp8/nvidia/quant_utils.cuh" - #endif - - #ifndef USE_ROCM - #define WARP_SIZE 32 - #else - #define WARP_SIZE warpSize - #endif - - #define MAX(a, b) ((a) > (b) ? (a) : (b)) - #define MIN(a, b) ((a) < (b) ? (a) : (b)) - #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) - - namespace vllm { - - // Utility function for attention softmax. - template - inline __device__ float block_sum(float* red_smem, float sum) { - // Decompose the thread index into warp / lane. - int warp = threadIdx.x / WARP_SIZE; - int lane = threadIdx.x % WARP_SIZE; - - // Compute the sum per warp. - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Warp leaders store the data to shared memory. - if (lane == 0) { - red_smem[warp] = sum; - } - - // Make sure the data is in shared memory. - __syncthreads(); - - // The warps compute the final sums. - if (lane < NUM_WARPS) { - sum = red_smem[lane]; - } - - // Parallel reduction inside the warp. - #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); - } - - // Broadcast to other threads. - return VLLM_SHFL_SYNC(sum, 0); - } - - // TODO(woosuk): Merge the last two dimensions of the grid. - // Grid: (num_heads, num_seqs, max_num_partitions). - template // Zero means no partitioning. - __device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int seq_len = seq_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); - const int num_blocks_per_partition = - USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = - USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = - MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = - MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); - const int num_tokens = end_token_idx - start_token_idx; - - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); - constexpr int NUM_THREAD_GROUPS = - NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE - // divides NUM_THREADS - assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = - DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int thread_idx = threadIdx.x; - const int warp_idx = thread_idx / WARP_SIZE; - const int lane = thread_idx % WARP_SIZE; - - const int head_idx = blockIdx.x; - const int num_heads = gridDim.x; - const int num_queries_per_kv = num_heads / num_kv_heads; - const int kv_head_idx = head_idx / num_queries_per_kv; - const float alibi_slope = - alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; - - // A vector type to store a part of a key or a query. - // The vector size is configured in such a way that the threads in a thread - // group fetch or compute 16 bytes at a time. For example, if the size of a - // thread group is 4 and the data type is half, then the vector size is 16 / - // (4 * sizeof(half)) == 2. - constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); - using K_vec = typename Vec::Type; - using Q_vec = typename Vec::Type; - using Quant_vec = typename Vec::Type; - - constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; - constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; - - const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; - const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; - - // Load the query to registers. - // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the query, and the second thread - // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because - // q is split from a qkv tensor, it may not be contiguous. - const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; - __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; - #pragma unroll - for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; - i += NUM_THREAD_GROUPS) { - const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; - q_vecs[thread_group_offset][i] = - *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); - } - __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a - // memory wall right before we use q_vecs - - // Memory planning. - extern __shared__ char shared_mem[]; - // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. - float* logits = reinterpret_cast(shared_mem); - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // x == THREAD_GROUP_SIZE * VEC_SIZE - // Each thread group fetches x elements from the key at a time. - constexpr int x = 16 / sizeof(cache_t); - float qk_max = -FLT_MAX; - - // Iterate over the key blocks. - // Each warp fetches a block of keys for each iteration. - // Each thread group in a warp fetches a key from the block, and computes - // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - - // blocksparse specific vars - int bs_block_offset; - int q_bs_block_id; - if constexpr (IS_BLOCK_SPARSE) { - // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, - // blocksparse_block_size); - q_bs_block_id = (seq_len - 1) / blocksparse_block_size; - if (blocksparse_head_sliding_step >= 0) - // sliding on q heads - bs_block_offset = - (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; - else - // sliding on kv heads - bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * - (-blocksparse_head_sliding_step) + - 1; - } - - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - const bool is_remote = - ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); - const bool is_local = - (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); - if (!is_remote && !is_local) { - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - - if (thread_group_offset == 0) { - // NOTE(linxihui): assign very large number to skipped tokens to - // avoid contribution to the sumexp softmax normalizer. This will - // not be used at computing sum(softmax*v) as the blocks will be - // skipped. - logits[token_idx - start_token_idx] = -FLT_MAX; - } - } - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - - // Load a key to registers. - // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in - // the group has 0, 4, 8, ... th vectors of the key, and the second thread - // has 1, 5, 9, ... th vectors of the key, and so on. - for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { - const int physical_block_offset = - (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - K_vec k_vecs[NUM_VECS_PER_THREAD]; - - #pragma unroll - for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { - const cache_t* k_ptr = - k_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride + physical_block_offset * x; - const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; - const int offset1 = (vec_idx * VEC_SIZE) / x; - const int offset2 = (vec_idx * VEC_SIZE) % x; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - k_vecs[j] = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - } else { - // Vector conversion from Quant_vec to K_vec. - Quant_vec k_vec_quant = *reinterpret_cast( - k_ptr + offset1 * BLOCK_SIZE * x + offset2); - k_vecs[j] = fp8::scaled_convert( - k_vec_quant, *k_scale); - } - } - - // Compute dot product. - // This includes a reduction across the threads in the same thread group. - float qk = scale * Qk_dot::dot( - q_vecs[thread_group_offset], k_vecs); - // Add the ALiBi bias if slopes are given. - qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; - - if (thread_group_offset == 0) { - // Store the partial reductions to shared memory. - // NOTE(woosuk): It is required to zero out the masked logits. - const bool mask = token_idx >= seq_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; - // Update the max value. - qk_max = mask ? qk_max : fmaxf(qk_max, qk); - } - } - } - - // Perform reduction across the threads in the same warp to get the - // max qk value for each "warp" (not across the thread block yet). - // The 0-th thread of each thread group already has its max qk value. - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = qk_max; - } - __syncthreads(); - - // TODO(woosuk): Refactor this part. - // Get the max qk value for the sequence. - qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; - #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); - } - // Broadcast the max qk value to all threads. - qk_max = VLLM_SHFL_SYNC(qk_max, 0); - - // Get the sum of the exp values. - float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - float val = __expf(logits[i] - qk_max); - logits[i] = val; - exp_sum += val; - } - exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); - - // Compute softmax. - const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { - logits[i] *= inv_sum; - } - __syncthreads(); - - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions + partition_idx; - *exp_sums_ptr = exp_sum; - } - - // Each thread will fetch 16 bytes from the value cache at a time. - constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); - using V_vec = typename Vec::Type; - using L_vec = typename Vec::Type; - using V_quant_vec = typename Vec::Type; - using Float_L_vec = typename FloatVec::Type; - - constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; - constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = - DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); - - // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. - float accs[NUM_ROWS_PER_THREAD]; - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - accs[i] = 0.f; - } - - scalar_t zero_value; - zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; - block_idx += NUM_WARPS) { - // NOTE(woosuk): The block number is stored in int32. However, we cast it to - // int64 because int32 can lead to overflow when this variable is multiplied - // by large numbers (e.g., kv_block_stride). - // For blocksparse attention: skip computation on blocks that are not - // attended - if constexpr (IS_BLOCK_SPARSE) { - int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; - if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && - !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { - continue; - } - } - const int64_t physical_block_number = - static_cast(block_table[block_idx]); - const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; - const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; - L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - - start_token_idx)); - - const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + - kv_head_idx * kv_head_stride; - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE) { - const int offset = row_idx * BLOCK_SIZE + physical_block_offset; - V_vec v_vec; - - if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { - v_vec = *reinterpret_cast(v_ptr + offset); - } else { - V_quant_vec v_quant_vec = - *reinterpret_cast(v_ptr + offset); - // Vector conversion from V_quant_vec to V_vec. - v_vec = fp8::scaled_convert(v_quant_vec, - *v_scale); - } - if (block_idx == num_seq_blocks - 1) { - // NOTE(woosuk): When v_vec contains the tokens that are out of the - // context, we should explicitly zero out the values since they may - // contain NaNs. See - // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 - scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); - #pragma unroll - for (int j = 0; j < V_VEC_SIZE; j++) { - v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; - } - } - accs[i] += dot(logits_vec, v_vec); - } - } - } - - // Perform reduction within each warp. - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - float acc = accs[i]; - #pragma unroll - for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += VLLM_SHFL_XOR_SYNC(acc, mask); - } - accs[i] = acc; - } - - // NOTE(woosuk): A barrier is required because the shared memory space for - // logits is reused for the output. - __syncthreads(); - - // Perform reduction across warps. - float* out_smem = reinterpret_cast(shared_mem); - #pragma unroll - for (int i = NUM_WARPS; i > 1; i /= 2) { - int mid = i / 2; - // Upper warps write to shared memory. - if (warp_idx >= mid && warp_idx < i) { - float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - dst[row_idx] = accs[i]; - } - } - } - __syncthreads(); - - // Lower warps update the output. - if (warp_idx < mid) { - const float* src = &out_smem[warp_idx * HEAD_SIZE]; - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - accs[i] += src[row_idx]; - } - } - } - __syncthreads(); - } - - // Write the final output. - if (warp_idx == 0) { - scalar_t* out_ptr = - out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; - #pragma unroll - for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { - const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; - if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { - from_float(*(out_ptr + row_idx), accs[i]); - } - } - } - } - - // Grid: (num_heads, num_seqs, 1). - template - __global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, - v_cache, num_kv_heads, scale, block_tables, seq_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, - kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, - blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); - } - - // Grid: (num_heads, num_seqs, max_num_partitions). - template - __global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, *k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + *v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float* k_scale, const float* v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, - // head_size/x, block_size, x] - const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, - // head_size, block_size] - const int num_kv_heads, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float* k_scale, const float* v_scale, const int tp_rank, - const int blocksparse_local_blocks, const int blocksparse_vert_stride, - const int blocksparse_block_size, const int blocksparse_head_sliding_step) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, - block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, - blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, - blocksparse_head_sliding_step); - } - - // Grid: (num_heads, num_seqs). - template - __global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, - // max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, - // max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, - // max_num_partitions, head_size] - const int* __restrict__ seq_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int seq_len = seq_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. - #pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; - #pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = VLLM_SHFL_SYNC(max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = - reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + - seq_idx * num_heads * max_num_partitions + - head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = - tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - #pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * - inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } - } - - } // namespace vllm - - #undef WARP_SIZE - #undef MAX - #undef MIN - #undef DIVIDE_ROUND_UP \ No newline at end of file + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/vertical_slash_index.cu b/csrc/attention/vertical_slash_index.cu new file mode 100644 index 0000000000000000000000000000000000000000..c1b45b143f4e1ad11548ecd981572257482694a7 --- /dev/null +++ b/csrc/attention/vertical_slash_index.cu @@ -0,0 +1,401 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include + +#include + +#include + +__device__ int64_t save_blocks(int* block_offset, int64_t range_start, + int64_t range_end, int64_t block_size, + int64_t input_block_count, int64_t kv_seqlen) { + if (range_start >= kv_seqlen) { + return input_block_count; + } + if (range_end > kv_seqlen) { + range_end = kv_seqlen; + } + int64_t current_block_count = input_block_count; + for (int idx = range_start; idx < range_end; idx += block_size) { + block_offset[current_block_count++] = idx; + } + return current_block_count; +} + +__global__ void convert_vertical_slash_indexes_kernel( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, block_count, + block_offset, column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, + BLOCK_SIZE_N, NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * This function builds the index of each row of blocks from vertical indices + * and slash indices. The vertical indices are treated as points, while the + * slash indices are converted as ranges. The output consists of the merged + * ranges and separate column indices, where the ranges are represented by + * block indices. + * + * The implementation is referenced from the original MInference repo: + * https://github.com/microsoft/MInference/blob/main/csrc/vertical_slash_index.cu. + */ +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + block_count.data_ptr(), block_offset.data_ptr(), + column_count.data_ptr(), column_index.data_ptr(), batch_size, + num_heads, num_rows, block_size_M, block_size_N, nnz_vertical, nnz_slash, + causal); +} + +__global__ void convert_vertical_slash_indexes_kernel_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + const int* per_head_vertical_topkv, const int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, int64_t BLOCK_SIZE_N, + int64_t NNZ_V, int64_t NNZ_S, + bool causal // True for intra, False for succ +) { + const int batch_idx = blockIdx.y; + const int head_idx = blockIdx.x; + const int group_idx = blockIdx.z; + + int64_t q_seqlen = q_seqlens[batch_idx]; + int64_t kv_seqlen = kv_seqlens[batch_idx]; + int64_t block_idx_m = group_idx * blockDim.x + threadIdx.x; + int64_t start_m = block_idx_m * BLOCK_SIZE_M; + if (start_m >= q_seqlen) { + return; + } + int64_t end_m = start_m + BLOCK_SIZE_M; + vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V; + slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S; + int64_t row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m; + block_count += row_offset; + block_offset += row_offset * NNZ_S; + column_count += row_offset; + column_index += row_offset * NNZ_V; + + // MergeHead: each head has it's unique max topk NNZ_V,NNZ_S. (NNZ_V,NNZ_S + // above is buffer size, use to compute offset) + NNZ_S = per_head_slash_topkv[head_idx]; + NNZ_V = per_head_vertical_topkv[head_idx]; + + bool has_slash = true; + int64_t tmp_col_cnt = 0, tmp_blk_cnt = 0; + int64_t s = 0, v = 0; + int64_t v_idx = vertical_indexes[v++]; + int64_t s_idx = slash_indexes[s++]; + if (causal) { + while (s_idx >= end_m + (kv_seqlen - q_seqlen) && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + (kv_seqlen - q_seqlen)) has_slash = false; + s_idx = max((kv_seqlen - q_seqlen) + end_m - s_idx, BLOCK_SIZE_M); + } else { + while (s_idx >= end_m + kv_seqlen && s < NNZ_S) { + s_idx = slash_indexes[s++]; + } + if (s_idx > end_m + kv_seqlen) has_slash = false; + s_idx = max(kv_seqlen + end_m - s_idx, BLOCK_SIZE_M); + } + + int64_t range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx; + if (!has_slash) { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + } + + bool slash_finished = false; + while (1) { + if (v_idx < range_end) { + if (v_idx < range_start) { + column_index[tmp_col_cnt++] = v_idx; + } + if (v < NNZ_V) { + v_idx = vertical_indexes[v++]; + } else { + if (causal) + v_idx = end_m + BLOCK_SIZE_N + (kv_seqlen - q_seqlen); + else + v_idx = end_m + BLOCK_SIZE_N + kv_seqlen; + } + } else { + if ((s < NNZ_S && causal) || + (s < NNZ_S && !causal && slash_indexes[s] >= start_m)) { + if (causal) + s_idx = max((kv_seqlen - q_seqlen) + end_m - slash_indexes[s++], + BLOCK_SIZE_M); + else + s_idx = max(kv_seqlen + end_m - slash_indexes[s++], BLOCK_SIZE_M); + } else { + if (v == NNZ_V || (v_idx > range_start && causal)) { + // add the last vertical if no more slash + if (v == NNZ_V && !causal && v_idx < kv_seqlen) { + column_index[tmp_col_cnt++] = v_idx; + } + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + break; + } else { + if (causal) { + range_start = (kv_seqlen - q_seqlen) + end_m; + range_end = (kv_seqlen - q_seqlen) + end_m + BLOCK_SIZE_N; + } else { + // if slash_finished but there are vertical left, save current + // blocks + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = kv_seqlen; + range_end = kv_seqlen + BLOCK_SIZE_N; + } + slash_finished = true; + } + } + if (!slash_finished) { + if (s_idx > range_end + BLOCK_SIZE_M) { + tmp_blk_cnt = save_blocks(block_offset, range_start, range_end, + BLOCK_SIZE_N, tmp_blk_cnt, kv_seqlen); + range_start = s_idx - BLOCK_SIZE_M; + range_end = s_idx; + } else if (s_idx > range_end) { + range_end += BLOCK_SIZE_M; + } + } + } + } + + block_count[0] = tmp_blk_cnt; + column_count[0] = tmp_col_cnt; +} + +void convert_vertical_slash_indexes_64x64_mergehead( + const int* q_seqlens, // [BATCH, ] + const int* kv_seqlens, // [BATCH, ] + const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int* per_head_vertical_topkv, int* per_head_slash_topkv, + int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S] + int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)] + int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V] + int64_t BATCH_SIZE, int64_t N_HEADS, int64_t N_ROWS, int64_t BLOCK_SIZE_M, + int64_t BLOCK_SIZE_N, int64_t NNZ_V, int64_t NNZ_S, bool causal) { + const int N_THREADS = 64; + const dim3 dimBlock(N_THREADS); + const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS); + convert_vertical_slash_indexes_kernel_mergehead<<>>( + q_seqlens, kv_seqlens, vertical_indexes, slash_indexes, + per_head_vertical_topkv, per_head_slash_topkv, block_count, block_offset, + column_count, column_index, N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, + NNZ_V, NNZ_S, causal); +} + +/** + * Implements the Algorithm 4 in paper https://arxiv.org/abs/2407.02490. + * + * Like the above convert_vertical_slash_indexes, but with + * pre-computed vertical and slash counts. + */ +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, // [N_HEADS, ] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal) { + cudaSetDevice(q_seqlens.get_device()); + + int batch_size = slash_indexes.size(0); + int num_heads = slash_indexes.size(1); + int nnz_slash = slash_indexes.size(2); + int nnz_vertical = vertical_indexes.size(2); + int num_rows = (context_size + block_size_M - 1) / block_size_M; + + convert_vertical_slash_indexes_64x64_mergehead( + q_seqlens.data_ptr(), kv_seqlens.data_ptr(), + vertical_indexes.data_ptr(), slash_indexes.data_ptr(), + vertical_indices_count.data_ptr(), + slash_indices_count.data_ptr(), block_count.data_ptr(), + block_offset.data_ptr(), column_count.data_ptr(), + column_index.data_ptr(), batch_size, num_heads, num_rows, + block_size_M, block_size_N, nnz_vertical, nnz_slash, causal); +} diff --git a/csrc/core/math.hpp b/csrc/core/math.hpp index b8171133f6aad2ea0f48fc3363054a93fffde9cc..6764e1fd60545ad89d809934d6be02b04475ed2d 100644 --- a/csrc/core/math.hpp +++ b/csrc/core/math.hpp @@ -7,3 +7,22 @@ inline constexpr uint32_t next_pow_2(uint32_t const num) { if (num <= 1) return num; return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); } + +template +static inline constexpr auto div_ceil(A a, B b) { + return (a + b - 1) / b; +} + +// Round a down to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_previous_multiple_of(T a, T b) { + return a % b == 0 ? a : (a / b) * b; +} + +// Round a up to the next multiple of b. The caller is responsible for making +// sure that b is non-zero +template +inline constexpr T round_to_next_multiple_of(T a, T b) { + return a % b == 0 ? a : ((a / b) + 1) * b; +} diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index c2ae554c9f8e81c39c9e22a2ce24613b1fe8d2a6..d0f85e23609b0b4b0510a83d1e14a8797da4c112 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -315,6 +315,8 @@ static inline constexpr auto kS8 = ScalarType::int_(8); static inline constexpr auto kU8 = ScalarType::uint(8); static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); +static inline constexpr auto kFE2M1f = + ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); static inline constexpr auto kFE4M3fn = @@ -332,6 +334,7 @@ static inline constexpr auto kInt8 = kS8; static inline constexpr auto kUint8 = kU8; static inline constexpr auto kUint8b128 = kU8B128; +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; static inline constexpr auto kFloat6_e3m2f = kFE3M2f; static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; static inline constexpr auto kFloat8_e5m2 = kFE5M2; diff --git a/csrc/cpu/cpu_types_vsx.hpp b/csrc/cpu/cpu_types_vsx.hpp index a8e1be37eb418bea8e23c29192d30ea6916c5793..089b9840ea2ed6dea2684326fa7ef8f505f1e63b 100644 --- a/csrc/cpu/cpu_types_vsx.hpp +++ b/csrc/cpu/cpu_types_vsx.hpp @@ -4,6 +4,7 @@ #include #include +#include #include namespace vec_op { @@ -62,6 +63,10 @@ typedef struct f32x4x4_t { __vector float val[4]; } f32x4x4_t; +typedef struct i32x4x4_t { + __vector int32_t val[4]; +} i32x4x4_t; + struct FP32Vec8; struct FP32Vec16; @@ -98,6 +103,28 @@ struct BF16Vec16 : public Vec { vec_xst(reg.val[0], 0, (signed short*)ptr); vec_xst(reg.val[1], 16, (signed short*)ptr); } + + void save(void* ptr, const int elem_num) const { + const int clamped_elem = std::max(0, std::min(elem_num, 16)); + + // Calculate elements to store in each 128-bit part (8 elements each) + const int elements_val0 = std::min(clamped_elem, 8); + const int elements_val1 = std::max(clamped_elem - 8, 0); + + // Convert elements to bytes (2 bytes per element) + const size_t bytes_val0 = elements_val0 * sizeof(signed short); + const size_t bytes_val1 = elements_val1 * sizeof(signed short); + + signed short* dest = static_cast(ptr); + // Store the first part using vec_xst_len + if (bytes_val0 > 0) { + vec_xst_len(reg.val[0], dest, bytes_val0); + } + // Store the second part if needed + if (bytes_val1 > 0) { + vec_xst_len(reg.val[1], dest + elements_val0, bytes_val1); + } + } }; const static __vector signed short zero = vec_splats((signed short)0); @@ -257,6 +284,64 @@ struct FP32Vec8 : public Vec { } }; +struct INT32Vec16 : public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + i32x4x4_t reg; + int32_t values[VEC_ELEM_NUM]; + }; + + i32x4x4_t reg; + + explicit INT32Vec16(const void* data_ptr) { + reg.val[0] = vec_xl(0, reinterpret_cast(data_ptr)); + reg.val[1] = + vec_xl(16, reinterpret_cast(data_ptr)); + reg.val[2] = + vec_xl(32, reinterpret_cast(data_ptr)); + reg.val[3] = + vec_xl(48, reinterpret_cast(data_ptr)); + } + + void save(int32_t* ptr) const { + vec_xst(reg.val[0], 0, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[1], 16, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[2], 32, reinterpret_cast<__vector int32_t*>(ptr)); + vec_xst(reg.val[3], 48, reinterpret_cast<__vector int32_t*>(ptr)); + } + + void save(int32_t* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(int32_t)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(int32_t)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(int32_t)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(int32_t)); + + vec_xst_len(reg.val[0], reinterpret_cast(ptr), bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; union AliasReg { @@ -319,6 +404,13 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8& v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16& v) { + reg.val[0] = vec_ctf(v.reg.val[0], 0); + reg.val[1] = vec_ctf(v.reg.val[1], 0); + reg.val[2] = vec_ctf(v.reg.val[2], 0); + reg.val[3] = vec_ctf(v.reg.val[3], 0); + } + FP32Vec16 operator*(const FP32Vec16& b) const { return FP32Vec16(f32x4x4_t({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1]), @@ -347,6 +439,117 @@ struct FP32Vec16 : public Vec { vec_div(reg.val[3], b.reg.val[3])})); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(f32x4x4_t( + {vec_min(max.reg.val[0], vec_max(min.reg.val[0], reg.val[0])), + vec_min(max.reg.val[1], vec_max(min.reg.val[1], reg.val[1])), + vec_min(max.reg.val[2], vec_max(min.reg.val[2], reg.val[2])), + vec_min(max.reg.val[3], vec_max(min.reg.val[3], reg.val[3]))})); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_max(reg.val[0], b.reg.val[0]), + vec_max(reg.val[1], b.reg.val[1]), + vec_max(reg.val[2], b.reg.val[2]), + vec_max(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 max(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + // Create a vector of element indices for each chunk + __vector unsigned int indices = {0, 1, 2, 3}; + __vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + // Compute masks for each chunk + __vector unsigned int chunk_offset0 = {0, 0, 0, + 0}; // Chunk 0: Elements 0-3 + __vector unsigned int chunk_offset1 = {4, 4, 4, + 4}; // Chunk 1: Elements 4-7 + __vector unsigned int chunk_offset2 = {8, 8, 8, + 8}; // Chunk 2: Elements 8-11 + __vector unsigned int chunk_offset3 = {12, 12, 12, + 12}; // Chunk 3: Elements 12-15 + + // Compute masks for each chunk + __vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + __vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + __vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + __vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + // Apply masks to compute the result for each chunk + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_max(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_max(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_max(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_max(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(f32x4x4_t({vec_min(reg.val[0], b.reg.val[0]), + vec_min(reg.val[1], b.reg.val[1]), + vec_min(reg.val[2], b.reg.val[2]), + vec_min(reg.val[3], b.reg.val[3])})); + } + + FP32Vec16 min(const FP32Vec16& b, int elem_num) const { + FP32Vec16 result; + + vector unsigned int indices = {0, 1, 2, 3}; + vector unsigned int elem_num_vec = + vec_splats(static_cast(elem_num)); + + vector unsigned int chunk_offset0 = {0, 0, 0, 0}; + vector unsigned int chunk_offset1 = {4, 4, 4, 4}; + vector unsigned int chunk_offset2 = {8, 8, 8, 8}; + vector unsigned int chunk_offset3 = {12, 12, 12, 12}; + + vector bool int mask0 = vec_cmplt(indices + chunk_offset0, elem_num_vec); + vector bool int mask1 = vec_cmplt(indices + chunk_offset1, elem_num_vec); + vector bool int mask2 = vec_cmplt(indices + chunk_offset2, elem_num_vec); + vector bool int mask3 = vec_cmplt(indices + chunk_offset3, elem_num_vec); + + result.reg.val[0] = vec_sel(this->reg.val[0], + vec_min(this->reg.val[0], b.reg.val[0]), mask0); + result.reg.val[1] = vec_sel(this->reg.val[1], + vec_min(this->reg.val[1], b.reg.val[1]), mask1); + result.reg.val[2] = vec_sel(this->reg.val[2], + vec_min(this->reg.val[2], b.reg.val[2]), mask2); + result.reg.val[3] = vec_sel(this->reg.val[3], + vec_min(this->reg.val[3], b.reg.val[3]), mask3); + + return FP32Vec16(result.reg); + } + + FP32Vec16 abs() const { + return FP32Vec16(f32x4x4_t({vec_abs(reg.val[0]), vec_abs(reg.val[1]), + vec_abs(reg.val[2]), vec_abs(reg.val[3])})); + } + + float reduce_max() { + __vector float max01 = vec_max(reg.val[0], reg.val[1]); + __vector float max23 = vec_max(reg.val[2], reg.val[3]); + __vector float max_all = vec_max(max01, max23); + __vector float temp = vec_max(max_all, vec_sld(max_all, max_all, 8)); + temp = vec_max(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + + float reduce_min() { + __vector float min01 = vec_min(reg.val[0], reg.val[1]); + __vector float min23 = vec_min(reg.val[2], reg.val[3]); + __vector float min_all = vec_min(min01, min23); + __vector float temp = vec_min(min_all, vec_sld(min_all, min_all, 8)); + temp = vec_min(temp, vec_sld(temp, temp, 4)); + return vec_extract(temp, 0); + } + float reduce_sum() const { AliasReg ar; ar.reg = reg; @@ -377,6 +580,68 @@ struct FP32Vec16 : public Vec { vec_xst(reg.val[2], 32, ptr); vec_xst(reg.val[3], 48, ptr); } + + void save(float* ptr, const int elem_num) const { + const int elements_in_chunk1 = + (elem_num >= 0) ? ((elem_num >= 4) ? 4 : elem_num) : 0; + const int elements_in_chunk2 = + (elem_num > 4) ? ((elem_num >= 8) ? 4 : elem_num - 4) : 0; + const int elements_in_chunk3 = + (elem_num > 8) ? ((elem_num >= 12) ? 4 : elem_num - 8) : 0; + const int elements_in_chunk4 = + (elem_num > 12) ? ((elem_num >= 16) ? 4 : elem_num - 12) : 0; + + const size_t bytes_chunk1 = + static_cast(elements_in_chunk1 * sizeof(float)); + const size_t bytes_chunk2 = + static_cast(elements_in_chunk2 * sizeof(float)); + const size_t bytes_chunk3 = + static_cast(elements_in_chunk3 * sizeof(float)); + const size_t bytes_chunk4 = + static_cast(elements_in_chunk4 * sizeof(float)); + + vec_xst_len(reg.val[0], ptr, bytes_chunk1); + vec_xst_len(reg.val[1], + reinterpret_cast(reinterpret_cast(ptr) + 16), + bytes_chunk2); + vec_xst_len(reg.val[2], + reinterpret_cast(reinterpret_cast(ptr) + 32), + bytes_chunk3); + vec_xst_len(reg.val[3], + reinterpret_cast(reinterpret_cast(ptr) + 48), + bytes_chunk4); + } +}; + +struct INT8Vec16 : public Vec { + constexpr static int VEC_NUM_ELEM = 16; // 128 bits / 8 bits = 16 + + union AliasReg { + __vector signed char reg; + int8_t values[VEC_NUM_ELEM]; + }; + + __vector signed char reg; + + explicit INT8Vec16(const FP32Vec16& vec) { + __vector signed int ret[4]; + ret[0] = vec_cts(vec.reg.val[0], 0); + ret[1] = vec_cts(vec.reg.val[1], 0); + ret[2] = vec_cts(vec.reg.val[2], 0); + ret[3] = vec_cts(vec.reg.val[3], 0); + + __vector signed short packed1 = vec_packs(ret[0], ret[1]); + __vector signed short packed2 = vec_packs(ret[2], ret[3]); + + reg = vec_packs(packed1, packed2); + } + + void save(void* ptr) const { + *reinterpret_cast<__vector signed char*>(ptr) = reg; + } + void save(signed char* ptr, const int elem_num) { + vec_xst_len(reg, ptr, static_cast(elem_num)); + } }; template diff --git a/csrc/cpu/pos_encoding.cpp b/csrc/cpu/pos_encoding.cpp index 8a59e884d6c82eacc988bd1684629d5f65960c7d..74bb014cf39e90aa7e6a4258575346054baccb75 100644 --- a/csrc/cpu/pos_encoding.cpp +++ b/csrc/cpu/pos_encoding.cpp @@ -9,7 +9,8 @@ void rotary_embedding_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -85,10 +86,13 @@ void rotary_embedding_impl( compute_loop(token_head, cache_ptr, query); } - for (int i = 0; i < num_kv_heads; ++i) { - const int head_idx = i; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - compute_loop(token_head, cache_ptr, key); + if (key != nullptr) { + for (int i = 0; i < num_kv_heads; ++i) { + const int head_idx = i; + const int64_t token_head = + token_idx * key_stride + head_idx * head_size; + compute_loop(token_head, cache_ptr, key); + } } } } @@ -100,7 +104,8 @@ void rotary_embedding_gptj_impl( scalar_t* __restrict__ query, /// [batch_size, seq_len, num_heads, /// head_size] or [num_tokens, num_heads, /// head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr (optional) or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // @@ -138,6 +143,10 @@ void rotary_embedding_gptj_impl( } } + if (key == nullptr) { + return; + } + #pragma omp parallel for collapse(2) for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { for (int i = 0; i < num_kv_heads; ++i) { @@ -168,13 +177,13 @@ void rotary_embedding_gptj_impl( }; // namespace void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox) { int num_tokens = positions.numel(); int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; - int num_kv_heads = key.size(-1) / head_size; - int64_t key_stride = key.stride(-2); + int num_kv_heads = key.has_value() ? key->size(-1) / head_size : num_heads; + int64_t key_stride = key.has_value() ? key->stride(-2) : 0; int64_t query_stride = query.stride(-2); VLLM_DISPATCH_FLOATING_TYPES( @@ -183,15 +192,15 @@ void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, if (is_neox) { rotary_embedding_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } else { rotary_embedding_gptj_impl( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size, num_tokens); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size, num_tokens); } CPU_KERNEL_GUARD_OUT(rotary_embedding_impl) diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp index 6751e7e55fc51994b4004c796788b840cd524399..f61dbcc948e83a5764a602a880fc9ddb99096dd0 100644 --- a/csrc/cpu/quant.cpp +++ b/csrc/cpu/quant.cpp @@ -239,6 +239,280 @@ void static_quant_epilogue(const float* input, scalar_t* output, } } +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#elif defined(__powerpc64__) +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} template void dynamic_quant_epilogue(const float* input, scalar_t* output, const float* a_scale, const float* b_scale, @@ -324,7 +598,8 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, const float* scale, const int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -332,7 +607,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, float* scale, int32_t* azp, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") + TORCH_CHECK( + false, + "dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.") } template @@ -340,7 +617,7 @@ void static_quant_epilogue(const float* input, scalar_t* output, const float a_scale, const float* b_scale, const int32_t* azp_with_adj, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.") } template @@ -349,7 +626,8 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output, const int32_t* azp, const int32_t* azp_with_adj, const scalar_t* bias, const int num_tokens, const int hidden_size) { - TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") + TORCH_CHECK(false, + "dynamic_quant_epilogue requires AVX512/powerpc64 support.") } #endif } // namespace @@ -611,3 +889,58 @@ void dynamic_scaled_int8_quant( } }); } + +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_ppc64le only supports INT8 inputs."); + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + // We dont need this + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_ppc64le", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + }); +} + +#endif diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 7ae7e3386b4ed047b2c0a7ddcef18001080c6a7e..447e826bc1c09b83c55824d0438dad1c0e12681b 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -18,6 +18,14 @@ void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, const std::optional& azp, const std::optional& bias); +#if defined(__powerpc64__) +void int8_scaled_mm_ppc64le(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, + const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const std::optional& bias); +#endif + void mla_decode_kvcache(torch::Tensor& out, torch::Tensor& query, torch::Tensor& kv_cache, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens); @@ -117,7 +125,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); @@ -150,6 +158,33 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor b_scales, Tensor azp_adj," " Tensor? azp, Tensor? bias) -> ()"); ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#elif defined(__powerpc64__) + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm_ppc64le); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); #endif // SHM CCL diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index dbe0e30f5cbfe608b2a2d75aecae274dcc5bc094..0877da52435eb5ac6c24081c9bb974d7390063d5 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel { #endif } }; + +template +struct enable_sm100_only : Kernel { + template + CUTLASS_DEVICE void operator()(Args&&... args) { +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000 + Kernel::operator()(std::forward(args)...); +#endif + } +}; diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index dc6e0769b8780d091d835395d9f5b7833f561533..f7b75c48373f68e9025020eea507415fb9405e2e 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -65,5 +65,19 @@ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index fb6882f3e7c3ea47d42fb1e2e48bd285b47dff0c..d073dd6d2dee134c462647f2afd1ce751481aa74 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -140,6 +140,10 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size] torch::Tensor& input, // [..., hidden_size] torch::Tensor& weight, // [hidden_size] double epsilon) { + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(weight.is_contiguous()); + int hidden_size = input.size(-1); int num_tokens = input.numel() / hidden_size; diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h deleted file mode 100644 index 47ecf109d0f53fc2691a00e4ba4e6add86328b29..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel.h +++ /dev/null @@ -1,1616 +0,0 @@ -#pragma once - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/scalar_type.hpp" - -namespace marlin_moe { - -constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// Instances of `Vec` are used to organize groups of >>registers<<, as needed -// for instance as inputs to tensor core operations. Consequently, all -// corresponding index accesses must be compile-time constants, which is why we -// extensively use `#pragma unroll` throughout the kernel code to guarantee -// this. -template -struct Vec { - T elems[n]; - __device__ T& operator[](int i) { return elems[i]; } -}; - -using I4 = Vec; - -// Matrix fragments for tensor core instructions; their precise layout is -// documented here: -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type -using FragA = Vec; -using FragB = Vec; -using FragC = Vec; -using FragS = Vec; // quantization scales -using FragZP = Vec; - -// Predicated asynchronous global->shared copy; used for inputs A where we apply -// predication to handle batchsizes that are not multiples of 16. -__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, - bool pred = true) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " .reg .pred p;\n" - " setp.ne.b32 p, %0, 0;\n" - " @p cp.async.cg.shared.global [%1], [%2], %3;\n" - "}\n" ::"r"((int)pred), - "r"(smem), "l"(glob_ptr), "n"(BYTES)); -} - -// Asynchronous global->shared copy -__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { - const int BYTES = 16; - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile( - "{\n" - " cp.async.cg.shared.global [%0], [%1], %2;\n" - "}\n" ::"r"(smem), - "l"(glob_ptr), "n"(BYTES)); -} - -// Async copy fence. -__device__ inline void cp_async_fence() { - asm volatile("cp.async.commit_group;\n" ::); -} - -// Wait until at most `n` async copy stages are still pending. -template -__device__ inline void cp_async_wait() { - asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); -} - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, - FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline FragB dequant(int q); - -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 -// Reference: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline FragB dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { - half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { - half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, - FragS& frag_s_3, FragS& frag_s_4, int i) { - __half2 s_val_1_2; - s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; - - __half2 s_val_3_4; - s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -__device__ inline void scale_float(float* c, FragS& s) { - __half* s_ptr = reinterpret_cast<__half*>(&s); - c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__device__ void MarlinMoESingle( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block // current m block to start kernel computation from -) { - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - ceildiv(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = ceildiv(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - sorted_ids += 16 * thread_m_blocks; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - int zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_sh_wr_delta * i + a_sh_wr; - int row = a_idx / a_gl_rd_delta_o; - if (row >= prob_m) { - a_sh_wr_pred[i] = false; - } else { - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - } - } - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; - int sorted_row = - replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; - int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; - if (sorted_row < tot_m * (replicate_input ? 1 : topk) && - new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { - cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], - a_sh_wr_pred[i]); - } - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch scales if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - int warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - int th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - int warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - FragB frag_b0 = dequant(b_quant_0); - FragB frag_b1 = dequant(b_quant_1); - // Apply zero-point to frag_b0 - if constexpr (has_zp) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], - act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int sorted_row = sorted_ids[c_idx / c_gl_stride]; - int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; - cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], - sorted_row < tot_m * topk && - (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk))); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (8 * (i / 2) + row < prob_m && - (i < (thread_m_blocks - 1) * 4 || - sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - __half2float(reinterpret_cast<__half*>(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast<__half*>(&c)[j] = - __float2half(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - int c_idx = - c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); - int row = sorted_ids[c_idx / c_gl_stride]; - if (row < tot_m * topk) { - int new_idx = row * c_gl_stride + c_idx % c_gl_stride; - C[new_idx] = c; - } - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - half2 res = __halves2half2(__float2half(c0), __float2half(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((half2*)sh)[idx] = res; - }; - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - int row = sorted_ids[c_gl_wr / c_gl_stride]; - if (row < tot_m * topk) { - int off = row * c_gl_stride + c_gl_wr % c_gl_stride; - if (!apply_weights) { - C[off] = sh[c_sh_rd]; - } else { - __half* ctrg = reinterpret_cast<__half*>(&C[off]); - __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); - for (int j = 0; j < 8; ++j) { - ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); - } - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - // For 4-bit per-column scales, we only fetch them here in the - // final step before write-out - if (last) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - int m_block_ctr = current_m_block; - - const int* sorted_ids_expert = - sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; - int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; - if (tot_its == 0) { - return; - } - int tot_m_blocks = ceildiv(tot_its, 16); - int pad = 16 * tot_m_blocks - tot_its; - - if (m_block_ctr >= tot_m_blocks) { - return; - } - - int max_block = tot_m_blocks - m_block_ctr; - prob_m = tot_its - 16 * m_block_ctr; - - int par = 1; - if (max_block > cfg_max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * cfg_max_m_blocks) * par; - m_block_ctr += cfg_max_m_blocks * (par - 1); - max_block = cfg_max_m_blocks; - } - - if (max_block == 1) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 2) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else if (max_block == 3) { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } else { - MarlinMoESingle( - A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, - expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, - prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, - current_m_block); - } -} - -#else - -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void MarlinMoE( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int* __restrict__ sorted_ids, // int32 sorted ids of experts - const float* __restrict__ topk_weights, // float topk weights - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - const int* __restrict__ expert_offsets, - int num_groups, // number of scale groups per output channel - int expert_idx, // idx of current expert - int num_experts, // number of experts - int topk, // topk parameter of moe - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int tot_m, // total number of rows in A and C - int* locks, // extra global storage for barrier synchronization - bool replicate_input, // do we use the same input for each expert? - bool apply_weights, // apply weights to output - int current_m_block, // current m block to start kernel computation from - int max_par, // maximum parallelism - int cfg_max_m_blocks // upper bound on m blocks -) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -// 8 warps are a good choice since every SM has 4 schedulers and having more -// than 1 warp per schedule allows some more latency hiding. At the same time, -// we want relatively few warps to have many registers per warp and small tiles. -const int USER_THREADS = - 256; // Note: This is only used with user-provided thread_k/n -const int STAGES = 4; // 4 pipeline stages fit into shared memory - -static constexpr int min_thread_n = 64; -static constexpr int min_thread_k = 64; - -#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ - HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ - else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - MarlinMoE, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - MarlinMoE \ - <<>>( \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - cfg_max_m_blocks); \ - } - -#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) - -#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ - __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu deleted file mode 100644 index 77bc0dd90edde03ab80a0a77241bacfbd4955712..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = true; - - if (false) { - } - AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h deleted file mode 100644 index 833fadf37721f93717e060c7c0379588704eb777..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu deleted file mode 100644 index f7e57b037594539ed77d69d1968b01cc2d506afa..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku4b8.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h deleted file mode 100644 index 494da8f10e26255d3f408ce7c8d87c53f67af845..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku4b8( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu deleted file mode 100644 index a901f0b11cd786b9f5574823939fe3f3c40783af..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu +++ /dev/null @@ -1,31 +0,0 @@ -#include "marlin_moe_kernel_ku8b128.h" - -namespace marlin_moe { - -// We return bool so we can create these different kernel calls as a sequence -// of if-elseif's. -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks) { - bool has_zp = false; - - if (false) { - } - GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) - else { - return false; - } - return true; -} - -} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h deleted file mode 100644 index f3018aa0c1ab7938ab99bbd0604d7b462d216721..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h +++ /dev/null @@ -1,18 +0,0 @@ -#pragma once - -#include "marlin_moe_kernel.h" - -namespace marlin_moe { - -bool call_marlin_moe_kernel_ku8b128( - vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, - bool has_act_order, int group_blocks, int num_threads, int blocks, - int max_shared_mem, cudaStream_t stream, const int4* A_ptr, - const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, - const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, - const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, - int expert_idx, int num_experts, int topk, int prob_m, int prob_n, - int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, - int m_block, int max_par, int cfg_max_m_blocks); - -} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu deleted file mode 100644 index 5f12483e951e849f5e0575bfd361ef6caafcf983..0000000000000000000000000000000000000000 --- a/csrc/moe/marlin_moe_ops.cu +++ /dev/null @@ -1,588 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include - -#include - -#include "core/exception.hpp" -#include "core/scalar_type.hpp" -#include "core/registration.h" -#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" -#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" -#include "marlin_kernels/marlin_moe_kernel_ku4.h" - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace marlin_moe { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - -// For a given "a" of size [M,K] performs a permutation of the K columns based -// on the given "perm" indices. -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - int start_row = block_rows * blockIdx.x; - int finish_row = start_row + block_rows; - if (finish_row > size_m) { - finish_row = size_m; - } - int cur_block_rows = finish_row - start_row; - - int row_stride = size_k * sizeof(half) / 16; - - auto permute_row = [&](int row) { - int iters = size_k / blockDim.x; - int rest = size_k % blockDim.x; - - int offset = row * row_stride; - - half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); - half* out_half = reinterpret_cast(out_int4_ptr + offset); - - int base_k = 0; - - for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - - base_k += blockDim.x; - } - - if (rest) { - if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; - int src_pos = perm_int_ptr[cur_k]; - - out_half[cur_k] = a_row_half[src_pos]; - } - } - }; - - for (int i = 0; i < cur_block_rows; i++) { - int cur_row = start_row + i; - if (cur_row < size_m) { - permute_row(cur_row); - } - } -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - int expert_id = threadIdx.x; - int num_experts = blockDim.x; - - int occurrences = 0; - for (int i = 0; i < topk_length; ++i) { - occurrences += (topk_ids[i] == expert_id); - } - expert_offsets[expert_id + 1] = occurrences; - __syncthreads(); - - if (threadIdx.x == 0) { - int tot_offset = 0; - expert_offsets[0] = 0; - for (int i = 0; i < num_experts; ++i) { - tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; - expert_offsets[i + 1] = tot_offset; - } - } - __syncthreads(); -} - -#else - -__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, - int const* __restrict__ perm_int_ptr, - int4* __restrict__ out_int4_ptr, int size_m, - int size_k, int block_rows) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, - int* __restrict__ expert_offsets, - int topk_length, int block_size) { - // Marlin is not implemented yet for SM < 8.0 - assert(false); - return; -} - -#endif - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, // Default - {128, 64, 128}, // Reduce N 2X, same K - {64, 256, 256}, // Reduce K 2X, increase N 2X - {64, 128, 128}, // Reduce K 2X, same N - {64, 64, 128}, // Reduce both 2X -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, // Default - {128, 128, 256}, // Reduce N 2X, increase K 2X - {64, 128, 128}, // Reduce N 2X, same K - {128, 64, 128}, // Reduce N 4X, increase K 2X - {64, 64, 128}, // Reduce N 4X, same K -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { - bool cache_scales_chunk = has_act_order && !is_k_full; - - int tb_n = th_config.thread_n; - int tb_k = th_config.thread_k; - - // Get max scale groups per thread-block - int tb_groups; - if (group_size == -1) { - tb_groups = 1; - } else if (group_size == 0) { - tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size - } else { - tb_groups = ceildiv(tb_k, group_size); - } - - if (cache_scales_chunk) { - int load_groups = - tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K - load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 4; - - } else { - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * STAGES; - } -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = ceildiv(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * STAGES; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // thread_k can be only 128 or 64 (because it must be less than groupsize - // which is 128) - if (th_config.thread_k != 128 && th_config.thread_k != 64) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - -#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ - else if (KERNEL_FUNCTION( \ - q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ - group_blocks, num_threads, blocks, max_shared_mem, stream, \ - A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ - zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ - num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ - replicate_input, apply_weights, m_block, max_par, \ - exec_cfg.max_m_blocks)) { \ - } - -void marlin_mm_moe(const void* A, const void* B, void* C, - const void* sorted_ids, const void* topk_weights, - const void* topk_ids, const void* s, void* zp, - const void* g_idx, const void* perm, void* a_tmp, - void* expert_offsets, int prob_m, int prob_n, int prob_k, - void* workspace, vllm::ScalarType const& q_type, - bool has_act_order, bool is_k_full, bool has_zp, - int num_groups, int group_size, int num_experts, int topk, - int moe_block_size, int dev, cudaStream_t stream, - int thread_k, int thread_n, int sms, int max_par, - bool replicate_input, bool apply_weights) { - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - int num_bits = q_type.size_bits(); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = 0; - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(group_size != -1); - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } else { - TORCH_CHECK(group_size == 0); - group_blocks = 0; - } - - } else { - if (group_size == -1) { - group_blocks = -1; - } else { - group_blocks = group_size / 16; - TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, - " is not divisible by group_blocks = ", group_blocks); - } - } - - int tot_m = prob_m; - - const int* topk_ids_ptr = (const int*)topk_ids; - int* expert_offsets_ptr = (int*)expert_offsets; - compute_expert_offsets<<<1, num_experts, 0, stream>>>( - topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); - - bool do_permute_a = has_act_order; - - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by - // having a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; - } - - int pack_factor = 32 / q_type.size_bits(); - - for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { - const int4* A_ptr = (const int4*)A; - int4* a_tmp_ptr = (int4*)a_tmp; - const int4* B_ptr = - (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; - int4* C_ptr = (int4*)C; - const float* topk_weights_ptr = (const float*)topk_weights; - const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; - const int4* zp_ptr = - (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; - const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; - const int* perm_ptr = (const int*)perm + prob_k * expert_idx; - int* locks = (int*)workspace; - - if (do_permute_a) { - // Permute A columns - int topk_rows = replicate_input ? tot_m : tot_m * topk; - int block_rows = ceildiv(topk_rows, blocks); - permute_cols_kernel<<>>( - A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); - A_ptr = a_tmp_ptr; - } - - int tot_m_blocks = ceildiv(tot_m, 16); - for (int m_block = 0; m_block < tot_m_blocks; - m_block += 4 * exec_cfg.max_m_blocks) { - if (false) { - } - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) - CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", has_act_order = " + str(has_act_order) + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - } - } -} - -} // namespace marlin_moe - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - torch::Tensor& b_zeros, const torch::Tensor& g_idx, - const torch::Tensor& perm, torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, - int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, - int64_t moe_block_size, bool replicate_input, bool apply_weights) { - vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - bool has_zp = b_zeros.size(1) != 0; - if (has_zp) { - TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); - } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type.str()); - } - - int pack_factor = 32 / b_q_type.size_bits(); - - int max_par = 4; - - int dev = a.get_device(); - - auto options_dtype = - torch::TensorOptions().dtype(a.dtype()).device(a.device()); - auto options_int = - torch::TensorOptions().dtype(torch::kInt).device(a.device()); - torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); - torch::Tensor a_tmp = - replicate_input ? torch::zeros({size_m, size_k}, options_dtype) - : torch::zeros({size_m, topk, size_k}, options_dtype); - torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - bool has_act_order = g_idx.size(1) != 0; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); - TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), - " is not size_n = ", size_n); - num_groups = b_scales.size(1); - - TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), - "if is_k_full is false, has_act_order must be true"); - - if (has_act_order) { - if (is_k_full) { - TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); - TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by num_groups = ", num_groups); - group_size = size_k / num_groups; - } else { - group_size = 0; - } - - } else { - if (num_groups > 1) { - TORCH_CHECK( - size_k % num_groups == 0, "size_k = ", size_k, - ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); - group_size = size_k / num_groups; - } else { - group_size = -1; - } - } - - // Verify b_zeros - if (has_zp) { - int rank = b_zeros.sizes().size(); - TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); - TORCH_CHECK(b_zeros.size(1) == num_groups, - "b_zeros dim 1 = ", b_zeros.size(1), - " is not num_groups = ", num_groups); - TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, - "b_zeros dim 2 = ", b_zeros.size(2), - " is not size_n / pack_factor = ", size_n / pack_factor); - } - - marlin_moe::marlin_mm_moe( - a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), - topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), - b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), - expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), - b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, - num_experts, topk, moe_block_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, - replicate_input, apply_weights); - return c; -} - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("marlin_gemm_moe", &marlin_gemm_moe); -} diff --git a/csrc/moe/marlin_moe_wna16/.gitignore b/csrc/moe/marlin_moe_wna16/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..77088552b85b4f8614516756a31ac1fb673266f3 --- /dev/null +++ b/csrc/moe/marlin_moe_wna16/.gitignore @@ -0,0 +1 @@ +kernel_*.cu \ No newline at end of file diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index d1c0d92f6814a080084fec7e77531f316a1e6373..15f008d4f61ed66ad7e0df5643796ed6176b2af5 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -25,15 +25,16 @@ TEMPLATE = ("template __global__ void Marlin<" "{{thread_k_blocks}}, " "{{'true' if m_block_size_8 else 'false'}}, " "{{stages}}, " - "{{'true' if has_act_order else 'false'}}, " - "{{'true' if has_zp else 'false'}}, " "{{group_blocks}}, " "{{'true' if is_zp_float else 'false'}}>" "( MARLIN_KERNEL_PARAMS );") # int8 with zero point case (vllm::kU8) is also supported, # we don't add it to reduce wheel size. -SCALAR_TYPES = ["vllm::kU4", "vllm::kU4B8", "vllm::kU8B128"] +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] @@ -41,7 +42,7 @@ THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] # = 0 : act order case # = -1 : channelwise quantization # > 0 : group_size=16*group_blocks -GROUP_BLOCKS = [0, -1, 2, 4, 8] +GROUP_BLOCKS = [0, -1, 1, 2, 4, 8] DTYPES = ["fp16", "bf16"] @@ -52,21 +53,35 @@ def remove_old_kernels(): def generate_new_kernels(): for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - has_zp = "B" not in scalar_type all_template_str_list = [] for group_blocks, m_blocks, thread_configs in itertools.product( GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): - has_act_order = group_blocks == 0 - if has_zp and has_act_order: + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "vllm::kU4B8", "vllm::kU8B128" + ]: continue if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) if m_blocks <= 1 and thread_configs[0] != 128: continue if m_blocks > 1 and thread_configs[0] != 64: continue + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue + k_blocks = thread_configs[0] // 16 n_blocks = thread_configs[1] // 16 threads = thread_configs[2] @@ -82,8 +97,6 @@ def generate_new_kernels(): thread_k_blocks=k_blocks, m_block_size_8=m_blocks == 0.5, stages="pipe_stages", - has_act_order=has_act_order, - has_zp=has_zp, group_blocks=group_blocks, is_zp_float=False, ) diff --git a/csrc/moe/marlin_moe_wna16/kernel.h b/csrc/moe/marlin_moe_wna16/kernel.h index 3d92660e8028e94885d0de42e1557054e619edab..537282aba8c87ad56c19ca177debd8797ab4dbef 100644 --- a/csrc/moe/marlin_moe_wna16/kernel.h +++ b/csrc/moe/marlin_moe_wna16/kernel.h @@ -7,18 +7,19 @@ #include "quantization/gptq_marlin/marlin_dtypes.cuh" #include "core/scalar_type.hpp" -#define MARLIN_KERNEL_PARAMS \ - const int4 *__restrict__ A, const int4 *__restrict__ B, \ - int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ - const int4 *__restrict__ scales_ptr, const int4 *__restrict__ zp_ptr, \ - const int *__restrict__ g_idx, \ - const int32_t *__restrict__ sorted_token_ids_ptr, \ - const int32_t *__restrict__ expert_ids_ptr, \ - const int32_t *__restrict__ num_tokens_past_padded_ptr, \ - const float *__restrict__ topk_weights_ptr, int top_k, \ - bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ - int prob_n, int prob_k, int *locks, bool use_atomic_add, \ - bool use_fp32_reduce +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + const int32_t *__restrict__ sorted_token_ids_ptr, \ + const int32_t *__restrict__ expert_ids_ptr, \ + const int32_t *__restrict__ num_tokens_past_padded_ptr, \ + const float *__restrict__ topk_weights_ptr, int top_k, \ + bool mul_topk_weights, bool is_ep, int num_groups, int prob_m, \ + int prob_n, int prob_k, int *locks, bool use_atomic_add, \ + bool use_fp32_reduce, int max_shared_mem namespace MARLIN_NAMESPACE_NAME { template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin(MARLIN_KERNEL_PARAMS); diff --git a/csrc/moe/marlin_moe_wna16/marlin_template.h b/csrc/moe/marlin_moe_wna16/marlin_template.h index 205b308fe511bdf57863b48ecd308572b2d68a2c..1c255396099d5fbe3b6ac1d953a6edcece7bf098 100644 --- a/csrc/moe/marlin_moe_wna16/marlin_template.h +++ b/csrc/moe/marlin_moe_wna16/marlin_template.h @@ -25,6 +25,7 @@ #include "quantization/gptq_marlin/marlin.cuh" #include "quantization/gptq_marlin/marlin_dtypes.cuh" +#include "quantization/gptq_marlin/dequant.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -48,11 +49,9 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -77,8 +76,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) {} + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) {} } // namespace MARLIN_NAMESPACE_NAME @@ -166,144 +165,6 @@ __device__ inline void ldsm(typename ScalarType::FragA& frag_a, } } -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b); - -// -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -// -template <> -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q, - typename ScalarType::FragB& frag_b) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - q >>= 4; - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -// -template <> -__device__ inline typename ScalarType::FragB dequant( - int q, typename ScalarType::FragB& frag_b) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q, - typename ScalarType::FragB& frag_b) { - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - // Multiply dequantized values by the corresponding quantization scale; used // only for grouped quantization. template @@ -429,11 +290,9 @@ template shared // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? > __global__ void Marlin( const int4* __restrict__ A, // fp16 input matrix of shape mxk @@ -442,9 +301,11 @@ __global__ void Marlin( int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k const int32_t* __restrict__ sorted_token_ids_ptr, // moe sorted_ids const int32_t* __restrict__ expert_ids_ptr, // moe expert ids const int32_t* __restrict__ num_tokens_past_padded_ptr, // moe num tokens @@ -458,8 +319,8 @@ __global__ void Marlin( int prob_k, // reduction dimension k int* locks, // extra global storage for barrier synchronization bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) { + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { // Each threadblock processes one "stripe" of the B matrix with (roughly) the // same size, which might involve multiple column "slices" (of width 16 * // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM @@ -481,13 +342,26 @@ __global__ void Marlin( extern __shared__ int4 sh[]; static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + + constexpr bool has_act_order = group_blocks == 0; constexpr int pack_factor = 32 / w_type.size_bits(); static_assert(thread_m_blocks == 1 || !m_block_size_8); constexpr int moe_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); const int group_size = (!has_act_order && group_blocks == -1) ? prob_k : prob_k / num_groups; - const int scales_expert_stride = prob_n * prob_k / group_size / 8; + const int scales_expert_stride = + prob_n * prob_k / group_size / (w_type == vllm::kFE2M1f ? 16 : 8); const int zp_expert_stride = is_zp_float ? prob_n * prob_k / group_size / 8 : prob_n * prob_k / group_size / (pack_factor * 4); @@ -534,13 +408,20 @@ __global__ void Marlin( int64_t B_expert_off = 0; int4* sh_block_sorted_ids_int4 = sh; + int4* sh_rd_block_sorted_ids_int4 = + sh_block_sorted_ids_int4 + moe_block_size / 4; + int4* sh_block_topk_weights_int4 = + sh_rd_block_sorted_ids_int4 + moe_block_size / 4; + // sh_block_topk_weights_int4 only need (moe_block_size / 4); + // but we pad to align to 256 bytes + int4* sh_new = + sh_block_topk_weights_int4 + moe_block_size / 2 + moe_block_size; int32_t* sh_block_sorted_ids = reinterpret_cast(sh_block_sorted_ids_int4); - int4* sh_block_topk_weights_int4 = - sh_block_sorted_ids_int4 + moe_block_size / 4; + int32_t* sh_rd_block_sorted_ids = + reinterpret_cast(sh_rd_block_sorted_ids_int4); scalar_t2* sh_block_topk_weights = reinterpret_cast(sh_block_topk_weights_int4); - int4* sh_new = sh_block_topk_weights_int4 + moe_block_size / 4; int32_t block_num_valid_tokens = 0; int32_t locks_off = 0; @@ -584,12 +465,24 @@ __global__ void Marlin( sh_block_sorted_ids_int4[tid4] = reinterpret_cast( sorted_token_ids_ptr)[block_id * moe_block_size / 4 + tid4]; + #pragma unroll + for (int i = 0; i < 4; i++) + sh_rd_block_sorted_ids[tid4 * 4 + i] = + sh_block_sorted_ids[tid4 * 4 + i] / top_k; + if (mul_topk_weights) { #pragma unroll for (int i = 0; i < 4; i++) { - sh_block_topk_weights[tid4 * 4 + i] = - Dtype::num2num2(Dtype::float2num( - topk_weights_ptr[sh_block_sorted_ids[tid4 * 4 + i]])); + int idx = tid4 * 4 + i; + idx = idx < block_num_valid_tokens ? idx : 0; + if constexpr (w_type == vllm::kFE2M1f) { + sh_block_topk_weights[idx] = __hmul2( + global_scale, Dtype::num2num2(Dtype::float2num( + topk_weights_ptr[sh_block_sorted_ids[idx]]))); + } else { + sh_block_topk_weights[idx] = Dtype::num2num2( + Dtype::float2num(topk_weights_ptr[sh_block_sorted_ids[idx]])); + } } } } @@ -620,6 +513,11 @@ __global__ void Marlin( expert_id = expert_ids_ptr[block_id]; } + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[expert_id]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + B_expert_off = expert_id * prob_n * prob_k / (pack_factor * 4); scales_ptr += (expert_id - old_expert_id) * scales_expert_stride; if constexpr (has_zp) { @@ -733,7 +631,7 @@ __global__ void Marlin( constexpr int s_sh_stride = 16 * thread_n_blocks / 8; constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) : 1; constexpr int s_sh_stage = s_tb_groups * s_sh_stride; int s_gl_rd_delta = s_gl_stride; @@ -743,6 +641,7 @@ __global__ void Marlin( constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; // constexpr int act_s_row_stride = 1; // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; int act_s_col_stride = 1; int act_s_col_warp_stride = act_s_col_stride * 8; int tb_n_warps = thread_n_blocks / 4; @@ -758,9 +657,9 @@ __global__ void Marlin( int zp_gl_rd_delta = zp_gl_stride; // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; + int a_gl_rd_row = threadIdx.x / a_gl_rd_delta_o; + int a_gl_rd_col = a_gl_rd_delta_o * slice_row + threadIdx.x % a_gl_rd_delta_o; + // Shared write index of current thread. int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); @@ -774,8 +673,8 @@ __global__ void Marlin( (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; b_gl_rd += b_sh_stride * slice_col; b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; // For act_order constexpr int k_iter_size = tb_k / b_sh_wr_iters; @@ -790,11 +689,12 @@ __global__ void Marlin( if constexpr (group_blocks == -1) { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + s_sh_stride * slice_col + threadIdx.x; } } - int s_sh_wr = threadIdx.x; + auto s_sh_wr = threadIdx.x; bool s_sh_wr_pred = threadIdx.x < s_sh_stride; // Zero-points @@ -807,17 +707,27 @@ __global__ void Marlin( zp_sh_stride * slice_col + threadIdx.x; } } - int zp_sh_wr = threadIdx.x; + auto zp_sh_wr = threadIdx.x; bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; // We use a different scale layout for grouped and column-wise quantization as // we scale a `half2` tile in column-major layout in the former and in // row-major in the latter case. int s_sh_rd; - if constexpr (group_blocks != -1) + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; - else if constexpr (group_blocks == -1 && (m_block_size_8 || has_zp)) + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; else @@ -851,7 +761,7 @@ __global__ void Marlin( // each warp must also write a consecutive memory segment? auto transform_a = [&](int i) { int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); }; // Since the computation of this remapping is non-trivial and, due to our main // loop unrolls, all shared memory accesses are static, we simply precompute @@ -879,12 +789,28 @@ __global__ void Marlin( B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; // Shared memory storage for global fetch pipelines. - int4* sh_a = sh_new; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh_new; + int4* sh_red = sh_new; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); int4* sh_s = sh_zp + (stages * zp_sh_stage); - int4* sh_red = sh_b; + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= + stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + constexpr int shm_size_used = + moe_block_size + stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // all remaining shared memory is used to cache A (input) + // sh_a_max_row is at least ` stages * 16 * thread_m_blocks ` + int sh_a_max_row = + ((max_shared_mem - 1024) / 16 - shm_size_used) / (thread_k_blocks * 2); // Register storage for double buffer of shared memory reads. FragA frag_a[2][thread_m_blocks]; @@ -905,15 +831,14 @@ __global__ void Marlin( int sh_first_group_id = -1; int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { sh_first_group_id = first_group_id; sh_num_groups = last_group_id - first_group_id + 1; - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; } if (sh_first_group_id + sh_num_groups > num_groups) { @@ -940,27 +865,31 @@ __global__ void Marlin( } } }; + // Asynchronously fetch the next A, B and s tile from global to the next // shared memory pipeline location. - int a_remaining_load_count_in_slice = stages; - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + bool should_load_a = true; + int max_num_stage_groups = + ((sh_a_max_row - moe_block_size) / moe_block_size + 1) / stages; + max_num_stage_groups = max(max_num_stage_groups, 1); + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true, + int pipe_a = 0) { if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - if (prob_k > thread_k_blocks * 16 * stages || slice_col == 0 || - a_remaining_load_count_in_slice > 0) { - a_remaining_load_count_in_slice--; + if (should_load_a) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; #pragma unroll for (int i = 0; i < a_sh_wr_iters; i++) { - int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; - int row = a_idx / a_gl_stride; + int row = a_gl_rd_delta_i / a_gl_stride * i + a_gl_rd_row; int64_t sorted_row = 0; if (!m_block_size_8 || row < 8) - sorted_row = sh_block_sorted_ids[row] / top_k; - int64_t true_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + sorted_row = sh_rd_block_sorted_ids[row]; + int64_t true_idx = + sorted_row * a_gl_stride + a_gl_rd_col + a_gl_rd_delta_o * a_off; cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[true_idx], row < block_num_valid_tokens); } } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) { @@ -1063,8 +992,8 @@ __global__ void Marlin( // Load the next sub-tile from the current location in the shared memory pipe // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; + auto fetch_to_registers = [&](int k, int pipe, int pipe_a = 0) { + int4* sh_a_stage = sh_a + moe_block_size * a_sh_stride * pipe_a; #pragma unroll for (int i = 0; i < thread_m_blocks; i++) ldsm( @@ -1109,12 +1038,17 @@ __global__ void Marlin( } } else if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = - sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1123,12 +1057,19 @@ __global__ void Marlin( cur_k += k_iter_size * (k % b_sh_wr_iters); int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } } } @@ -1152,7 +1093,7 @@ __global__ void Marlin( // Determine "position" inside the thread-block (based on warp and // thread-id) - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N @@ -1161,7 +1102,7 @@ __global__ void Marlin( cur_k += warp_row * 16; - int th_id = threadIdx.x % 32; + auto th_id = threadIdx.x % 32; cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix int s_col_shift = @@ -1222,15 +1163,18 @@ __global__ void Marlin( } } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1251,6 +1195,7 @@ __global__ void Marlin( sh_zp_stage += cur_group_id * zp_sh_stride; + #pragma unroll for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; @@ -1263,12 +1208,16 @@ __global__ void Marlin( if constexpr (group_blocks != -1) { if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd]; + } } else { - int warp_id = threadIdx.x / 32; + auto warp_id = threadIdx.x / 32; int n_warps = thread_n_blocks / 4; int warp_row = warp_id / n_warps; @@ -1292,6 +1241,10 @@ __global__ void Marlin( } }; + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + // Execute the actual tensor core matmul of a sub-tile. bool is_first_matmul_in_slice = true; auto matmul = [&](int k) { @@ -1315,15 +1268,27 @@ __global__ void Marlin( zp_quant_1 = frag_qzp[k2][1]; } - dequant(zp_quant_0, frag_zp_0); - dequant(zp_quant_1, frag_zp_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = + reinterpret_cast(&frag_zpf[k2])[0]; } } + + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + // We have the m dimension as the inner loop in order to encourage overlapping // dequantization and matmul operations. #pragma unroll @@ -1332,7 +1297,10 @@ __global__ void Marlin( FragB frag_b1; int b_quant_0, b_quant_1; - if constexpr (w_type.size_bits() == 4) { + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { b_quant_0 = frag_b_quant[k2][0][j]; b_quant_1 = b_quant_0 >> 8; } else { @@ -1342,8 +1310,13 @@ __global__ void Marlin( b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; } - dequant(b_quant_0, frag_b0); - dequant(b_quant_1, frag_b1); + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } // Apply scale to frag_b0 if constexpr (has_act_order) { @@ -1351,9 +1324,9 @@ __global__ void Marlin( scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], - act_frag_s[k][2][j], act_frag_s[k2][3][j], 1); - - } else if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { int idx = (threadIdx.x / 4) % 2; scalar_t2 s2 = Dtype::nums2num2( reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], @@ -1361,18 +1334,12 @@ __global__ void Marlin( if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); scale_and_sub(frag_b0, s2.x, frag_zp[j].x); scale_and_sub(frag_b1, s2.y, frag_zp[j].y); - } else if constexpr (has_zp && !is_zp_float && group_blocks != -1) { + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k % 2][j][0].x, frag_zp[j].x); - scale_and_sub(frag_b1, frag_s[k % 2][j][0].y, frag_zp[j].y); - } else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - if (is_new_zp) - frag_zpf[k2][j] = __hmul2( - frag_zpf[k2][j], *reinterpret_cast(&frag_s[k2][j])); - scale_and_sub(frag_b0, frag_s[k2][j].x, frag_zpf[k2][j].x); - scale_and_sub(frag_b1, frag_s[k2][j].y, frag_zpf[k2][j].y); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); } else if constexpr (group_blocks != -1) { scale(frag_b0, frag_s[k2][j], 0); scale(frag_b1, frag_s[k2][j], 1); @@ -1397,7 +1364,7 @@ __global__ void Marlin( auto thread_block_reduce = [&]() { constexpr int red_off = threads / b_sh_stride_threads / 2; if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; + auto red_idx = threadIdx.x / b_sh_stride_threads; constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; constexpr int red_sh_delta = b_sh_stride_threads; int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + @@ -1634,10 +1601,17 @@ __global__ void Marlin( // For per-column quantization we finally apply the scale here (only for // 4-bit) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4 && !has_zp) { + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { res = __hmul2(res, s[0]); } + if constexpr (w_type == vllm::kFE2M1f) { + if (!mul_topk_weights) { + res = __hmul2(res, global_scale); + } + } + if constexpr (m_block_size_8) { ((scalar_t*)sh_red)[idx] = res.x; ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; @@ -1728,10 +1702,12 @@ __global__ void Marlin( if constexpr (has_zp && !is_zp_float && group_blocks == -1) { if (i == 0) { fetch_col_zp_to_shared(); - fetch_col_scale_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } } } - fetch_to_shared(i, i, i < slice_iters); + fetch_to_shared(i, i, i < slice_iters, i); } zero_accums(); @@ -1740,8 +1716,10 @@ __global__ void Marlin( fetch_to_registers(0, 0); fetch_scales_to_registers(0, 0); fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); + a_gl_rd_col += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } }; if (slice_iters) { start_pipes(); @@ -1754,43 +1732,59 @@ __global__ void Marlin( // have even length meaning that the next iteration will always start at // index 0. + for (int stage_group_id = 0; stage_group_id < max_num_stage_groups; + stage_group_id++) { #pragma unroll - for (int pipe = 0; pipe < stages;) { + for (int pipe = 0; pipe < stages;) { #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); + for (int k = 0; k < b_sh_wr_iters; k++) { + int idx = + (pipe >= stages && stage_group_id == max_num_stage_groups - 1) + ? (pipe - stages) + : (pipe + stage_group_id * stages); + fetch_to_registers(k + 1, pipe % stages, idx); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + int idx = (pipe >= 1 && stage_group_id == max_num_stage_groups - 1) + ? (pipe - 1) + : (pipe + (stage_group_id + 1) * stages - 1); + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages, idx); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; } - } - a_remaining_load_count_in_slice = 0; - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; + a_gl_rd_col += a_gl_rd_delta_o * stages; - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); + if (slice_iters == 0) { + break; } } @@ -1802,7 +1796,8 @@ __global__ void Marlin( bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before // write-out - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); @@ -1812,7 +1807,8 @@ __global__ void Marlin( } thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1 && !has_zp) { + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { if (w_type.size_bits() == 8 || (last || use_atomic_add)) { cp_async_wait<0>(); __syncthreads(); @@ -1836,7 +1832,8 @@ __global__ void Marlin( // that converts the fp32 results to fp16 (so that we avoid possible // overflow in fp16) if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8 && !has_zp) { + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { if (threadIdx.x / 32 < thread_n_blocks / 4) { #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { @@ -1877,15 +1874,30 @@ __global__ void Marlin( if (last || use_atomic_add) // only the last block in a slice actually writes the result write_result(); - if (slice_row) a_remaining_load_count_in_slice = stages; + int old_slice_row = slice_row; slice_row = 0; slice_col_par++; slice_col++; is_first_matmul_in_slice = true; init_slice(); + + // Should we load A matrix in next slice? + // `slice_col == 0`: when move to a new moe block + // `old_slice_row > 0`: + // when the last slice is not starting from k_index == 0 + // (only happen when it is the first slice of a threadblock) + // `prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups`: + // when the required shared memory size is larger than + // the remaining shared memory + if (slice_col == 0 || old_slice_row || + prob_k > thread_k_blocks * 16 * stages * max_num_stage_groups) { + should_load_a = true; + } else { + should_load_a = false; + } + if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd_col = (threadIdx.x % a_gl_rd_delta_o); #pragma unroll for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; @@ -1900,12 +1912,10 @@ __global__ void Marlin( slice_k_finish = slice_k_start + tb_k * slice_iters; slice_k_start_shared_fetch = slice_k_start; slice_n_offset = act_s_col_tb_stride * slice_col; - } else { s_gl_rd = s_sh_stride * slice_col + threadIdx.x; zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; } - start_pipes(); } } diff --git a/csrc/moe/marlin_moe_wna16/ops.cu b/csrc/moe/marlin_moe_wna16/ops.cu index a16e955a325e236a9131135d5935543603bf931c..2cff04f699b04a6694a11edcd3dd2e80be50a2ac 100644 --- a/csrc/moe/marlin_moe_wna16/ops.cu +++ b/csrc/moe/marlin_moe_wna16/ops.cu @@ -116,7 +116,7 @@ __global__ void permute_cols_kernel( int base_k = 0; for (int i = 0; i < iters; i++) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -126,7 +126,7 @@ __global__ void permute_cols_kernel( if (rest) { if (threadIdx.x < rest) { - int cur_k = base_k + threadIdx.x; + auto cur_k = base_k + threadIdx.x; int src_pos = perm_int_ptr[cur_k]; out_half[cur_k] = a_row_half[src_pos]; @@ -195,7 +195,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; - } else { int tb_scales = tb_groups * tb_n * 2; @@ -203,22 +202,24 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } } -int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float) { +int get_kernel_cache_size(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, + int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full, int has_zp, + int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - int tb_m = thread_m_blocks * 16; + int tb_m = thread_m_blocks * (m_block_size_8 ? 8 : 16); - // shm size for block_sorted_ids/block_topk_weights + // shm size for block_sorted_ids/rd_block_sorted_ids/block_topk_weights // both of them requires tb_m * 4 bytes (tb_m * int32 or tb_m * float32) - int sh_block_meta_size = tb_m * 4 * 2; + int sh_block_meta_size = tb_m * 4; int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); @@ -233,16 +234,17 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, sh_zp_size = sh_s_size / 2; } - int total_size = sh_a_size + sh_b_size + sh_s_size + sh_zp_size + - sh_g_idx_size + sh_block_meta_size; + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + + sh_zp_size + sh_g_idx_size + sh_block_meta_size; return total_size; } -bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float, int max_shared_mem) { +bool is_valid_config(thread_config_t const& th_config, bool m_block_size_8, + int thread_m_blocks, int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, bool has_act_order, + bool is_k_full, int has_zp, int is_zp_float, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -266,143 +268,129 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // Check that pipeline fits into cache int cache_size = get_kernel_cache_size( - th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, has_zp, is_zp_float); + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); return cache_size <= max_shared_mem; } - #define __GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \ - NUM_THREADS, IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - m_block_size_8 == M_BLOCK_SIZE_8 && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - kernel = Marlin; \ + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ } - #define GPTQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, true, false, 0, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, false, 8, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) - - #define GPTQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, false, 0, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, false, 8, \ - NUM_THREADS, false) - - #define AWQ_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 2, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 8, NUM_THREADS, \ - false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) - - #define AWQ_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) \ - \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, -1, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 2, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, false) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 8, \ - NUM_THREADS, false) + // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) + // this is the most common cases + // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) + // FZP: cases for float-zero-point (is_zp_float = true) + // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) + + #define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) // We currently have 4-bit models only with group_blocks == 4 - #define HQQ_GET_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, true, 4, NUM_THREADS, \ - true) \ - __GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) \ - __GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, true, 4, \ - NUM_THREADS, true) + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) template MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, @@ -415,23 +403,17 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, auto kernel = MarlinDefault; if (false) { } - GPTQ_GET_IF_M1(vllm::kU4B8, 8, 8, 256) - GPTQ_GET_IF_M1(vllm::kU4B8, 8, 4, 128) - - GPTQ_GET_IF_M234(vllm::kU4B8, 16, 4, 256) - GPTQ_GET_IF_M234(vllm::kU4B8, 8, 4, 128) - GPTQ_GET_IF_M1(vllm::kU8B128, 8, 8, 256) - GPTQ_GET_IF_M1(vllm::kU8B128, 8, 4, 128) + COMMON_GET_IF(vllm::kU4) + COMMON_GET_IF(vllm::kU4B8) + COMMON_GET_IF(vllm::kU8B128) - GPTQ_GET_IF_M234(vllm::kU8B128, 16, 4, 256) - GPTQ_GET_IF_M234(vllm::kU8B128, 8, 4, 128) + BIGGROUP_GET_IF(vllm::kFE4M3fn) - AWQ_GET_IF_M1(vllm::kU4, 8, 8, 256) - AWQ_GET_IF_M1(vllm::kU4, 8, 4, 128) + FP4_GET_IF(vllm::kFE2M1f) - AWQ_GET_IF_M234(vllm::kU4, 16, 4, 256) - AWQ_GET_IF_M234(vllm::kU4, 8, 4, 128) + ACT_GET_IF(vllm::kU4B8) + ACT_GET_IF(vllm::kU8B128) return kernel; } @@ -457,19 +439,19 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, for (int i = 0; i < thread_configs_size; i++) { thread_config_t th_config = thread_configs[i]; - if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + if (!is_valid_config(th_config, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem)) { continue; } int cache_size = get_kernel_cache_size( - th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full, has_zp, is_zp_float); + th_config, m_block_size_8, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float); int group_blocks = 0; if (!has_act_order) { - group_blocks = group_size == -1 ? -1 : group_size / 16; + group_blocks = group_size == -1 ? -1 : (group_size / 16); } auto kernel = get_marlin_kernel( @@ -501,7 +483,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, void* sorted_token_ids, void* expert_ids, void* num_tokens_past_padded, void* topk_weights, int moe_block_size, int top_k, bool mul_topk_weights, bool is_ep, @@ -520,8 +502,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128, - "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", q_type.str()); } @@ -555,6 +539,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -631,18 +616,18 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, int thread_k_blocks = thread_k / 16; int thread_n_blocks = thread_n / 16; - TORCH_CHECK(is_valid_config(thread_tfg, thread_m_blocks, prob_m, prob_n, - prob_k, num_bits, group_size, has_act_order, - is_k_full, has_zp, is_zp_float, max_shared_mem), - "Invalid thread config: thread_m_blocks = ", thread_m_blocks, - ", thread_k = ", thread_tfg.thread_k, - ", thread_n = ", thread_tfg.thread_n, - ", num_threads = ", thread_tfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, - ", max_shared_mem = ", max_shared_mem); + TORCH_CHECK( + is_valid_config(thread_tfg, m_block_size_8, thread_m_blocks, prob_m, + prob_n, prob_k, num_bits, group_size, has_act_order, + is_k_full, has_zp, is_zp_float, max_shared_mem), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, ", ", + prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, ", has_act_order = ", has_act_order, + ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, + ", is_zp_float = ", is_zp_float, ", max_shared_mem = ", max_shared_mem); auto kernel = get_marlin_kernel( q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, @@ -663,10 +648,10 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, // avoid ">>>" being formatted to "> > >" // clang-format off kernel<<>>( - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, sorted_token_ids_ptr, expert_ids_ptr, num_tokens_past_padded_ptr, topk_weights_ptr, top_k, mul_topk_weights, is_ep, num_groups, prob_m, - prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce); + prob_n, prob_k, locks, use_atomic_add, use_fp32_reduce, max_shared_mem); // clang-format on } @@ -675,6 +660,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, torch::Tensor moe_wna16_marlin_gemm( torch::Tensor& a, std::optional const& c_or_none, torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, std::optional const& b_zeros_or_none, std::optional const& g_idx_or_none, std::optional const& perm_or_none, torch::Tensor& workspace, @@ -826,6 +812,17 @@ torch::Tensor moe_wna16_marlin_gemm( } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + torch::Tensor b_zeros; if (b_zeros_or_none.has_value()) { b_zeros = b_zeros_or_none.value(); @@ -838,13 +835,15 @@ torch::Tensor moe_wna16_marlin_gemm( if (has_zp) { TORCH_CHECK( - b_q_type == vllm::kU4, - "b_q_type must be u4 when has_zp = True. Got = ", b_q_type.str()); + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type.str()); + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); } if (has_zp && is_zp_float) { @@ -889,9 +888,16 @@ torch::Tensor moe_wna16_marlin_gemm( int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), @@ -901,11 +907,18 @@ torch::Tensor moe_wna16_marlin_gemm( at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + MARLIN_NAMESPACE_NAME::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), sorted_token_ids.data_ptr(), expert_ids.data_ptr(), num_tokens_past_padded.data_ptr(), topk_weights.data_ptr(), moe_block_size, top_k, mul_topk_weights, is_ep, size_m, size_n, size_k, diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d7be769458e35b175d557dc62026b4afdf88bc5e..6b6a9d04a60f40df4542f0552f587d2347767cfc 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, cumsum_buffer.data_ptr()); }); } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.numel()); }); } else { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..76d5f0eab0218025381251b39ea979f0d98e5bc3 --- /dev/null +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -0,0 +1,133 @@ +#include +#include +#include +#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h" +#include "permute_unpermute_kernels/dispatch.h" +#include "core/registration.h" + +void moe_permute( + const torch::Tensor& input, // [n_token, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& token_expert_indicies, // [n_token, topk] + const std::optional& expert_map, // [n_expert] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + const std::optional& align_block_size, + torch::Tensor& + permuted_input, // [topk * n_token/align_block_size_m, hidden] + torch::Tensor& expert_first_token_offset, // [n_local_expert + 1] + torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + torch::Tensor& m_indices) { // [align_expand_m] + TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float, + "topk_weights must be float32"); + TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long, + "expert_first_token_offset must be int64"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int, + "token_expert_indicies must be int32"); + TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int, + "src_row_id2dst_row_id_map must be int32"); + TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1, + "expert_first_token_offset shape != n_local_expert+1") + TORCH_CHECK( + src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(), + "token_expert_indicies shape must be same as src_row_id2dst_row_id_map"); + auto n_token = input.sizes()[0]; + auto n_hidden = input.sizes()[1]; + auto align_block_size_value = + align_block_size.has_value() ? align_block_size.value() : -1; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const long sorter_size = + CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert); + auto sort_workspace = torch::empty( + {sorter_size}, + torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false)); + auto permuted_experts_id = torch::empty_like(topk_ids); + auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map); + auto align_expert_first_token_offset = + torch::zeros_like(expert_first_token_offset); + + CubKeyValueSorter sorter{}; + int64_t* valid_num_ptr = nullptr; + // pre-process kernel for expert-parallelism: + // no local expert id plus "n_expert" offset for priority to local expert + // map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1] + // For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id + // [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids + // and map global expert id [2, 3] to local_expert id [0, 1] and map global + // expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map + // operation is to make local expert high priority in following sort topk_ids + // and scan local expert_first_token_offset for each ep rank for next group + // gemm. + if (expert_map.has_value()) { + const int* expert_map_ptr = get_ptr(expert_map.value()); + valid_num_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + preprocessTopkIdLauncher(get_ptr(topk_ids), n_token * topk, + expert_map_ptr, n_expert, stream); + } + // expert sort topk expert id and scan expert id get expert_first_token_offset + sortAndScanExpert(get_ptr(topk_ids), get_ptr(token_expert_indicies), + get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(expert_first_token_offset), n_token, + n_expert, n_local_expert, topk, sorter, + get_ptr(sort_workspace), stream); + + // dispatch expandInputRowsKernelLauncher + MOE_DISPATCH(input.scalar_type(), [&] { + expandInputRowsKernelLauncher( + get_ptr(input), get_ptr(permuted_input), + get_ptr(topk_weights), get_ptr(permuted_experts_id), + get_ptr(dst_row_id2src_row_id_map), + get_ptr(src_row_id2dst_row_id_map), + get_ptr(expert_first_token_offset), n_token, valid_num_ptr, + n_hidden, topk, n_local_expert, align_block_size_value, stream); + }); + + // get m_indices and update expert_first_token_offset with align block + getMIndices(get_ptr(expert_first_token_offset), + get_ptr(align_expert_first_token_offset), + get_ptr(m_indices), n_local_expert, align_block_size_value, + stream); + if (align_block_size.has_value()) { + // update align_expert_first_token_offset + expert_first_token_offset.copy_(align_expert_first_token_offset); + } +} + +void moe_unpermute( + const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden] + const torch::Tensor& topk_weights, //[n_token, topk] + const torch::Tensor& topk_ids, // [n_token, topk] + const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk] + const torch::Tensor& expert_first_token_offset, // [n_local_expert+1] + int64_t n_expert, int64_t n_local_expert, int64_t topk, + torch::Tensor& hidden_states // [n_token, hidden] +) { + TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(), + "topk_ids shape must be same as src_row_id2dst_row_id_map"); + TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int, + "topk_ids must be int32"); + TORCH_CHECK( + permuted_hidden_states.scalar_type() == hidden_states.scalar_type(), + "topk_ids dtype must be same as src_row_id2dst_row_id_map"); + auto n_token = hidden_states.size(0); + auto n_hidden = hidden_states.size(1); + auto stream = at::cuda::getCurrentCUDAStream().stream(); + const int64_t* valid_ptr = + get_ptr(expert_first_token_offset) + n_local_expert; + MOE_DISPATCH(hidden_states.scalar_type(), [&] { + finalizeMoeRoutingKernelLauncher( + get_ptr(permuted_hidden_states), + get_ptr(hidden_states), get_ptr(topk_weights), + get_ptr(src_row_id2dst_row_id_map), get_ptr(topk_ids), + n_token, n_hidden, topk, valid_ptr, stream); + }); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("moe_permute", &moe_permute); + m.impl("moe_unpermute", &moe_unpermute); +} \ No newline at end of file diff --git a/csrc/moe/moe_wna16_utils.h b/csrc/moe/moe_wna16_utils.h index 4396b80240efe98741eeb85586ea2b1f9146a17b..8ef03f0e60527bd37850d98e07d7341bd62d653d 100644 --- a/csrc/moe/moe_wna16_utils.h +++ b/csrc/moe/moe_wna16_utils.h @@ -108,11 +108,11 @@ __device__ inline void dequant(int q, half2* res) { const int MUL = 0x2c002c00; const int ADD = 0xd400d400; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); q >>= 8; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); res[0] = __hsub2(*reinterpret_cast(&lo0), *reinterpret_cast(&SUB)); @@ -149,13 +149,13 @@ __device__ inline void dequant(int q, nv_bfloat162* res) { static constexpr uint32_t MASK = 0x000f000f; static constexpr uint32_t EX = 0x43004300; - int lo0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int lo0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int hi0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int lo1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int lo1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); q >>= 4; - int hi1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); + int hi1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); static constexpr uint32_t MUL = 0x3F803F80; static constexpr uint32_t ADD = 0xC300C300; diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..41932cdd85bcd35b1623943695d05c6935cc6038 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -0,0 +1,53 @@ +#pragma once +#include +#define MOE_SWITCH(TYPE, ...) \ + at::ScalarType _st = ::detail::scalar_type(TYPE); \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \ + } + +#define MOE_DISPATCH_CASE(enum_type, ...) \ + case enum_type: { \ + using scalar_t = ScalarType2CudaType::type; \ + __VA_ARGS__(); \ + break; \ + } +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) + +#define MOE_DISPATCH(TYPE, ...) \ + MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) + +template +struct ScalarType2CudaType; + +template <> +struct ScalarType2CudaType { + using type = float; +}; +template <> +struct ScalarType2CudaType { + using type = half; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_bfloat16; +}; + +// #if __CUDA_ARCH__ >= 890 +// fp8 +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e5m2; +}; +template <> +struct ScalarType2CudaType { + using type = __nv_fp8_e4m3; +}; +// #endif \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..aa353d0f0437f863d79ed3b3151d40cbfefcb33d --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu @@ -0,0 +1,229 @@ + +#include "moe_permute_unpermute_kernel.h" + +// CubKeyValueSorter definition begin +CubKeyValueSorter::CubKeyValueSorter() + : num_experts_(0), num_bits_(sizeof(int) * 8) {} + +int CubKeyValueSorter::expertsToBits(int num_experts) { + // Max value we represent is V = num_experts + (num_experts - 1) = 2 * + // num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1 + return static_cast(log2(2 * num_experts - 1)) + 1; +} + +CubKeyValueSorter::CubKeyValueSorter(int const num_experts) + : num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {} + +void CubKeyValueSorter::updateNumExperts(int const num_experts) { + num_experts_ = num_experts; + num_bits_ = expertsToBits(num_experts); +} + +size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts) { + int num_bits = expertsToBits(num_experts); + size_t required_storage = 0; + int* null_int = nullptr; + cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int, + null_int, null_int, num_key_value_pairs, 0, + num_bits); + + // when num_key_value_pairs, num_experts, num_bits, required_storage = 64, + // 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same + // inputs + if (required_storage == 0) { + required_storage = 1; + } + return required_storage; +} + +void CubKeyValueSorter::run(void* workspace, size_t const workspace_size, + int const* keys_in, int* keys_out, + int const* values_in, int* values_out, + size_t const num_key_value_pairs, + cudaStream_t stream) { + size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_); + size_t actual_ws_size = workspace_size; + + TORCH_CHECK(expected_ws_size <= workspace_size, + "[CubKeyValueSorter::run] The allocated workspace is too small " + "to run this problem."); + cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out, + values_in, values_out, num_key_value_pairs, 0, + num_bits_, stream); +} +// CubKeyValueSorter definition end + +static inline size_t pad_to_multiple_of_16(size_t const& input) { + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} +template +__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices, + int64_t const arr_length, + T const target) { + int64_t low = 0, high = arr_length - 1, target_location = -1; + while (low <= high) { + int64_t mid = (low + high) / 2; + + if (sorted_indices[mid] >= target) { + high = mid - 1; + } else { + low = mid + 1; + target_location = mid; + } + } + return target_location + 1; +} + +// Calculates the start offset of the tokens for a given expert. The last +// element is the total number of valid tokens +__global__ void computeExpertFirstTokenOffsetKernel( + int const* sorted_experts, int64_t const sorted_experts_len, + int const num_experts, int64_t* expert_first_token_offset) { + // First, compute the global tid. We only need 1 thread per expert. + int const expert = blockIdx.x * blockDim.x + threadIdx.x; + + // Note that expert goes [0, num_experts] (inclusive) because we want a count + // for the total number of active tokens at the end of the scan. + if (expert >= num_experts + 1) { + return; + } + expert_first_token_offset[expert] = + findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert); +} + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream) { + int const num_entries = num_experts + 1; + int const threads = std::min(1024, num_entries); + int const blocks = (num_entries + threads - 1) / threads; + + computeExpertFirstTokenOffsetKernel<<>>( + sorted_indices, total_indices, num_experts, expert_first_token_offset); +} + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream) { + int64_t const expanded_num_rows = static_cast(k) * num_rows; + // We need to use the full num_experts because that is the sentinel value used + // by topk for disabled experts + sorter.updateNumExperts(num_experts); + size_t const sorter_ws_size_bytes = pad_to_multiple_of_16( + sorter.getWorkspaceSize(expanded_num_rows, num_experts)); + sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row, + permuted_experts, source_rows, permuted_rows, expanded_num_rows, + stream); + computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows, + num_experts_per_node, expert_first_token_offset, + stream); +} + +__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size, + const int* expert_map_ptr, + int num_experts) { + auto tidx = threadIdx.x; + auto bidx = blockIdx.x; + auto lidx = tidx & 31; + auto widx = tidx >> 5; + auto warp_count = (blockDim.x + 31) >> 5; + auto offset = bidx * blockDim.x; + auto bound = min(offset + blockDim.x, size); + extern __shared__ int smem_expert_map[]; + // store expert_map in smem + for (int i = tidx; i < num_experts; i += blockDim.x) { + smem_expert_map[i] = expert_map_ptr[i]; + } + __syncthreads(); + + // query global expert id in expert map. + // if global expert id = -1 in exert map, plus n_expert + // else set global expert id = exert map[global expert id] + if (offset + tidx < bound) { + auto topk_id = topk_id_ptr[offset + tidx]; + auto local_expert_idx = smem_expert_map[topk_id]; + if (local_expert_idx == -1) { + topk_id += num_experts; + } else { + topk_id = local_expert_idx; + } + __syncwarp(); + topk_id_ptr[offset + tidx] = topk_id; + } +} +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream) { + int block = std::min(size, 1024); + int grid = (size + block - 1) / block; + int smem_size = (num_experts) * sizeof(int); + preprocessTopkIdKernel<<>>( + topk_id_ptr, size, expert_map_ptr, num_experts); +} + +template +__global__ void getMIndicesKernel(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, + int* m_indices, const int num_local_expert, + const int align_block_size) { + int eidx = blockIdx.x; + int tidx = threadIdx.x; + extern __shared__ int64_t smem_expert_first_token_offset[]; + for (int i = tidx; i <= num_local_expert; i += blockDim.x) { + smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i); + } + __syncthreads(); + auto last_token_offset = smem_expert_first_token_offset[eidx + 1]; + auto first_token_offset = smem_expert_first_token_offset[eidx]; + int n_token_in_expert = last_token_offset - first_token_offset; + + if constexpr (ALIGN_BLOCK_SIZE) { + n_token_in_expert = (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + // round up to ALIGN_BLOCK_SIZE + int64_t accumulate_align_offset = 0; + for (int i = 1; i <= eidx + 1; i++) { + int n_token = smem_expert_first_token_offset[i] - + smem_expert_first_token_offset[i - 1]; + accumulate_align_offset = + accumulate_align_offset + (n_token + align_block_size - 1) / + align_block_size * align_block_size; + if (i == eidx) { + first_token_offset = accumulate_align_offset; + } + // last block store align_expert_first_token_offset + if (eidx == num_local_expert - 1 && threadIdx.x == 0) { + align_expert_first_token_offset[i] = accumulate_align_offset; + } + } + } + for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) { + // update m_indice with expert id + m_indices[first_token_offset + idx] = eidx; + } +} + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream) { + int block = 256; + int grid = num_local_expert; + int smem_size = sizeof(int64_t) * (num_local_expert + 1); + if (align_block_size == -1) { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } else { + getMIndicesKernel<<>>( + expert_first_token_offset, align_expert_first_token_offset, m_indices, + num_local_expert, align_block_size); + } +} \ No newline at end of file diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..43c29721cd16e7faa2d0750a7b4c7d982f9d4287 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.h @@ -0,0 +1,95 @@ +#pragma once +// reference from tensorrt_llm moe kernel implementation archive in +// https://github.com/BBuf/tensorrt-llm-moe/tree/master + +#include +#include +#include "dispatch.h" +#include +#include +#include +#include "cutlass/numeric_size.h" +#include "cutlass/array.h" + +template +inline T* get_ptr(torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +template +inline const T* get_ptr(const torch::Tensor& t) { + return reinterpret_cast(t.data_ptr()); +} + +class CubKeyValueSorter { + public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts); + + void updateNumExperts(int const num_experts); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, + int const num_experts); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, + int* keys_out, int const* values_in, int* values_out, + size_t const num_key_value_pairs, cudaStream_t stream); + + private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +void computeExpertFirstTokenOffset(int const* sorted_indices, + int const total_indices, + int const num_experts, + int64_t* expert_first_token_offset, + cudaStream_t stream); + +void sortAndScanExpert(int* expert_for_source_row, const int* source_rows, + int* permuted_experts, int* permuted_rows, + int64_t* expert_first_token_offset, int num_rows, + int num_experts, int num_experts_per_node, int k, + CubKeyValueSorter& sorter, void* sorter_ws, + cudaStream_t stream); + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream); + +// Final kernel to unpermute and scale +// This kernel unpermutes the original data, does the k-way reduction and +// performs the final skip connection. +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr); + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream); + +void preprocessTopkIdLauncher(int* topk_id_ptr, int size, + const int* expert_map_ptr, int num_experts, + cudaStream_t stream); + +void getMIndices(int64_t* expert_first_token_offset, + int64_t* align_expert_first_token_offset, int* m_indices, + int num_local_expert, const int align_block_size, + cudaStream_t stream); + +#include "moe_permute_unpermute_kernel.inl" diff --git a/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl new file mode 100644 index 0000000000000000000000000000000000000000..42441800fb1107fe0b1a97f9d9a7d1e205ec1452 --- /dev/null +++ b/csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.inl @@ -0,0 +1,211 @@ +#pragma once + +template +__global__ void expandInputRowsKernel( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_dest_rows, int64_t const cols, int64_t k, + int num_local_experts, int align_block_size) { + // Reverse permutation map. + // I do this so that later, we can use the source -> dest map to do the k-way + // reduction and unpermuting. I need the reverse map for that reduction to + // allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1 + // thread block will be responsible for all k summations. + int64_t expanded_dest_row = blockIdx.x; + int64_t const expanded_source_row = + expanded_dest_row_to_expanded_source_row[expanded_dest_row]; + int expert_id = sorted_experts[expanded_dest_row]; + + extern __shared__ int64_t smem_expert_first_token_offset[]; + int64_t align_expanded_row_accumulate = 0; + if constexpr (ALIGN_BLOCK_SIZE) { + // load g2s + for (int idx = threadIdx.x; idx < num_local_experts + 1; + idx += blockDim.x) { + smem_expert_first_token_offset[idx] = + __ldg(expert_first_token_offset + idx); + } + __syncthreads(); + int lane_idx = threadIdx.x & 31; + + if (lane_idx == 0) { + // set token_offset_in_expert = 0 if this expert is not local expert + int token_offset_in_expert = + expert_id >= num_local_experts + ? 0 + : expanded_dest_row - smem_expert_first_token_offset[expert_id]; + int64_t accumulate_align_offset = 0; +#pragma unroll 1 + for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) { + auto n_token_in_expert = smem_expert_first_token_offset[eidx] - + smem_expert_first_token_offset[eidx - 1]; + accumulate_align_offset += (n_token_in_expert + align_block_size - 1) / + align_block_size * align_block_size; + } + expanded_dest_row = accumulate_align_offset + token_offset_in_expert; + } + // lane0 shuffle broadcast align_expanded_dest_row + expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0); + } + + if (threadIdx.x == 0) { + assert(expanded_dest_row <= INT32_MAX); + expanded_source_row_to_expanded_dest_row[expanded_source_row] = + static_cast(expanded_dest_row); + } + + if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits::value; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + int64_t const source_k_rank = expanded_source_row / num_rows; + int64_t const source_row = expanded_source_row % num_rows; + + auto const* source_row_ptr = + reinterpret_cast(unpermuted_input + source_row * cols); + auto* dest_row_ptr = + reinterpret_cast(permuted_output + expanded_dest_row * cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +template +void expandInputRowsKernelLauncher( + T const* unpermuted_input, T* permuted_output, + const float* unpermuted_scales, int* sorted_experts, + int const* expanded_dest_row_to_expanded_source_row, + int* expanded_source_row_to_expanded_dest_row, + int64_t* expert_first_token_offset, int64_t const num_rows, + int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k, + int num_local_experts, const int& align_block_size, cudaStream_t stream) { + int64_t const blocks = num_rows * k; + int64_t const threads = 256; + using FuncPtr = decltype(&expandInputRowsKernel); + FuncPtr func_map[2][2] = { + {&expandInputRowsKernel, + &expandInputRowsKernel}, + {&expandInputRowsKernel, + &expandInputRowsKernel}, + }; + bool is_check_skip = num_valid_tokens_ptr != nullptr; + bool is_align_block_size = align_block_size != -1; + auto func = func_map[is_check_skip][is_align_block_size]; + + int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1); + + func<<>>( + unpermuted_input, permuted_output, unpermuted_scales, sorted_experts, + expanded_dest_row_to_expanded_source_row, + expanded_source_row_to_expanded_dest_row, expert_first_token_offset, + num_rows, num_valid_tokens_ptr, cols, k, num_local_experts, + align_block_size); +} + +template +__host__ __device__ constexpr static U arrayConvert(T const& input) { + using Type = typename U::Element; + static_assert(T::kElements == U::kElements); + U u; +#pragma unroll + for (int i = 0; i < U::kElements; i++) { + u[i] = static_cast(input[i]); + } + return u; +} + +template +__global__ void finalizeMoeRoutingKernel( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const orig_cols, int64_t const k, + int64_t const* num_valid_ptr) { + assert(orig_cols % 4 == 0); + int64_t const original_row = blockIdx.x; + int64_t const num_rows = gridDim.x; + auto const offset = original_row * orig_cols; + OutputType* reduced_row_ptr = reduced_unpermuted_output + offset; + int64_t const num_valid = *num_valid_ptr; + + // Load 128-bits per thread, according to the smallest data type we read/write + constexpr int64_t FINALIZE_ELEM_PER_THREAD = + 128 / std::min(cutlass::sizeof_bits::value, + cutlass::sizeof_bits::value); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD; + + using InputElem = cutlass::Array; + using OutputElem = cutlass::Array; + using ComputeElem = cutlass::Array; + auto const* expanded_permuted_rows_v = + reinterpret_cast(expanded_permuted_rows); + auto* reduced_row_ptr_v = reinterpret_cast(reduced_row_ptr); + +#pragma unroll + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + ComputeElem thread_output; + thread_output.fill(0); + float row_rescale{0.f}; + for (int k_idx = 0; k_idx < k; ++k_idx) { + int64_t const expanded_original_row = original_row + k_idx * num_rows; + int64_t const expanded_permuted_row = + expanded_source_row_to_expanded_dest_row[expanded_original_row]; + + int64_t const k_offset = original_row * k + k_idx; + float const row_scale = scales[k_offset]; + + // Check after row_rescale has accumulated + if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) { + continue; + } + + auto const* expanded_permuted_rows_row_ptr = + expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col; + + int64_t const expert_idx = expert_for_source_row[k_offset]; + + ComputeElem expert_result = arrayConvert( + expanded_permuted_rows_row_ptr[elem_index]); + thread_output = thread_output + row_scale * (expert_result); + } + + OutputElem output_elem = + arrayConvert(thread_output); + reduced_row_ptr_v[elem_index] = output_elem; + } +} + +template +void finalizeMoeRoutingKernelLauncher( + T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output, + float const* scales, int const* expanded_source_row_to_expanded_dest_row, + int const* expert_for_source_row, int64_t const num_rows, + int64_t const cols, int64_t const k, int64_t const* num_valid_ptr, + cudaStream_t stream) { + int64_t const blocks = num_rows; + int64_t const threads = 256; + bool const check_finished = num_valid_ptr != nullptr; + using FuncPtr = decltype(&finalizeMoeRoutingKernel); + FuncPtr func_map[2] = {&finalizeMoeRoutingKernel, + &finalizeMoeRoutingKernel}; + auto* const kernel = func_map[check_finished]; + kernel<<>>( + expanded_permuted_rows, reduced_unpermuted_output, scales, + expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k, + num_valid_ptr); +} diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b602524a158808e8bebb040c10471059f1..a9379032245d9b74838ff30398cf796a6568ca72 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + IndType* indices, + int* source_rows, + const int num_experts, + const int k, + const int start_expert, + const int end_expert) { using cub_kvp = cub::KeyValuePair; @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. @@ -397,8 +405,8 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); +template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - int* topk_indicies, + IndType* topk_indicies, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -493,14 +502,32 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + + if(topk_indices.scalar_type() == at::ScalarType::Int) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else + { + assert(topk_indices.scalar_type() == at::ScalarType::UInt32); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d0de42251f97adf6624c39d1feeea041a1e09e58..810026d034c07d5ff3ac612a2cb1786c1ffd9676 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def( "moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none," - "Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none," + "Tensor! b_q_weight, Tensor! b_scales, Tensor? global_scale, Tensor? " + "b_zeros_or_none," "Tensor? g_idx_or_none, Tensor? perm_or_none, Tensor! workspace," "Tensor sorted_token_ids," "Tensor! expert_ids, Tensor! num_tokens_past_padded," @@ -53,7 +54,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int size_m, int size_n, int size_k," "bool is_full_k, bool use_atomic_add," "bool use_fp32_reduce, bool is_zp_float) -> Tensor"); + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " + "int b_q_type, SymInt size_m, " + "SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int " + "topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); + + m.def( + "moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids," + "Tensor token_expert_indicies, Tensor? expert_map, int n_expert," + "int n_local_expert," + "int topk, int? align_block_size,Tensor! permuted_input, Tensor! " + "expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! " + "m_indices)->()"); + m.def( + "moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights," + "Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor " + "expert_first_token_offset, int n_expert, int n_local_expert,int " + "topk, Tensor! hidden_states)->()"); // conditionally compiled so impl registration is in source file #endif diff --git a/csrc/ops.h b/csrc/ops.h index 01873caec04ca4e7ba9e2cd3f15cb9f38c4f23cd..6e19f033bd68e0ea260a765fbed75229dd949d34 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -59,6 +59,31 @@ void merge_attn_states(torch::Tensor& output, const torch::Tensor& prefix_lse, const torch::Tensor& suffix_output, const torch::Tensor& suffix_lse); + +void convert_vertical_slash_indexes( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + int64_t context_size, int64_t block_size_M, int64_t block_size_N, + bool causal); + +void convert_vertical_slash_indexes_mergehead( + torch::Tensor& block_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& block_offset, // [BATCH, N_HEADS, NUM_ROWS, NNZ_S] + torch::Tensor& column_count, // [BATCH, N_HEADS, NUM_ROWS] + torch::Tensor& column_index, // [BATCH, N_HEADS, NUM_ROWS, NNZ_V] + torch::Tensor q_seqlens, // [BATCH, ] + torch::Tensor kv_seqlens, // [BATCH, ] + torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V] + torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S] + torch::Tensor vertical_indices_count, // [N_HEADS, ] + torch::Tensor slash_indices_count, int64_t context_size, + int64_t block_size_M, int64_t block_size_N, bool causal); #endif void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, @@ -86,17 +111,20 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, // std::optional residual); void rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, + std::optional key, int64_t head_size, torch::Tensor& cos_sin_cache, bool is_neox); void batched_rotary_embedding(torch::Tensor& positions, torch::Tensor& query, - torch::Tensor& key, int64_t head_size, - torch::Tensor& cos_sin_cache, bool is_neox, - int64_t rot_dim, + std::optional key, + int64_t head_size, torch::Tensor& cos_sin_cache, + bool is_neox, int64_t rot_dim, torch::Tensor& cos_sin_cache_offsets); void silu_and_mul(torch::Tensor& out, torch::Tensor& input); +void silu_and_mul_quant(torch::Tensor& out, torch::Tensor& input, + torch::Tensor& scale); + void mul_and_silu(torch::Tensor& out, torch::Tensor& input); void gelu_and_mul(torch::Tensor& out, torch::Tensor& input); @@ -177,6 +205,10 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, torch::Tensor W, torch::Tensor num_tokens_post_padded, int64_t type, int64_t row, int64_t top_k, int64_t tokens); +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, torch::Tensor W, + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens); + int64_t ggml_moe_get_block_size(int64_t type); #ifndef USE_ROCM @@ -203,6 +235,12 @@ void cutlass_moe_mm( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets); + void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, @@ -230,6 +268,12 @@ std::vector cutlass_sparse_compress(torch::Tensor const& a); void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_scale, torch::Tensor const& input_scale); + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index c085d31a3e9b183cba672b83f173da842f1601c7..266f2a0667a24e3bcbb06d49449369fce6b69042 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -38,12 +38,14 @@ inline __device__ void apply_rotary_embedding( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + const int64_t query_stride, const int64_t key_stride, + const int64_t head_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -51,19 +53,23 @@ inline __device__ void apply_rotary_embedding( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = + token_idx * query_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); } - const int nk = num_kv_heads * embed_dim; - for (int i = threadIdx.x; i < nk; i += blockDim.x) { - const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; - const int rot_offset = i % embed_dim; - apply_token_rotary_embedding( - key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + if (key != nullptr) { + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = + token_idx * key_stride + head_idx * head_stride; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } } } @@ -74,13 +80,15 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -88,7 +96,7 @@ __global__ void rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } template @@ -98,15 +106,16 @@ __global__ void batched_rotary_embedding_kernel( scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, // head_size] or [num_tokens, num_heads, // head_size] - scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + scalar_t* __restrict__ key, // nullptr or + // [batch_size, seq_len, num_kv_heads, // head_size] or [num_tokens, num_kv_heads, // head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -116,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } } // namespace vllm @@ -127,10 +136,12 @@ void rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional key, + // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox) { @@ -138,40 +149,46 @@ void rotary_embedding( int64_t num_tokens = positions.numel(); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key // hidden_size = num_heads * head_size int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have consistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int rot_dim = cos_sin_cache.size(1); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -181,15 +198,16 @@ void rotary_embedding( if (is_neox) { vllm::rotary_embedding_kernel<<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, - query_stride, key_stride, num_heads, num_kv_heads, head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, + head_stride, num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), - rot_dim, query_stride, key_stride, num_heads, num_kv_heads, - head_size); + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), rot_dim, query_stride, + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } @@ -204,10 +222,12 @@ void batched_rotary_embedding( // [num_tokens, num_heads * head_size] or // [batch_size, seq_len, num_heads, head_size] or // [num_tokens, num_heads, head_size] - torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or - // [num_tokens, num_kv_heads * head_size] or - // [batch_size, seq_len, num_heads, head_size] or - // [num_tokens, num_heads, head_size] + std::optional + key, // null or + // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] or + // [batch_size, seq_len, num_heads, head_size] or + // [num_tokens, num_heads, head_size] int64_t head_size, torch::Tensor& cos_sin_cache, // [max_position, rot_dim] bool is_neox, int64_t rot_dim, @@ -221,38 +241,44 @@ void batched_rotary_embedding( "cos_sin_cache_offsets"); int positions_ndim = positions.dim(); - // Make sure num_tokens dim is consistent across positions, query, and key. + // Make sure num_tokens dim is consistent across positions, query, and key TORCH_CHECK( positions_ndim == 1 || positions_ndim == 2, "positions must have shape [num_tokens] or [batch_size, seq_len]"); if (positions_ndim == 1) { - TORCH_CHECK( - query.size(0) == positions.size(0) && key.size(0) == positions.size(0), - "query, key and positions must have the same number of tokens"); + TORCH_CHECK(query.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)), + "query, key and positions must have the same number of tokens"); } if (positions_ndim == 2) { TORCH_CHECK( query.size(0) == positions.size(0) && - key.size(0) == positions.size(0) && + (!key.has_value() || key->size(0) == positions.size(0)) && query.size(1) == positions.size(1) && - key.size(1) == positions.size(1), + (!key.has_value() || key->size(1) == positions.size(1)), "query, key and positions must have the same batch_size and seq_len"); } // Make sure head_size is valid for query and key int query_hidden_size = query.numel() / num_tokens; - int key_hidden_size = key.numel() / num_tokens; + int key_hidden_size = key.has_value() ? key->numel() / num_tokens : 0; TORCH_CHECK(query_hidden_size % head_size == 0); TORCH_CHECK(key_hidden_size % head_size == 0); // Make sure query and key have concistent number of heads int num_heads = query_hidden_size / head_size; - int num_kv_heads = key_hidden_size / head_size; + int num_kv_heads = key.has_value() ? key_hidden_size / head_size : num_heads; TORCH_CHECK(num_heads % num_kv_heads == 0); int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); - int64_t key_stride = key.stride(seq_dim_idx); + int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -263,16 +289,18 @@ void batched_rotary_embedding( vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), - key.data_ptr(), cos_sin_cache.data_ptr(), + key.has_value() ? key->data_ptr() : nullptr, + cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } diff --git a/csrc/quantization/activation_kernels.cu b/csrc/quantization/activation_kernels.cu new file mode 100644 index 0000000000000000000000000000000000000000..67e9149c137950c068534cc696317f2c6a1bce1d --- /dev/null +++ b/csrc/quantization/activation_kernels.cu @@ -0,0 +1,121 @@ +#include +#include +#include + +#include +#include "core/math.hpp" +#include "cuda_compat.h" +#include "dispatch_utils.h" + +#include "quantization/fp8/common.cuh" + +namespace vllm { + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +// Activation and gating kernel template. +template +__global__ void act_and_mul_quant_kernel( + fp8_type* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const float* scale, const int d) { + const int32_t blocks_per_token = gridDim.y; + + const int32_t elems_per_128bit_load = (128 / 8) / sizeof(scalar_t); + + // We don't expect the hidden dimension to exceed 32 bits so int32 should + // be safe here. + const int32_t tgt_elems_per_block = div_ceil(d, blocks_per_token); + const int32_t elems_per_block = + round_to_next_multiple_of(tgt_elems_per_block, elems_per_128bit_load); + const int32_t block_start = blockIdx.y * elems_per_block; + int32_t block_end = block_start + elems_per_block; + block_end = block_end > d ? d : block_end; + + // token_idx is 64 bit to prevent 32 bit overflow when the number of tokens + // is very large + const int64_t token_idx = blockIdx.x; + const scalar_t* __restrict__ x_ptr = input + token_idx * 2 * d; + const scalar_t* __restrict__ y_ptr = input + token_idx * 2 * d + d; + fp8_type* __restrict__ out_ptr = out + token_idx * d; + + // 128-bit vectorized code + const int32_t vec_loop_end = + round_to_previous_multiple_of(elems_per_128bit_load, block_end); + const int32_t vec_end_idx = vec_loop_end / elems_per_128bit_load; + const int32_t vec_start_idx = block_start / elems_per_128bit_load; + + const int4* __restrict__ x_128bit_ptr = reinterpret_cast(x_ptr); + const int4* __restrict__ y_128bit_ptr = reinterpret_cast(y_ptr); + int2* __restrict__ out_128bit_ptr = reinterpret_cast(out_ptr); + + float inverted_scale = 1 / *scale; +#pragma unroll + for (int32_t vec_idx = vec_start_idx + threadIdx.x; vec_idx < vec_end_idx; + vec_idx += blockDim.x) { + const int4 x_128bit = VLLM_LDG(&x_128bit_ptr[vec_idx]); + const int4 y_128bit = VLLM_LDG(&y_128bit_ptr[vec_idx]); + using scalar_128bit_vec_t = std::array; + using scalar_64bit_vec_t = std::array; + + scalar_64bit_vec_t out_vec; + const auto x_vec = reinterpret_cast(x_128bit); + const auto y_vec = reinterpret_cast(y_128bit); + +#pragma unroll + for (int i = 0; i < elems_per_128bit_load; i++) { + out_vec[i] = scaled_fp8_conversion( + ACT_FN(x_vec[i]) * y_vec[i], inverted_scale); + } + + out_128bit_ptr[vec_idx] = reinterpret_cast(out_vec); + } + + // Scalar cleanup code + if (block_end > vec_loop_end) { + for (int64_t idx = vec_loop_end + threadIdx.x; idx < block_end; + idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&x_ptr[idx]); + const scalar_t y = VLLM_LDG(&y_ptr[idx]); + out_ptr[idx] = + scaled_fp8_conversion(ACT_FN(x) * y, inverted_scale); + } + } +} +} // namespace vllm + +// Launch activation, gating, and quantize kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens, num_tokens > 16 ? num_tokens > 32 ? 1 : 2 : 4); \ + dim3 block(std::min(d, 512)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + VLLM_DISPATCH_FP8_TYPES( \ + out.scalar_type(), "fused_add_rms_norm_kernel_fp8_type", [&] { \ + vllm::act_and_mul_quant_kernel, \ + fp8_t> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), \ + scale.data_ptr(), d); \ + }); \ + }); + +void silu_and_mul_quant(torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., 2 * d] + torch::Tensor& scale) { + TORCH_CHECK(out.dtype() == torch::kFloat8_e4m3fn || + out.dtype() == torch::kFloat8_e4m3fnuz); + TORCH_CHECK(input.dtype() == torch::kFloat16 || + input.dtype() == torch::kBFloat16); + TORCH_CHECK(input.size(-1) % 2 == 0); + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index e79785827189de75ea03d06c462787d55cf0763b..bf46cce60a233909205140dfa66680f75a7720c6 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -26,7 +26,13 @@ static inline __device__ int8_t float_to_int8_rn(float x) { float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path @@ -79,7 +85,13 @@ static inline __device__ int8_t int32_to_int8(int32_t x) { static_cast(std::numeric_limits::max()); // saturate - int32_t dst = std::clamp(x, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // int32_t dst = std::clamp(x, i8_min, i8_max); + int32_t dst = (x < i8_min) ? i8_min : (x > i8_max) ? i8_max : x; return static_cast(dst); #else // CUDA path diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 0000000000000000000000000000000000000000..84492553c02f2177e3fa81da1033fcb612e6f98c --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,27 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + TORCH_CHECK( + a.size(0) % 4 == 0, + "Input tensor must have a number of rows that is a multiple of 4. ", + "but got: ", a.size(0), " rows."); + if (out.dtype() == torch::kBFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + TORCH_CHECK(out.dtype() == torch::kFloat16); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ef324364c6d5e01cc3f32222f07cfe0fdd20f589 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -0,0 +1,205 @@ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass_extensions/gemm/dispatch_policy.hpp" +#include "cutlass_extensions/gemm/collective/collective_builder.hpp" + +#include "cutlass_gemm_caller.cuh" + +namespace vllm { + +using namespace cute; + +template +struct cutlass_3x_gemm_fp8_blockwise { + using ElementAB = cutlass::float_e4m3_t; + + using ElementA = ElementAB; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; + + using ElementB = ElementAB; + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using ElementD = OutType; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + using LayoutC = LayoutD; + static constexpr int AlignmentC = AlignmentD; + + using ElementAccumulator = float; + using ElementCompute = float; + using ElementBlockScale = float; + + // MMA and Cluster Tile Shapes + // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster + // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; + static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); + static constexpr int ScaleGranularityM = + size<0>(MmaTileShape{}) / ScaleMsPerTile; + static constexpr int ScaleGranularityN = + size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); + static constexpr int ScaleGranularityK = + size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); + + // Shape of the threadblocks in a cluster + using ClusterShape_MNK = ClusterShape; + + using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); + using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); + + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassTensorOp; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using ElementScalar = float; + // clang-format off + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + MmaTileShape, + ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, + ElementCompute, + ElementC, + LayoutC, + AlignmentC, + ElementD, + LayoutD, + AlignmentD, + EpilogueScheduler, + DefaultOperation + >::CollectiveOp; + + using StageCountType = cutlass::gemm::collective::StageCountAuto; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp; + // clang-format on + + using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; + + struct GemmKernel : public KernelType {}; +}; + +template +void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + using GemmKernel = typename Gemm::GemmKernel; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideD = typename Gemm::GemmKernel::StrideD; + using StrideC = typename Gemm::GemmKernel::StrideC; + using LayoutSFA = typename Gemm::LayoutSFA; + using LayoutSFB = typename Gemm::LayoutSFB; + using ScaleConfig = typename Gemm::ScaleConfig; + + using ElementAB = typename Gemm::ElementAB; + using ElementD = typename Gemm::ElementD; + + int32_t m = a.size(0), n = b.size(1), k = a.size(1); + auto prob_shape = cute::make_shape(m, n, k, 1); + + StrideA a_stride; + StrideB b_stride; + StrideC c_stride; + a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1)); + b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); + c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + + LayoutSFA layout_SFA = + ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); + LayoutSFB layout_SFB = + ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); + + auto a_ptr = static_cast(a.data_ptr()); + auto b_ptr = static_cast(b.data_ptr()); + auto a_scales_ptr = static_cast(a_scales.data_ptr()); + auto b_scales_ptr = static_cast(b_scales.data_ptr()); + + typename GemmKernel::MainloopArguments mainloop_args{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + + auto c_ptr = static_cast(out.data_ptr()); + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, c_ptr, c_stride, c_ptr, c_stride}; + c3x::cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, + epilogue_args); +} + +template +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales) { + auto m = a.size(0); + auto k = a.size(1); + auto n = b.size(1); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + + auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { + return std::ceil(static_cast(m) / tile1SM) * + std::ceil(static_cast(n) / tile1SM) >= + sms; + }; + bool use_2sm = should_use_2sm(m, n); + if (use_2sm) { + cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, + Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2ee6a19407f923e13110fae658e0809f141afa03 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -0,0 +1,75 @@ +#include +#include "cuda_utils.h" +#include "cutlass_extensions/common.hpp" + +template +void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, + torch::Tensor const& b, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias, + Fp8Func fp8_func, Int8Func int8_func, + BlockwiseFunc blockwise_func) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + int M = a.size(0), N = b.size(1), K = a.size(1); + + if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && + (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { + // Standard per-tensor/per-token/per-channel scaling + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.dtype() == torch::kFloat8_e4m3fn) { + fp8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(a.dtype() == torch::kInt8); + if constexpr (!std::is_same_v) { + int8_func(c, a, b, a_scales, b_scales, bias); + } else { + TORCH_CHECK(false, "Int8 not supported for this architecture"); + } + } + } else { + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + int32_t version_num = get_sm_version_num(); + if (version_num >= 100) { + TORCH_CHECK( + a.size(0) == a_scales.size(0) && + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), + "a_scale_group_shape must be [1, 128]."); + TORCH_CHECK( + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), + "b_scale_group_shape must be [128, 128]."); + } else { + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm + // kernel, or introducing ceil_div to the load_init() of mainloop. + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + } + + TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + blockwise_func(c, a, b, a_scales, b_scales); + } +} diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 85272804774dbf6f45c60ee6d9113d39dc92a32c..c1242fdb39da9c58b1d4fbd674631862cecf7446 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, std::optional const& bias); +void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales); } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu index 459eb1bb76eb07e3a380fa7bc283218fb0d778b8..0cbd5305e3c252c9bf3d9bc8ea8a816b7306690d 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm100 (Blackwell). @@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - TORCH_CHECK( - (a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1)), - "Currently, block scaled fp8 gemm is not implemented for Blackwell"); - - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn, - "Currently, only fp8 gemm is implemented for Blackwell"); - vllm::cutlass_scaled_mm_sm100_fp8(c, a, b, a_scales, b_scales, bias); + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm100_fp8, + nullptr, // int8 not supported on SM100 + vllm::cutlass_scaled_mm_blockwise_sm100_fp8); } #endif diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu index bcb91040d5e2e18b08b96ae261fb81f9f8bb44d4..211302171f07458d3d2392edc4eb0a4ebdddbfbb 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu @@ -1,8 +1,6 @@ -#include +#include "c3x/scaled_mm_helper.hpp" #include "c3x/scaled_mm_kernels.hpp" -#include "cuda_utils.h" - /* This file defines quantized GEMM operations using the CUTLASS 3.x API, for NVIDIA GPUs with sm90a (Hopper). @@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - - int M = a.size(0), N = b.size(1), K = a.size(1); - - if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && - (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { - // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { - vllm::cutlass_scaled_mm_sm90_fp8(c, a, b, a_scales, b_scales, bias); - } else { - TORCH_CHECK(a.dtype() == torch::kInt8); - vllm::cutlass_scaled_mm_sm90_int8(c, a, b, a_scales, b_scales, bias); - } - } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; - - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); - - vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); - } + dispatch_scaled_mm(c, a, b, a_scales, b_scales, bias, + vllm::cutlass_scaled_mm_sm90_fp8, + vllm::cutlass_scaled_mm_sm90_int8, + vllm::cutlass_scaled_mm_blockwise_sm90_fp8); } void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 54b63894e4cbcd222ad48415cb4850c2a3fa3b14..3c258ddce61e626d2efce4bb5891cf5f9d3ea00a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -29,7 +29,8 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& a_scales, torch::Tensor const& b_scales, std::optional const& bias); - +#endif +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 void cutlass_moe_mm_sm90( torch::Tensor& out_tensors, torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, torch::Tensor const& a_scales, @@ -37,12 +38,6 @@ void cutlass_moe_mm_sm90( torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, torch::Tensor const& b_strides, torch::Tensor const& c_strides); -void get_cutlass_moe_mm_data_caller( - const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, - torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); - #endif #if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100 @@ -53,6 +48,15 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a, std::optional const& bias); #endif +#if defined(ENABLE_SCALED_MM_SM90) && ENABLE_SCALED_MM_SM90 || \ + defined(ENABLE_SCALED_MM_SM100) && ENABLE_SCALED_MM_SM100 +void get_cutlass_moe_mm_data_caller( + const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, + torch::Tensor& input_permutation, torch::Tensor& output_permutation, + const int64_t num_experts, const int64_t n, const int64_t k); +#endif + void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -110,6 +114,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { #if defined CUDA_VERSION if (cuda_device_capability >= 90 && cuda_device_capability < 100) { return CUDA_VERSION >= 12000; + } else if (cuda_device_capability >= 100) { + return CUDA_VERSION >= 12080; } #endif @@ -222,7 +228,8 @@ void get_cutlass_moe_mm_data( // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); -#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 +#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ + (defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90) get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1, problem_sizes2, input_permutation, output_permutation, num_experts, n, k); diff --git a/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..45ec3d29ce045759dfab8521f61ab7c6e0222a13 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu @@ -0,0 +1,402 @@ +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include + +using namespace cute; + +template +__global__ void __get_group_gemm_starts( + ElementAB** a_offsets, ElementAB** b_offsets, ElementC** out_offsets, + ElementSF** a_scales_offsets, ElementSF** b_scales_offsets, + ElementAccumulator** alpha_offsets, LayoutSFA* layout_sfa_base_as_int, + LayoutSFB* layout_sfb_base_as_int, ElementAB* a_base_as_int, + ElementAB* b_base_as_int, ElementC* out_base_as_int, + ElementSF* a_scales_base_as_int, ElementSF* b_scales_base_as_int, + ElementAccumulator* alphas_base_as_int, const int32_t* expert_offsets, + const int32_t* sf_offsets, const int32_t* problem_sizes_as_shapes, + const int K, const int N) { + int64_t expert_id = threadIdx.x; + if (expert_id >= gridDim.x * blockDim.x) { + return; + } + // Originally int32_t but upcasting to int64_t to avoid overflow + // during offset calculations + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t sf_offset = static_cast(sf_offsets[expert_id]); + // size for block in block scale. + int64_t group_size = 16; + int64_t m = static_cast(problem_sizes_as_shapes[expert_id * 3]); + int64_t n = static_cast(problem_sizes_as_shapes[expert_id * 3 + 1]); + int64_t k = static_cast(problem_sizes_as_shapes[expert_id * 3 + 2]); + assert((m >= 0 && n == N && k == K && k % 2 == 0) && + "unexpected problem sizes"); + + int64_t half_k = static_cast(k / 2); + int64_t group_k = static_cast(k / group_size); + // Shape of A as uint8/byte = [M, K // 2] + // Shape of B as uint8/byte = [E, N, K // 2] + a_offsets[expert_id] = a_base_as_int + expert_offset * half_k; + + b_offsets[expert_id] = b_base_as_int + expert_id * n * half_k; + // Shape of C = [M, N] + out_offsets[expert_id] = out_base_as_int + expert_offset * n; + // Shape of a_scale = [sum(sf_sizes), K // group_size] + a_scales_offsets[expert_id] = a_scales_base_as_int + sf_offset * group_k; + + assert((reinterpret_cast(a_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + + // Shape of B scale = [E, N, K // group_size] + b_scales_offsets[expert_id] = b_scales_base_as_int + expert_id * n * group_k; + assert((reinterpret_cast(b_scales_offsets[expert_id]) % 128) == + 0 && + "TMA requires 128-byte alignment"); + // Shape of alpha = [E] + alpha_offsets[expert_id] = alphas_base_as_int + expert_id; + + LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id; + + *layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); + *layout_sfb_ptr = ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape( + static_cast(m), static_cast(n), static_cast(k), 1)); +} + +#define __CALL_GET_STARTS_KERNEL_BLOCKSCALE(ELEMENT_AB_TYPE, SF_TYPE, \ + TENSOR_C_TYPE, C_TYPE, LayoutSFA, \ + LayoutSFB, ScaleConfig) \ + else if (out_tensors.dtype() == TENSOR_C_TYPE) { \ + __get_group_gemm_starts \ + <<<1, num_experts, 0, stream>>>( \ + static_cast(a_starts.data_ptr()), \ + static_cast(b_starts.data_ptr()), \ + static_cast(out_starts.data_ptr()), \ + static_cast(a_scales_starts.data_ptr()), \ + static_cast(b_scales_starts.data_ptr()), \ + static_cast(alpha_starts.data_ptr()), \ + reinterpret_cast(layout_sfa.data_ptr()), \ + reinterpret_cast(layout_sfb.data_ptr()), \ + static_cast(a_tensors.data_ptr()), \ + static_cast(b_tensors.data_ptr()), \ + static_cast(out_tensors.data_ptr()), \ + static_cast(a_scales.data_ptr()), \ + static_cast(b_scales.data_ptr()), \ + static_cast(alphas.data_ptr()), \ + static_cast(expert_offsets.data_ptr()), \ + static_cast(sf_offsets.data_ptr()), \ + static_cast(problem_sizes.data_ptr()), K, N); \ + } + +template +void run_get_group_gemm_starts( + const torch::Tensor& a_starts, const torch::Tensor& b_starts, + const torch::Tensor& out_starts, const torch::Tensor& a_scales_starts, + const torch::Tensor& b_scales_starts, const torch::Tensor& alpha_starts, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + /*these are used for their base addresses*/ + torch::Tensor const& a_tensors, torch::Tensor const& b_tensors, + torch::Tensor const& out_tensors, torch::Tensor const& a_scales, + torch::Tensor const& b_scales, torch::Tensor const& alphas, + torch::Tensor const& expert_offsets, torch::Tensor const& sf_offsets, + torch::Tensor const& problem_sizes, int M, int N, int K) { + int num_experts = (int)expert_offsets.size(0); + auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); + + TORCH_CHECK(out_tensors.size(1) == N, + "Output tensor shape doesn't match expected shape"); + TORCH_CHECK(K / 2 == b_tensors.size(2), + "b_tensors(dim = 2) and a_tensors(dim = 1) trailing" + " dimension must match"); + if (false) { + } + //(ELEMENT_AB_TYPE, BS_TYPE, TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, + // ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE( + cutlass::float_e2m1_t, cutlass::float_ue4m3_t, torch::kBFloat16, + cutlass::bfloat16_t, LayoutSFA, LayoutSFB, ScaleConfig) + __CALL_GET_STARTS_KERNEL_BLOCKSCALE(cutlass::float_e2m1_t, + cutlass::float_ue4m3_t, torch::kFloat16, + half, LayoutSFA, LayoutSFB, ScaleConfig) + else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } +} + +template +void run_fp4_blockwise_scaled_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets, int M, + int N, int K) { + using ProblemShape = + cutlass::gemm::GroupProblemShape>; + using ElementType = cutlass::float_e2m1_t; + using ElementSFType = cutlass::float_ue4m3_t; + using ElementA = cutlass::nv_float4_t; + using ElementB = cutlass::nv_float4_t; + + using ElementC = OutType; + using ElementD = ElementC; + using ElementAccumulator = float; + // Layout definitions + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = LayoutC; + + // Alignment constraints + static constexpr int AlignmentA = 32; + static constexpr int AlignmentB = 32; + static constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; + static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + + // Architecture definitions + using ArchTag = cutlass::arch::Sm100; + using EpilogueOperatorClass = + cutlass::arch::OpClassTensorOp; // Epilogue Operator class tag + using MainloopOperatorClass = + cutlass::arch::OpClassBlockScaledTensorOp; // Mainloop Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based + // on the tile size + + using ClusterShape = Shape<_1, _1, _1>; + struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = cutlass::gemm:: + KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100; // Kernel to launch + using EpilogueSchedule = + cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; // Epilogue to launch + }; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, EpilogueOperatorClass, typename MMA1SMConfig::MmaTileShape, + ClusterShape, Shape<_128, _64>, ElementAccumulator, + ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD, + LayoutC*, AlignmentD, + typename MMA1SMConfig::EpilogueSchedule>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, MainloopOperatorClass, ElementA, LayoutA*, AlignmentA, + ElementB, LayoutB*, AlignmentB, ElementAccumulator, + typename MMA1SMConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMA1SMConfig::KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm1SM = cutlass::gemm::device::GemmUniversalAdapter; + using Gemm = Gemm1SM; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using ScaleConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + + using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape; + int num_experts = static_cast(expert_offsets.size(0)); + auto options_int = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int); + torch::Tensor out_ptrs = torch::empty(num_experts, options_int); + torch::Tensor a_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor b_scales_ptrs = torch::empty(num_experts, options_int); + torch::Tensor alpha_ptrs = torch::empty(num_experts, options_int); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int); + torch::Tensor c_strides1 = + torch::full({num_experts}, output.stride(0), options_int); + torch::Tensor a_strides1 = + torch::full({num_experts}, a.stride(0) * 2, options_int); + torch::Tensor b_strides1 = + torch::full({num_experts}, b.stride(1) * 2, options_int); + + run_get_group_gemm_starts( + a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, alpha_ptrs, + layout_sfa, layout_sfb, a, b, output, a_blockscale, b_blockscales, alphas, + expert_offsets, sf_offsets, problem_sizes, M, N, K); + + // Create an instance of the GEMM + Gemm gemm_op; + + // Initialize problem_sizes_as_shapes correctly + UnderlyingProblemShape* problem_sizes_as_shapes = + static_cast(problem_sizes.data_ptr()); + + // Set the Scheduler info + cutlass::KernelHardwareInfo hw_info; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail:: + PersistentTileSchedulerSm100GroupParams< + typename ProblemShape::UnderlyingProblemShape>::RasterOrderOptions; + typename Gemm::GemmKernel::TileSchedulerArguments scheduler; + scheduler.raster_order = RasterOrderOptions::AlongM; + hw_info.device_id = a.get_device(); + static std::unordered_map cached_sm_counts; + if (cached_sm_counts.find(hw_info.device_id) == cached_sm_counts.end()) { + cached_sm_counts[hw_info.device_id] = + cutlass::KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + hw_info.sm_count = min(cached_sm_counts[hw_info.device_id], INT_MAX); + + // Mainloop Arguments + typename GemmKernel::MainloopArguments mainloop_args{ + static_cast(a_ptrs.data_ptr()), + static_cast(a_strides1.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(b_strides1.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}; + + // Epilogue Arguments + typename GemmKernel::EpilogueArguments epilogue_args{ + {}, // epilogue.thread + nullptr, + static_cast(c_strides1.data_ptr()), + static_cast(out_ptrs.data_ptr()), + static_cast(c_strides1.data_ptr())}; + auto& fusion_args = epilogue_args.thread; + fusion_args.alpha_ptr_array = + reinterpret_cast(alpha_ptrs.data_ptr()); + fusion_args.dAlpha = {_0{}, _0{}, 1}; + + // Gemm Arguments + typename GemmKernel::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, problem_sizes_as_shapes, nullptr}, + mainloop_args, + epilogue_args, + hw_info, + scheduler}; + + size_t workspace_size = Gemm::get_workspace_size(args); + auto const workspace_options = + torch::TensorOptions().dtype(torch::kUInt8).device(a.device()); + auto workspace = torch::empty(workspace_size, workspace_options); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + auto can_implement_status = gemm_op.can_implement(args); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + // Run the GEMM + auto status = gemm_op.initialize(args, workspace.data_ptr()); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm_op.run(args, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte; +constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn; + +#define CHECK_TYPE(x, st, m) \ + TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m) +#define CHECK_TH_CUDA(x, m) \ + TORCH_CHECK(x.is_cuda(), m, ": must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, ": must be contiguous.") +#define CHECK_INPUT(x, st, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); \ + CHECK_TYPE(x, st, m) + +void cutlass_fp4_group_mm( + torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& a_blockscale, const torch::Tensor& b_blockscales, + const torch::Tensor& alphas, const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, const torch::Tensor& sf_offsets) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + // Input validation + CHECK_INPUT(a, FLOAT4_E2M1X2, "a"); + CHECK_INPUT(b, FLOAT4_E2M1X2, "b"); + CHECK_INPUT(a_blockscale, SF_DTYPE, "a_blockscale"); + CHECK_INPUT(b_blockscales, SF_DTYPE, "b_blockscales"); + CHECK_INPUT(alphas, at::ScalarType::Float, "alphas"); + + TORCH_CHECK(a_blockscale.dim() == 2, + "expected a_blockscale to be of shape [num_experts, rounded_m," + " k // group_size], observed rank: ", + a_blockscale.dim()) + TORCH_CHECK(b_blockscales.dim() == 3, + "expected b_blockscale to be of shape: " + " [num_experts, n, k // group_size], observed rank: ", + b_blockscales.dim()) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be a 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have the shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32."); + + int M = static_cast(a.size(0)); + int N = static_cast(b.size(1)); + int E = static_cast(b.size(0)); + int K = static_cast(2 * b.size(2)); + + if (output.scalar_type() == torch::kBFloat16) { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } else { + run_fp4_blockwise_scaled_group_mm( + output, a, b, a_blockscale, b_blockscales, alphas, problem_sizes, + expert_offsets, sf_offsets, M, N, K); + } +#else + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_fp4_group_mm kernel, vLLM must " + "be compiled with ENABLE_NVFP4 for SM100+ and CUDA " + "12.8 or above."); +#endif +} diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu new file mode 100644 index 0000000000000000000000000000000000000000..076c4a085337b434a4cc5e6f6d1caabaff5ceed8 --- /dev/null +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -0,0 +1,404 @@ +#include + +#include +#include + +#include +#include + +template +struct TypeConverter { + using Type = half2; +}; // keep for generality + +template <> +struct TypeConverter { + using Type = half; +}; + +template <> +struct TypeConverter { + using Type = half2; +}; + +template <> +struct TypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct TypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +#define ELTS_PER_THREAD 8 + +constexpr int CVT_FP4_ELTS_PER_THREAD = 8; +constexpr int CVT_FP4_SF_VEC_SIZE = 16; + +// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]), + "f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7])); + return val; +#else + return 0; +#endif +} + +// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t). +inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + uint32_t val; + asm volatile( + "{\n" + ".reg .b8 byte0;\n" + ".reg .b8 byte1;\n" + ".reg .b8 byte2;\n" + ".reg .b8 byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "}" + : "=r"(val) + : "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y), + "f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y)); + return val; +#else + return 0; +#endif +} + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +template +__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx, + int numCols, + SFType* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 || + CVT_FP4_NUM_THREADS_PER_SF == 2); + + // One pair of threads write one SF to global memory. + // TODO: stage through smem for packed STG.32 + // is it better than STG.8 from 4 threads ? + if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) { + // SF vector index (16 elements share one SF in the K dimension). + int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF; + int32_t mIdx = rowIdx; + + // SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)] + // --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx] + + int32_t mTileIdx = mIdx / (32 * 4); + // SF vector size 16. + int factor = CVT_FP4_SF_VEC_SIZE * 4; + int32_t numKTiles = (numCols + factor - 1) / factor; + int64_t mTileStride = numKTiles * 32 * 4 * 4; + + int32_t kTileIdx = (kIdx / 4); + int64_t kTileStride = 32 * 4 * 4; + + // M tile layout [32, 4] is column-major. + int32_t outerMIdx = (mIdx % 32); + int64_t outerMStride = 4 * 4; + + int32_t innerMIdx = (mIdx % (32 * 4)) / 32; + int64_t innerMStride = 4; + + int32_t innerKIdx = (kIdx % 4); + int64_t innerKStride = 1; + + // Compute the global offset. + int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride + + outerMIdx * outerMStride + innerMIdx * innerMStride + + innerKIdx * innerKStride; + + return reinterpret_cast(SFout) + SFOffset; + } +#endif + return nullptr; +} + +// Define a 16 bytes packed data type. +template +struct PackedVec { + typename TypeConverter::Type elts[4]; +}; + +template <> +struct PackedVec<__nv_fp8_e4m3> { + __nv_fp8x2_e4m3 elts[8]; +}; + +// Quantizes the provided PackedVec into the uint32_t output +template +__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, + uint8_t* SFout) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + // Get absolute maximum values among the local 8 values. + auto localMax = __habs2(vec.elts[0]); + + // Local maximum value. + #pragma unroll + for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + localMax = __hmax2(localMax, __habs2(vec.elts[i])); + } + + // Get the absolute maximum among all 16 values (two threads). + localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax); + // Get the final absolute maximum values. + float vecMax = float(__hmax(localMax.x, localMax.y)); + + // Get the SF (max value of the vector / max value of e2m1). + // maximum value of e2m1 = 6.0. + // TODO: use half as compute data type. + float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f)); + // 8 bits representation of the SF. + uint8_t fp8SFVal; + // Write the SF to global memory (STG.8). + if constexpr (UE8M0_SF) { + // Extract the 8 exponent bits from float32. + // float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits. + uint32_t tmp = reinterpret_cast(SFValue) >> 23; + fp8SFVal = tmp & 0xff; + // Convert back to fp32. + reinterpret_cast(SFValue) = tmp << 23; + } else { + // Here SFValue is always positive, so E4M3 is the same as UE4M3. + __nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue); + reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp; + // Convert back to fp32. + SFValue = float(tmp); + } + // Get the output scale. + // Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) * + // reciprocal(SFScaleVal)) + float outputScale = + SFValue != 0 ? reciprocal_approximate_ftz( + SFValue * reciprocal_approximate_ftz(SFScaleVal)) + : 0.0f; + + if (SFout) { + // Write the SF to global memory (STG.8). + *SFout = fp8SFVal; + } + + // Convert the input to float. + float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2]; + + #pragma unroll + for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { + if constexpr (std::is_same_v) { + fp2Vals[i] = __half22float2(vec.elts[i]); + } else { + fp2Vals[i] = __bfloat1622float2(vec.elts[i]); + } + fp2Vals[i].x *= outputScale; + fp2Vals[i].y *= outputScale; + } + + // Convert to e2m1 values. + uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals); + + // Write the e2m1 values to global memory. + return e2m1Vec; +#else + return 0; +#endif +} + +// Use UE4M3 by default. +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(512, 4) cvt_fp16_to_fp4( +#else +cvt_fp16_to_fp4( +#endif + int32_t numRows, int32_t numCols, Type const* in, float const* SFScale, + uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, + uint32_t* output_scale_offset_by_experts, int n_experts) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + using PackedVec = PackedVec; + static constexpr int CVT_FP4_NUM_THREADS_PER_SF = + (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); + static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, + "Vec size is not matched."); + + // Input tensor row/col loops. + for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { + for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD; + colIdx += blockDim.x) { + int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + PackedVec in_vec = reinterpret_cast(in)[inOffset]; + // Get the output tensor offset. + // Same as inOffset because 8 elements are packed into one uint32_t. + int64_t outOffset = inOffset; + auto& out_pos = out[outOffset]; + + // Find index within the experts. + int rowIdx_in_expert = 0; + int expert_idx = 0; + for (int i = 0; i < n_experts; i++) { + if (rowIdx >= input_offset_by_experts[i] && + rowIdx < input_offset_by_experts[i + 1]) { + rowIdx_in_expert = rowIdx - input_offset_by_experts[i]; + expert_idx = i; + break; + } + } + + // Get the global scaling factor, which will be applied to the SF. + // Note SFScale is the same as next GEMM's alpha, which is + // (448.f / (Alpha_A / 6.f)). + float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[expert_idx]; + + int factor = CVT_FP4_SF_VEC_SIZE * 4; + // The actual output_scales dim is computed from the padded numCols. + int32_t numCols_padded = (numCols + factor - 1) / factor * factor; + int numCols_SFout = numCols_padded / CVT_FP4_SF_VEC_SIZE / 4; + uint32_t* SFout_in_expert = + SFout + output_scale_offset_by_experts[expert_idx] * numCols_SFout; + + auto sf_out = + cvt_quant_to_fp4_get_sf_out_offset( + rowIdx_in_expert, colIdx, numCols, SFout_in_expert); + + out_pos = + cvt_warp_fp16_to_fp4(in_vec, SFScaleVal, sf_out); + } + } +#endif +} + +template +void quant_impl(void* output, void* output_scale, void* input, + void* input_global_scale, void* input_offset_by_experts, + void* output_scale_offset_by_experts, int m_topk, int k, + int n_experts, cudaStream_t stream) { + // TODO: this multiProcessorCount should be cached. + int device; + cudaGetDevice(&device); + int multiProcessorCount; + cudaDeviceGetAttribute(&multiProcessorCount, cudaDevAttrMultiProcessorCount, + device); + + // Grid, Block size. + // Each thread converts 8 values. + dim3 block(std::min(int(k / ELTS_PER_THREAD), 512)); + // Get number of blocks per SM (assume we can fully utilize the SM). + int const numBlocksPerSM = 2048 / block.x; + dim3 grid(std::min(int(m_topk), multiProcessorCount * numBlocksPerSM)); + + cvt_fp16_to_fp4<<>>( + m_topk, k, reinterpret_cast(input), + reinterpret_cast(input_global_scale), + reinterpret_cast(output), + reinterpret_cast(output_scale), + reinterpret_cast(input_offset_by_experts), + reinterpret_cast(output_scale_offset_by_experts), n_experts); +} + +/*Quantization entry for fp4 experts quantization*/ +#define CHECK_TH_CUDA(x, m) TORCH_CHECK(x.is_cuda(), m, "must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x, m) \ + TORCH_CHECK(x.is_contiguous(), m, "must be contiguous") +#define CHECK_INPUT(x, m) \ + CHECK_TH_CUDA(x, m); \ + CHECK_CONTIGUOUS(x, m); + +constexpr auto HALF = at::ScalarType::Half; +constexpr auto BF16 = at::ScalarType::BFloat16; +constexpr auto FLOAT = at::ScalarType::Float; +constexpr auto INT = at::ScalarType::Int; +constexpr auto UINT8 = at::ScalarType::Byte; + +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { + CHECK_INPUT(output, "output must be a CUDA tensor"); + CHECK_INPUT(output_scale, "output_scale must be a CUDA tensor"); + CHECK_INPUT(input, "input must be a CUDA tensor"); + CHECK_INPUT(input_global_scale, "input_global_scale must be a CUDA tensor"); + CHECK_INPUT(input_offset_by_experts, + "input_offset_by_experts must be a CUDA tensor"); + CHECK_INPUT(output_scale_offset_by_experts, + "output_scale_offset_by_experts must be a CUDA tensor"); + + TORCH_CHECK(output.dim() == 2); + TORCH_CHECK(output_scale.dim() == 2); + TORCH_CHECK(input.dim() == 2); + TORCH_CHECK(input_global_scale.dim() == 1); + TORCH_CHECK(input_offset_by_experts.dim() == 1); + TORCH_CHECK(output_scale_offset_by_experts.dim() == 1); + + TORCH_CHECK(input.scalar_type() == HALF || input.scalar_type() == BF16); + TORCH_CHECK(input_global_scale.scalar_type() == FLOAT); + TORCH_CHECK(input_offset_by_experts.scalar_type() == INT); + TORCH_CHECK(output_scale_offset_by_experts.scalar_type() == INT); + // output is uint8 (two nvfp4 values are packed into one uint8) + // output_scale is int32 (four fp8 values are packed into one int32) + TORCH_CHECK(output.scalar_type() == UINT8); + TORCH_CHECK(output_scale.scalar_type() == INT); + + const int BLOCK_SIZE = 16; + auto m_topk = input.size(0); + auto k = input.size(1); + TORCH_CHECK(k % BLOCK_SIZE == 0, "k must be a multiple of 16"); + auto n_experts = input_global_scale.size(0); + TORCH_CHECK(input_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output_scale_offset_by_experts.size(0) == n_experts + 1); + TORCH_CHECK(output.size(0) == m_topk); + TORCH_CHECK(output.size(1) == k / 2); + int scales_k = k / BLOCK_SIZE; + // 4 means the swizzle requirement by nvidia nvfp4. + int padded_k = (scales_k + (4 - 1)) / 4 * 4; + // 4 means 4 fp8 values are packed into one int32 + TORCH_CHECK(output_scale.size(1) * 4 == padded_k); + + auto in_dtype = input.dtype(); + at::cuda::CUDAGuard device_guard{(char)input.get_device()}; + const cudaStream_t stream = + at::cuda::getCurrentCUDAStream(input.get_device()); + if (in_dtype == at::ScalarType::Half) { + quant_impl(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, k, + n_experts, stream); + } else if (in_dtype == at::ScalarType::BFloat16) { + quant_impl<__nv_bfloat16>(output.data_ptr(), output_scale.data_ptr(), + input.data_ptr(), input_global_scale.data_ptr(), + input_offset_by_experts.data_ptr(), + output_scale_offset_by_experts.data_ptr(), m_topk, + k, n_experts, stream); + } else { + TORCH_CHECK(false, "Expected input data type to be half or bfloat16"); + } +} \ No newline at end of file diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index b1426c43b456b5d8ee57b68934a273bacb351cb0..badbb7e310df08260af594ba0fbc7512cd3bb91d 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -23,10 +23,32 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output, torch::Tensor const& input_sf); #endif +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 +void scaled_fp4_experts_quant_sm100a( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts); +#endif + void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input, torch::Tensor& output_sf, torch::Tensor const& input_sf) { #if defined ENABLE_NVFP4 && ENABLE_NVFP4 return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf); #endif - TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization"); + TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization kernel"); +} + +void scaled_fp4_experts_quant( + torch::Tensor& output, torch::Tensor& output_scale, + torch::Tensor const& input, torch::Tensor const& input_global_scale, + torch::Tensor const& input_offset_by_experts, + torch::Tensor const& output_scale_offset_by_experts) { +#if defined ENABLE_NVFP4 && ENABLE_NVFP4 + return scaled_fp4_experts_quant_sm100a( + output, output_scale, input, input_global_scale, input_offset_by_experts, + output_scale_offset_by_experts); +#endif + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled nvfp4 experts quantization kernel"); } diff --git a/csrc/quantization/fp8/fp8_marlin.cu b/csrc/quantization/fp8/fp8_marlin.cu deleted file mode 100644 index 376bbd498ca52c8d025d81a8c2a973a843a5b659..0000000000000000000000000000000000000000 --- a/csrc/quantization/fp8/fp8_marlin.cu +++ /dev/null @@ -1,1311 +0,0 @@ -/* - * Modified by Neural Magic - * Copyright (C) Marlin.2024 Elias Frantar - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * Adapted from https://github.com/IST-DASLab/marlin - */ - -#include "../gptq_marlin/marlin.cuh" -#include "../gptq_marlin/marlin_dtypes.cuh" - -#include "core/registration.h" - -using namespace marlin; - -#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ - static_assert(std::is_same::value || \ - std::is_same::value, \ - "only float16 and bfloat16 is supported"); - -template -inline std::string str(T x) { - return std::to_string(x); -} - -namespace fp8_marlin { - -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) {} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else - -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16 -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -template -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); -} - -template <> -__device__ inline typename ScalarType::FragB dequant_8bit(int q) { - // Constants for FP8 (E4M3) and FP16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5; - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to FP16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant_8bit(int q) { - // Constants for FP8 (E4M3) and BF16 formats - constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8; - constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; - - // Calculate MASK for extracting mantissa and exponent - constexpr int MASK1 = 0x80000000; - constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2 & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - // Final MASK value: 0x7F007F00 - - // Extract and shift FP8 values to BF16 format - int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - // Construct and apply exponent bias - constexpr int BIAS_OFFSET = - (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent - // position - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - - // Convert to bfloat162 and apply bias - typename ScalarType::FragB frag_b; - // Note: reverse indexing is intentional because weights are permuted - frag_b[1] = __hmul2(*reinterpret_cast(&Out1), bias_reg); - frag_b[0] = __hmul2(*reinterpret_cast(&Out2), bias_reg); - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - -template shared - // fetch pipeline - const int group_blocks = -1 // number of consecutive 16x16 blocks - // with a separate quantization scale - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks // extra global storage for barrier synchronization -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - - constexpr int pack_factor = 32 / num_bits; - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * prob_k / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = prob_k / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - int b_sh_wr = threadIdx.x * b_thread_vecs; - int b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - int slice_k_start = tb_k * slice_row; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - int s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // We scale a `half2` tile in row-major layout for column-wise quantization. - int s_sh_rd = - 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_s = sh_g_idx + (stages * g_idx_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - int b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - int b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - - frag_b0 = dequant_8bit(b_quant_0); - frag_b1 = dequant_8bit(b_quant_1); - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - int red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - int c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - ((scalar_t2*)sh)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - C[c_gl_wr] = sh[c_sh_rd]; - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - - thread_block_reduce(); - - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float(reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float(reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float(reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float(reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - - if (slice_count > 1) { // only globally reduce if there is more than one - // block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - global_reduce(slice_idx == 0, last); - barrier_release(&locks[slice_col], last); - } - if (last) // only the last block in a slice actually writes the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - - start_pipes(); - } - } - } -} - - #define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \ - THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \ - else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \ - locks); \ - } - -typedef struct { - int thread_k; - int thread_n; - int num_threads; -} thread_config_t; - -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - -thread_config_t small_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {128, 128, 256}, - {64, 128, 128}, - {128, 64, 128}, -}; - -thread_config_t large_batch_thread_configs[] = { - // Ordered by priority - - // thread_k, thread_n, num_threads - {64, 256, 256}, - {64, 128, 128}, - {128, 64, 128}, - -}; - -int get_scales_cache_size(thread_config_t const& th_config, int prob_m, - int prob_n, int prob_k, int num_bits, - int group_size) { - int tb_n = th_config.thread_n; - - // Get max scale groups per thread-block - // Fixed for channelwise - int tb_groups = 1; - int tb_scales = tb_groups * tb_n * 2; - - return tb_scales * pipe_stages; -} - -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { - int pack_factor = 32 / num_bits; - - // Get B size - int tb_k = th_config.thread_k; - int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } - } - - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); -} - -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int group_size, int max_shared_mem) { - // Sanity - if (th_config.thread_k == -1 || th_config.thread_n == -1 || - th_config.num_threads == -1) { - return false; - } - - // Verify K/N are divisible by thread K/N - if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { - return false; - } - - // Verify min for thread K/N - if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { - return false; - } - - // num_threads must be at least 128 (= 4 warps) - if (th_config.num_threads < 128) { - return false; - } - - // Determine cache for scales - int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n, - prob_k, num_bits, group_size); - - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; -} - -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } - - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } - - return exec_config_t{0, {-1, -1, -1}}; -} - - #define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ - __CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) - -template -void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m, - int prob_n, int prob_k, void* workspace, int num_bits, - int num_groups, int group_size, int dev, - cudaStream_t stream, int thread_k, int thread_n, int sms, - int max_par) { - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, - ", ", prob_n, ", ", prob_k, "]"); - - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits, - group_size, max_shared_mem); - } - - TORCH_CHECK( - exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m, - prob_n, prob_k, num_bits, group_size, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m, - ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - - int group_blocks = -1; - - const int4* A_ptr = (const int4*)A; - const int4* B_ptr = (const int4*)B; - int4* C_ptr = (int4*)C; - const int4* s_ptr = (const int4*)s; - - int* locks = (int*)workspace; - - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; - } - - // Define kernel configurations - if (false) { - } - CALL_IF(8, 32, 2, 256) - CALL_IF(8, 16, 4, 256) - CALL_IF(8, 8, 8, 256) - CALL_IF(8, 8, 4, 128) - CALL_IF(8, 4, 8, 128) - else { - TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + - str(prob_n) + ", " + str(prob_k) + "]" + - ", num_groups = " + str(num_groups) + - ", group_size = " + str(group_size) + - ", thread_m_blocks = " + str(thread_m_blocks) + - ", thread_n_blocks = " + str(thread_n_blocks) + - ", thread_k_blocks = " + str(thread_k_blocks)); - } - - A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; - } -} - -} // namespace fp8_marlin - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k) { - // Verify num_bits - TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits); - int pack_factor = 32 / num_bits; - - // Verify A - TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0), - ", size_m = ", size_m); - TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1), - ", size_k = ", size_k); - - // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), - "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; - TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, - ", actual_size_n = ", actual_size_n); - - // Verify device and strides - TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); - TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); - - TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); - TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); - - TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); - TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - - // Alloc buffers - const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); - auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); - torch::Tensor c = torch::empty({size_m, size_n}, options); - - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Detect groupsize and act_order - int num_groups = -1; - int group_size = -1; - - int b_rank = b_scales.sizes().size(); - TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); - TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), - " is not size_n = ", size_n); - // Channelwise only for FP8 - TORCH_CHECK(b_scales.size(0) == 1) - num_groups = b_scales.size(0); - - // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; - TORCH_CHECK(workspace.numel() >= min_workspace_size, - "workspace.numel = ", workspace.numel(), - " is below min_workspace_size = ", min_workspace_size); - - int dev = a.get_device(); - if (a.scalar_type() == at::ScalarType::Half) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - b_scales.data_ptr(), size_m, size_n, size_k, - workspace.data_ptr(), num_bits, num_groups, group_size, dev, - at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { - fp8_marlin::marlin_mm_f16i4( - a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), b_scales.data_ptr(), size_m, - size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size, - dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par); - } else { - TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16"); - } - - return c; -} - -#endif - -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { - m.impl("fp8_marlin_gemm", &fp8_marlin_gemm); -} \ No newline at end of file diff --git a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu index 2b6ab7fcec9026bcc5e3a1d34eb76bd68b47e2c9..95aa92e25b30c22f2465a906b4e9d1b923d127e3 100644 --- a/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu +++ b/csrc/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu @@ -96,7 +96,7 @@ void rms_norm_dynamic_per_token_quant_dispatch( std::optional const& scale_ub, std::optional& residual) { int32_t hidden_size = input.size(-1); - int32_t num_tokens = input.numel() / hidden_size; + auto num_tokens = input.numel() / hidden_size; dim3 grid(num_tokens); dim3 block(std::min(hidden_size, 1024)); diff --git a/csrc/quantization/fused_kernels/quant_conversions.cuh b/csrc/quantization/fused_kernels/quant_conversions.cuh index edd3b2d93f344bd950c849025d8fc3ca9039b968..5256d21598a990aec8394f07bccc270757271804 100644 --- a/csrc/quantization/fused_kernels/quant_conversions.cuh +++ b/csrc/quantization/fused_kernels/quant_conversions.cuh @@ -21,7 +21,13 @@ static __device__ __forceinline__ int8_t float_to_int8_rn(float const x) { // round float dst = std::nearbyint(x); // saturate - dst = std::clamp(dst, i8_min, i8_max); + + // See https://github.com/pytorch/pytorch/issues/127666 + // See https://github.com/llvm/llvm-project/issues/95183 + // hip-clang std::clamp __glibcxx_assert_fail host function when building on + // Arch/gcc14. The following replaces std::clamp usage with similar logic + // dst = std::clamp(dst, i8_min, i8_max); + dst = (dst < i8_min) ? i8_min : (dst > i8_max) ? i8_max : dst; return static_cast(dst); #else // CUDA path diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 56b78f1834d15dd5b1ebdc5dd7a8347de379e505..6c146c3fb6fdeea1a83d7ec0f54e9bd7100d9bb6 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -13,6 +13,7 @@ #include "mmvq.cuh" #include "mmq.cuh" #include "moe.cuh" +#include "moe_vec.cuh" // Q8 gemv template @@ -377,6 +378,142 @@ torch::Tensor ggml_moe_a8(torch::Tensor X, // input return Y; } +torch::Tensor ggml_moe_a8_vec(torch::Tensor X, // input + torch::Tensor W, // expert weights + torch::Tensor topk_ids, int64_t top_k, + int64_t type, int64_t row, int64_t tokens) { + int col = X.sizes()[1]; + const int padded = (col + 512 - 1) / 512 * 512; + const at::cuda::OptionalCUDAGuard device_guard(device_of(X)); + auto options = torch::TensorOptions().dtype(X.dtype()).device(W.device()); + at::Tensor Y = torch::zeros({tokens * top_k, row}, options); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + options = torch::TensorOptions().dtype(torch::kInt32).device(W.device()); + at::Tensor quant_X = torch::empty({tokens, padded / 32 * 9}, options); + VLLM_DISPATCH_FLOATING_TYPES(X.scalar_type(), "ggml_moe_vec_a8", [&] { + quantize_row_q8_1_cuda((scalar_t*)X.data_ptr(), + (void*)quant_X.data_ptr(), col, tokens, + stream); + switch (type) { + case 2: + moe_vec_q4_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 3: + moe_vec_q4_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 6: + moe_vec_q5_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 7: + moe_vec_q5_1_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 8: + moe_vec_q8_0_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 10: + moe_vec_q2_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 11: + moe_vec_q3_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 12: + moe_vec_q4_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 13: + moe_vec_q5_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 14: + moe_vec_q6_K_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 16: + moe_vec_iq2_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 17: + moe_vec_iq2_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 18: + moe_vec_iq3_xxs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 19: + moe_vec_iq1_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 20: + moe_vec_iq4_nl_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 21: + moe_vec_iq3_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 22: + moe_vec_iq2_s_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 23: + moe_vec_iq4_xs_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + case 29: + moe_vec_iq1_m_q8_1_cuda( + (void*)W.data_ptr(), (void*)quant_X.data_ptr(), + (scalar_t*)Y.data_ptr(), (int*)topk_ids.data_ptr(), top_k, tokens, + col, row, quant_X.stride(0), stream); + break; + } + }); + return Y; +} + int64_t ggml_moe_get_block_size(int64_t type) { switch (type) { case 2: diff --git a/csrc/quantization/gguf/moe_vec.cuh b/csrc/quantization/gguf/moe_vec.cuh new file mode 100644 index 0000000000000000000000000000000000000000..60f65a1bfdcba4160ff8f0ee28508027e4c97263 --- /dev/null +++ b/csrc/quantization/gguf/moe_vec.cuh @@ -0,0 +1,338 @@ +// copied and adapted from +// https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-cuda/mmvq.cu +template +static __global__ void moe_vec_q(const void* __restrict__ vx, + const void* __restrict__ vy, + scalar_t* __restrict__ dst, + const int* topk_ids, const int topk, + const int ncols, const int nrows, + const int token_stride) { + const auto row = blockIdx.x * blockDim.y + threadIdx.y; + + const auto token = blockIdx.z / topk; + const auto expert = (topk_ids)[blockIdx.z]; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = vdr * WARP_SIZE / qi; + + // partial sum for each thread + float tmp = 0.0f; + + const block_q_t* x = ((const block_q_t*)vx) + expert * nrows * blocks_per_row; + const block_q8_1* y = + (const block_q8_1*)(((const int*)vy) + token * token_stride); + + for (auto i = threadIdx.x / (qi / vdr); i < blocks_per_row; + i += blocks_per_warp) { + const int ibx = row * blocks_per_row + i; // x block index + + const int iby = i * (qk / QK8_1); // y block index that aligns with ibx + + const int iqs = + vdr * + (threadIdx.x % + (qi / vdr)); // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); + } + + if (threadIdx.x == 0) { + dst[blockIdx.z * nrows + row] = tmp; + } +} + +template +static void moe_vec_q4_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_1_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q8_0_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q2_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q3_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q4_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q5_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_q6_K_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq2_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_xxs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq1_m_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_nl_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q<<>>( + vx, vy, dst, topk_ids, top_k, ncols, nrows, token_stride); +} + +template +static void moe_vec_iq4_xs_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} + +template +static void moe_vec_iq3_s_q8_1_cuda(const void* vx, const void* vy, + scalar_t* dst, const int* topk_ids, + const int top_k, const int tokens, + const int ncols, const int nrows, + const int token_stride, + cudaStream_t stream) { + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, tokens * top_k); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + moe_vec_q + <<>>(vx, vy, dst, topk_ids, top_k, + ncols, nrows, token_stride); +} diff --git a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu index ec0bf2c3cb4bd3e32b10d83a61fe7f35ce2275f2..03bd5964a7fc4f1efda78584c0a44ead4bae8a63 100644 --- a/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu +++ b/csrc/quantization/gptq_allspark/allspark_qgemm_w8a16.cu @@ -9,7 +9,7 @@ at::Tensor as_g_workspace; torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, c10::optional const& b_qzeros, + torch::Tensor const& b_scales, std::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { TORCH_CHECK_NOT_IMPLEMENTED( @@ -347,7 +347,7 @@ struct ComputeTile_W8A16_PerC_MtilexNtilex32_multistage_SM8x_SplitK { for (int n_idx = 0; n_idx < WARP_NITER; ++n_idx) { hmma16816_f32( C_frag[m_idx][n_idx], A_frag[reg_buf_idx][m_idx], - reinterpret_cast(BF_frag[reg_buf_idx][n_idx])); + reinterpret_cast(BF_frag[reg_buf_idx][n_idx])); } } } @@ -918,7 +918,7 @@ void allspark_qgemm_w8a16_perc_ampere( torch::Tensor allspark_w8a16_gemm( torch::Tensor const& a, torch::Tensor const& b_qweight, - torch::Tensor const& b_scales, c10::optional const& b_qzeros, + torch::Tensor const& b_scales, std::optional const& b_qzeros, int64_t n, int64_t group_size, int64_t sm_count, int64_t sm_version, int64_t CUBLAS_M_THRESHOLD, bool has_zp, bool n32k16_reorder) { // Verify device and strides diff --git a/csrc/quantization/gptq_allspark/allspark_repack.cu b/csrc/quantization/gptq_allspark/allspark_repack.cu index ea8eccf040df698d932423e88fd862c9b88857f1..7a5b2f95cc2efcc0f598753d353727535d4b28ae 100644 --- a/csrc/quantization/gptq_allspark/allspark_repack.cu +++ b/csrc/quantization/gptq_allspark/allspark_repack.cu @@ -100,9 +100,9 @@ void rearrange_kn_weight_as_n32k16_order_ldg16( void rearrange_kn_weight_as_n32k16_order( torch::Tensor const& b_qweight, torch::Tensor const& b_scales, - c10::optional const& b_zeros, bool has_zp, + std::optional const& b_zeros, bool has_zp, torch::Tensor& b_qweight_reorder, torch::Tensor& b_scales_reorder, - c10::optional const& b_zeros_reorder, const int64_t K, + std::optional const& b_zeros_reorder, const int64_t K, const int64_t N, const int64_t N_32align) { // Verify device and strides TORCH_CHECK(b_qweight.device().is_cuda(), "b_qweight is not on GPU"); diff --git a/csrc/quantization/gptq_marlin/.gitignore b/csrc/quantization/gptq_marlin/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..77088552b85b4f8614516756a31ac1fb673266f3 --- /dev/null +++ b/csrc/quantization/gptq_marlin/.gitignore @@ -0,0 +1 @@ +kernel_*.cu \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/dequant.h b/csrc/quantization/gptq_marlin/dequant.h new file mode 100644 index 0000000000000000000000000000000000000000..ae0d6c0f20020aafa2f5103aed9a1bc3e0e211bb --- /dev/null +++ b/csrc/quantization/gptq_marlin/dequant.h @@ -0,0 +1,507 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = + reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], + fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], + fp32_intermediates_casted[3], 0x7632); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant(int q, + half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template <> +__device__ inline void dequant( + int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = + (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template <> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template <> +__device__ inline void dequant_fp8_scales(int q, + nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace MARLIN_NAMESPACE_NAME diff --git a/csrc/quantization/gptq_marlin/generate_kernels.py b/csrc/quantization/gptq_marlin/generate_kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..4ac7121ab4e1b63316942cf658e77f32a66681d1 --- /dev/null +++ b/csrc/quantization/gptq_marlin/generate_kernels.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: Apache-2.0 +import glob +import itertools +import os +import subprocess + +import jinja2 + +FILE_HEAD = """ +// auto generated by generate.py +// clang-format off + +#include "kernel.h" +#include "marlin_template.h" + +namespace MARLIN_NAMESPACE_NAME { +""".strip() + +TEMPLATE = ("template __global__ void Marlin<" + "{{scalar_t}}, " + "{{w_type_id}}, " + "{{threads}}, " + "{{thread_m_blocks}}, " + "{{thread_n_blocks}}, " + "{{thread_k_blocks}}, " + "{{'true' if m_block_size_8 else 'false'}}, " + "{{stages}}, " + "{{group_blocks}}, " + "{{'true' if is_zp_float else 'false'}}>" + "( MARLIN_KERNEL_PARAMS );") + +# int8 with zero point case (vllm::kU8) is also supported, +# we don't add it to reduce wheel size. +SCALAR_TYPES = [ + "vllm::kU4", "vllm::kU4B8", "vllm::kU8B128", "vllm::kFE4M3fn", + "vllm::kFE2M1f" +] +THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), + (128, 64, 128)] + +THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] +# group_blocks: +# = 0 : act order case +# = -1 : channelwise quantization +# > 0 : group_size=16*group_blocks +GROUP_BLOCKS = [0, 1, -1, 2, 4, 8] +DTYPES = ["fp16", "bf16"] + + +def remove_old_kernels(): + for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): + subprocess.call(["rm", "-f", filename]) + + +def generate_new_kernels(): + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + all_template_str_list = [] + + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "vllm::kU4B8", "vllm::kU8B128" + ]: + continue + if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) + if m_blocks <= 1 and thread_configs[0] != 128: + continue + if m_blocks > 1 and thread_configs[0] != 64: + continue + + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + # nvfp4 only supports group_size == 16 + if scalar_type == "vllm::kFE2M1f" and group_blocks != 1: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + is_zp_float_list = [False] + if dtype == "fp16" and scalar_type == "vllm::kU4" and \ + group_blocks == 4: + # HQQ (is_zp_float = true) only supports + # 4bit quantization and fp16 + is_zp_float_list.append(True) + + for is_zp_float in is_zp_float_list: + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + group_blocks=group_blocks, + is_zp_float=is_zp_float, + ) + + all_template_str_list.append(template_str) + + file_content = FILE_HEAD + "\n\n" + file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" + filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" + + with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: + f.write(file_content) + + +if __name__ == "__main__": + remove_old_kernels() + generate_new_kernels() diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 83bbd1e6816a80130279d077f48f70d6f3911250..4a242f2050d56f50a72c736ca1a2645d104a4c38 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -19,10 +19,11 @@ * Adapted from https://github.com/IST-DASLab/marlin */ -#include "marlin.cuh" -#include "marlin_dtypes.cuh" -#include "core/scalar_type.hpp" +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif +#include "kernel.h" #include "core/registration.h" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -30,13 +31,12 @@ std::is_same::value, \ "only float16 and bfloat16 is supported"); -template -inline std::string str(T x) { - return std::to_string(x); -} - namespace marlin { +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -44,46 +44,17 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int4* __restrict__ out_int4_ptr, int size_m, int size_k, int lda, int block_rows) {} -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const int group_blocks = -1, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int* locks, // extra global storage for barrier synchronization - bool use_fp32_reduce // whether to use fp32 global reduce -) {} - } // namespace marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeId const b_q_type_id, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, bool is_zp_float) { +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); return torch::empty({1, 1}); @@ -91,369 +62,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, #else -// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 -// output/accumulation. -template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -// Instruction for loading a full 16x16 matrix fragment of operand A from shared -// memory, directly in tensor core layout. -template -__device__ inline void ldsm4(typename ScalarType::FragA& frag_a, - const void* smem_ptr) { - uint32_t* a = reinterpret_cast(&frag_a); - uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); - asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" - : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) - : "r"(smem)); -} - -// Lookup-table based 3-input logical operation; explicitly used for -// dequantization as the compiler does not seem to automatically recognize it in -// all cases. -template -__device__ inline int lop3(int a, int b, int c) { - int res; - asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" - : "=r"(res) - : "r"(a), "r"(b), "r"(c), "n"(lut)); - return res; -} - -// Constructs destination register by taking bytes from 2 sources (based on -// mask) -template -__device__ inline uint32_t prmt(uint32_t a) { - uint32_t res; - asm volatile("prmt.b32 %0, %1, %2, %3;\n" - : "=r"(res) - : "r"(a), "n"(start_byte), "n"(mask)); - return res; -} - -template -__device__ inline typename ScalarType::FragB dequant(int q); - -// -// Efficiently dequantize 4bit values packed in an int32 value into a full -// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, -// with some small changes: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 -// -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); - // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point - // directly into `SUB` and `ADD`. - const int SUB = 0x64086408; - const int MUL = 0x2c002c00; - const int ADD = 0xd480d480; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - q >>= 4; - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC308C308; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - const int LO = 0x000f000f; - const int HI = 0x00f000f0; - const int EX = 0x64006400; - // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); - - const int SUB = 0x64006400; - const int MUL = 0x2c002c00; - const int ADD = 0xd400d400; - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&SUB)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t MASK = 0x000f000f; - static constexpr uint32_t EX = 0x43004300; - - // Guarantee that the `(a & b) | c` operations are LOP3s. - - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - q >>= 4; - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, MASK, EX); - - typename ScalarType::FragB frag_b; - static constexpr uint32_t MUL = 0x3F803F80; - static constexpr uint32_t ADD = 0xC300C300; - - frag_b[0] = __hfma2(*reinterpret_cast(&lo), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - frag_b[1] = __hfma2(*reinterpret_cast(&hi), - *reinterpret_cast(&MUL), - *reinterpret_cast(&ADD)); - return frag_b; -} - -// -// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or -// bf16 Reference: -// - FP16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 -// - BF16: -// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 -// -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388736.f; - fp32_intermediates[1] -= 8388736.f; - fp32_intermediates[2] -= 8388736.f; - fp32_intermediates[3] -= 8388736.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - static constexpr uint32_t mask_for_elt_01 = 0x5250; - static constexpr uint32_t mask_for_elt_23 = 0x5351; - static constexpr uint32_t start_byte_for_fp16 = 0x64646464; - - uint32_t lo = prmt(q); - uint32_t hi = prmt(q); - - static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; - - typename ScalarType::FragB frag_b; - frag_b[0] = __hsub2(*reinterpret_cast(&lo), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - frag_b[1] = __hsub2(*reinterpret_cast(&hi), - *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); - return frag_b; -} - -template <> -__device__ inline typename ScalarType::FragB -dequant(int q) { - typename ScalarType::FragB frag_b; - - float fp32_intermediates[4]; - uint32_t* fp32_intermediates_casted = - reinterpret_cast(fp32_intermediates); - - static constexpr uint32_t fp32_base = 0x4B000000; - fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); - fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); - fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); - fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); - - fp32_intermediates[0] -= 8388608.f; - fp32_intermediates[1] -= 8388608.f; - fp32_intermediates[2] -= 8388608.f; - fp32_intermediates[3] -= 8388608.f; - - uint32_t* bf16_result_ptr = reinterpret_cast(&frag_b); - bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], - fp32_intermediates_casted[1], 0x7632); - bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], - fp32_intermediates_casted[3], 0x7632); - - return frag_b; -} - -// Multiply dequantized values by the corresponding quantization scale; used -// only for grouped quantization. -template -__device__ inline void scale(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s = - ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); - frag_b[0] = __hmul2(frag_b[0], s); - frag_b[1] = __hmul2(frag_b[1], s); -} - -template -__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, - typename ScalarType::scalar_t2& frag_zp, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 zp = - ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); - frag_b[0] = __hsub2(frag_b[0], zp); - frag_b[1] = __hsub2(frag_b[1], zp); -} - -// Same as above, but for act_order (each K is multiplied individually) -template -__device__ inline void scale4(typename ScalarType::FragB& frag_b, - typename ScalarType::FragS& frag_s_1, - typename ScalarType::FragS& frag_s_2, - typename ScalarType::FragS& frag_s_3, - typename ScalarType::FragS& frag_s_4, - int i) { - using scalar_t2 = typename ScalarType::scalar_t2; - scalar_t2 s_val_1_2; - s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; - s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; - - scalar_t2 s_val_3_4; - s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; - s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; - - frag_b[0] = __hmul2(frag_b[0], s_val_1_2); - frag_b[1] = __hmul2(frag_b[1], s_val_3_4); -} - -// Given 2 floats multiply by 2 scales (halves) -template -__device__ inline void scale_float(float* c, - typename ScalarType::FragS& s) { - scalar_t* s_ptr = reinterpret_cast(&s); - c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); - c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); -} - -// Wait until barrier reaches `count`, then lock for current threadblock. -__device__ inline void barrier_acquire(int* lock, int count) { - if (threadIdx.x == 0) { - int state = -1; - do - // Guarantee that subsequent writes by this threadblock will be visible - // globally. - asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" - : "=r"(state) - : "l"(lock)); - while (state != count); - } - __syncthreads(); -} - -// Release barrier and increment visitation count. -__device__ inline void barrier_release(int* lock, bool reset = false) { - __syncthreads(); - if (threadIdx.x == 0) { - if (reset) { - lock[0] = 0; - return; - } - int val = 1; - // Make sure that all writes since acquiring this barrier are visible - // globally, while releasing the barrier. - asm volatile("fence.acq_rel.gpu;\n"); - asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" - : - : "l"(lock), "r"(val)); - } -} - // For a given "a" of size [M,K] performs a permutation of the K columns based // on the given "perm" indices. __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, @@ -510,1304 +118,19 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } } -template shared - // fetch pipeline - const bool has_act_order, // whether act_order is enabled - const bool has_zp, // whether zero-points are enabled - const int group_blocks = -1, // number of consecutive 16x16 blocks - // with a separate quantization scale - const bool is_zp_float // is zero point of float16 type? - > -__global__ void Marlin( - const int4* __restrict__ A, // fp16 input matrix of shape mxk - const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn - int4* __restrict__ C, // fp16 output buffer of shape mxn - int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) - const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape - // (k/groupsize)xn - const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape - // (k/groupsize)x(n/pack_factor) - const int* __restrict__ g_idx, // int32 group indices of shape k - int num_groups, // number of scale groups per output channel - int prob_m, // batch dimension m - int prob_n, // output dimension n - int prob_k, // reduction dimension k - int lda, // A.stride(0), equal to prob_k is A is contiguous - int* locks, // extra global storage for barrier synchronization - bool use_atomic_add, // whether to use atomic add to reduce - bool use_fp32_reduce // whether to use fp32 global reduce -) { - // Each threadblock processes one "stripe" of the B matrix with (roughly) the - // same size, which might involve multiple column "slices" (of width 16 * - // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM - // example: - // 0 1 3 - // 0 2 3 - // 1 2 4 - // While this kind of partitioning makes things somewhat more complicated, it - // ensures good utilization of all SMs for many kinds of shape and GPU - // configurations, while requiring as few slow global cross-threadblock - // reductions as possible. - using Dtype = ScalarType; - using scalar_t2 = typename ScalarType::scalar_t2; - using FragA = typename ScalarType::FragA; - using FragB = typename ScalarType::FragB; - using FragC = typename ScalarType::FragC; - using FragS = typename ScalarType::FragS; - using FragZP = typename ScalarType::FragZP; - - static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); - - constexpr int pack_factor = 32 / w_type.size_bits(); - - // For larger GEMMs we run multiple batchsize 64 versions in parallel for a - // better partitioning with less reductions - int parallel = 1; - if (prob_m > 16 * thread_m_blocks) { - parallel = prob_m / (16 * thread_m_blocks); - prob_m = 16 * thread_m_blocks; - } - - int k_tiles = prob_k / 16 / thread_k_blocks; - int n_tiles = prob_n / 16 / thread_n_blocks; - int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); - - if constexpr (!has_act_order && group_blocks != -1) { - if (group_blocks >= thread_k_blocks) { - // Ensure that the number of tiles in each stripe is a multiple of the - // groupsize; this avoids an annoying special case where a stripe starts - // in the middle of group. - iters = (group_blocks / thread_k_blocks) * - div_ceil(iters, (group_blocks / thread_k_blocks)); - } - } - - int slice_row = (iters * blockIdx.x) % k_tiles; - int slice_col_par = (iters * blockIdx.x) / k_tiles; - int slice_col = slice_col_par; - int slice_iters; // number of threadblock tiles in the current slice - int slice_count = - 0; // total number of active threadblocks in the current slice - int slice_idx; // index of threadblock in current slice; numbered bottom to - // top - - int par_id = 0; - - // We can easily implement parallel problem execution by just remapping - // indices and advancing global pointers - if (slice_col_par >= n_tiles) { - A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; - C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; - locks += (slice_col_par / n_tiles) * n_tiles; - slice_col = slice_col_par % n_tiles; - par_id = slice_col_par / n_tiles; - } - - // Compute all information about the current slice which is required for - // synchronization. - auto init_slice = [&]() { - slice_iters = - iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); - if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; - if (slice_iters == 0) return; - if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; - slice_count = 1; - slice_idx = 0; - int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); - if (col_first <= k_tiles * (slice_col_par + 1)) { - int col_off = col_first - k_tiles * slice_col_par; - slice_count = div_ceil(k_tiles - col_off, iters); - if (col_off > 0) slice_count++; - int delta_first = iters * blockIdx.x - col_first; - if (delta_first < 0 || (col_off == 0 && delta_first == 0)) - slice_idx = slice_count - 1; - else { - slice_idx = slice_count - 1 - delta_first / iters; - if (col_off > 0) slice_idx--; - } - } - if (slice_col == n_tiles) { - A += 16 * thread_m_blocks * lda / 8; - C += 16 * thread_m_blocks * prob_n / 8; - locks += n_tiles; - slice_col = 0; - par_id++; - } - }; - init_slice(); - - // A sizes/strides - - // stride of the A matrix in global memory - int a_gl_stride = lda / 8; - // stride of an A matrix tile in shared memory - constexpr int a_sh_stride = 16 * thread_k_blocks / 8; - // delta between subsequent A tiles in global memory - constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; - // between subsequent accesses within a tile - int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); - // between shared memory writes - constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); - // between shared memory tile reads - constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); - // within a shared memory tile - constexpr int a_sh_rd_delta_i = a_sh_stride * 16; - // overall size of a tile - constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); - // number of shared write iterations for a tile - constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); - - // B sizes/strides - int b_gl_stride = 16 * prob_n / (pack_factor * 4); - constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; - constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; - constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; - - int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; - int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); - constexpr int b_sh_wr_delta = threads * b_thread_vecs; - constexpr int b_sh_rd_delta = threads * b_thread_vecs; - constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; - constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; - - // Scale sizes/strides without act_order - int s_gl_stride = prob_n / 8; - constexpr int s_sh_stride = 16 * thread_n_blocks / 8; - constexpr int s_tb_groups = - !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks - ? thread_k_blocks / group_blocks - : 1; - constexpr int s_sh_stage = s_tb_groups * s_sh_stride; - int s_gl_rd_delta = s_gl_stride; - - // Scale size/strides with act_order - constexpr int tb_k = 16 * thread_k_blocks; - constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; - // constexpr int act_s_row_stride = 1; - // int act_s_col_stride = act_s_row_stride * num_groups; - int act_s_col_stride = 1; - int act_s_col_warp_stride = act_s_col_stride * 8; - int tb_n_warps = thread_n_blocks / 4; - int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; - - // Zero-points sizes/strides - int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; - constexpr int zp_sh_stride = is_zp_float - ? 16 * thread_n_blocks / 8 - : ((16 * thread_n_blocks) / pack_factor) / 4; - constexpr int zp_tb_groups = s_tb_groups; - constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; - int zp_gl_rd_delta = zp_gl_stride; - - // Global A read index of current thread. - int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - a_gl_rd += a_gl_rd_delta_o * slice_row; - // Shared write index of current thread. - int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - // Shared read index. - int a_sh_rd = - a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; - a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); - - int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; - b_gl_rd += b_sh_stride * slice_col; - b_gl_rd += b_gl_rd_delta_o * slice_row; - auto b_sh_wr = threadIdx.x * b_thread_vecs; - auto b_sh_rd = threadIdx.x * b_thread_vecs; - - // For act_order - constexpr int k_iter_size = tb_k / b_sh_wr_iters; - int slice_k_start = tb_k * slice_row; - int slice_k_finish = slice_k_start + tb_k * slice_iters; - int slice_k_start_shared_fetch = slice_k_start; - int slice_n_offset = act_s_col_tb_stride * slice_col; - - // No act_order - int s_gl_rd; - if constexpr (!has_act_order) { - if constexpr (group_blocks == -1) { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - } else { - s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - s_sh_stride * slice_col + threadIdx.x; - } - } - auto s_sh_wr = threadIdx.x; - bool s_sh_wr_pred = threadIdx.x < s_sh_stride; - - // Zero-points - int zp_gl_rd; - if constexpr (has_zp) { - if constexpr (group_blocks == -1) { - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } else { - zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + - zp_sh_stride * slice_col + threadIdx.x; - } - } - auto zp_sh_wr = threadIdx.x; - bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; - - // We use a different scale layout for grouped and column-wise quantization as - // we scale a `half2` tile in column-major layout in the former and in - // row-major in the latter case. - int s_sh_rd; - if constexpr (group_blocks != -1) - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - else - s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) % 4; - - // Zero-points have the same read layout as the scales - // (without column-wise case) - constexpr int num_col_threads = 8; - constexpr int num_row_threads = 4; - constexpr int num_ints_per_thread = 8 / pack_factor; - int zp_sh_rd; - if constexpr (has_zp) { - if constexpr (is_zp_float) { - if constexpr (group_blocks != -1) { - zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - (threadIdx.x % 32) / 4; - } - } else { - zp_sh_rd = num_ints_per_thread * num_col_threads * - ((threadIdx.x / 32) % (thread_n_blocks / 4)) + - num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); - } - } - - // Precompute which thread should not read memory in which iterations; this is - // needed if there are more threads than required for a certain tilesize or - // when the batchsize is not a multiple of 16. - bool a_sh_wr_pred[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; - - // To ensure that writing and reading A tiles to/from shared memory, the - // latter in fragment format, is fully bank conflict free, we need to use a - // rather fancy XOR-based layout. The key here is that neither reads nor - // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the - // same shared memory banks. Further, it seems (based on NSight-Compute) that - // each warp must also write a consecutive memory segment? - auto transform_a = [&](int i) { - int row = i / a_gl_rd_delta_o; - return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; - }; - // Since the computation of this remapping is non-trivial and, due to our main - // loop unrolls, all shared memory accesses are static, we simply precompute - // both transformed reads and writes. - int a_sh_wr_trans[a_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) - a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); - int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < thread_m_blocks; j++) - a_sh_rd_trans[i][j] = - transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); - } - - // Since B-accesses have non-constant stride they have to be computed at - // runtime; we break dependencies between subsequent accesses with a tile by - // maintining multiple pointers (we have enough registers), a tiny - // optimization. - const int4* B_ptr[b_sh_wr_iters]; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; - - extern __shared__ int4 sh[]; - // Shared memory storage for global fetch pipelines. - int4* sh_a = sh; - int4* sh_b = sh_a + (stages * a_sh_stage); - int4* sh_g_idx = sh_b + (stages * b_sh_stage); - int4* sh_zp = sh_g_idx + (stages * g_idx_stage); - int4* sh_s = sh_zp + (stages * zp_sh_stage); - int4* sh_red = sh_s + (stages * s_sh_stage); - - // Register storage for double buffer of shared memory reads. - FragA frag_a[2][thread_m_blocks]; - I4 frag_b_quant[2][b_thread_vecs]; - FragC frag_c[thread_m_blocks][4][2]; - FragS frag_s[2][4]; // No act-order - FragS act_frag_s[2][4][4]; // For act-order - int frag_qzp[2][num_ints_per_thread]; // Zero-points - FragZP frag_zp; // Zero-points in fp16 - FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ - - // Zero accumulators. - auto zero_accums = [&]() { - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) - reinterpret_cast(frag_c)[i] = 0; - }; - - int sh_first_group_id = -1; - int sh_num_groups = -1; - constexpr int sh_max_num_groups = 32; - - auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, - int last_group_id) { - sh_first_group_id = first_group_id; - sh_num_groups = last_group_id - first_group_id + 1; - - if (sh_num_groups < sh_max_num_groups) { - sh_num_groups = sh_max_num_groups; - } - - if (sh_first_group_id + sh_num_groups > num_groups) { - sh_num_groups = num_groups - sh_first_group_id; - } - - int row_offset = first_group_id * s_gl_stride; - - if (is_async) { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], - &scales_ptr[row_offset + (i * s_gl_stride) + - slice_n_offset + threadIdx.x]); - } - } - } else { - for (int i = 0; i < sh_num_groups; i++) { - if (threadIdx.x < s_sh_stride) { - sh_s[(i * s_sh_stride) + threadIdx.x] = - scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + - threadIdx.x]; - } - } - } - }; - // Asynchronously fetch the next A, B and s tile from global to the next - // shared memory pipeline location. - auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { - if (pred) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < a_sh_wr_iters; i++) { - cp_async4_pred( - &sh_a_stage[a_sh_wr_trans[i]], - &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], - a_sh_wr_pred[i]); - } - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) { - #pragma unroll - for (int j = 0; j < b_thread_vecs; j++) { - cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); - } - - B_ptr[i] += b_gl_rd_delta_o; - } - - if constexpr (has_act_order) { - // Fetch g_idx thread-block portion - int full_pipe = a_off; - int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; - if (cur_k < prob_k && cur_k < slice_k_finish) { - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - - int4 const* cur_g_idx_stage_ptr = - reinterpret_cast(&g_idx[cur_k]); - - if (threadIdx.x < g_idx_stage) { - cp_async4_pred(&sh_g_idx_stage[threadIdx.x], - &cur_g_idx_stage_ptr[threadIdx.x]); - } - } - } else { - if constexpr (group_blocks != -1) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); - } - // Only fetch scales if this tile starts a new group - if ((pipe + 1) % (group_blocks / thread_k_blocks) == 0) { - s_gl_rd += s_gl_rd_delta; - } - } else { - for (int i = 0; i < s_tb_groups; i++) { - if (s_sh_wr_pred) { - cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], - &scales_ptr[s_gl_rd]); - } - s_gl_rd += s_gl_rd_delta; - } - } - } - - if constexpr (has_zp && group_blocks != -1) { - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - if constexpr (group_blocks >= thread_k_blocks) { - // Only fetch zero-points if this tile starts a new group - if (pipe % (group_blocks / thread_k_blocks) == 0) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } else { - for (int i = 0; i < zp_tb_groups; i++) { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], - &zp_ptr[zp_gl_rd]); - } - zp_gl_rd += zp_gl_rd_delta; - } - } - } - } - } - // Insert a fence even when we are winding down the pipeline to ensure that - // waiting is also correct at this point. - cp_async_fence(); - }; - - auto fetch_zp_to_shared = [&]() { - if (zp_sh_wr_pred) { - cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); - } - }; - - // Wait until the next thread tile has been loaded to shared memory. - auto wait_for_stage = [&]() { - // We only have `stages - 2` active fetches since we are double buffering - // and can only issue the next fetch when it is guaranteed that the previous - // shared memory load is fully complete (as it may otherwise be - // overwritten). - cp_async_wait(); - __syncthreads(); - }; - - // Load the next sub-tile from the current location in the shared memory pipe - // into the current register buffer. - auto fetch_to_registers = [&](int k, int pipe) { - int4* sh_a_stage = sh_a + a_sh_stage * pipe; - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) - ldsm4(frag_a[k % 2][i], - &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); - int4* sh_b_stage = sh_b + b_sh_stage * pipe; - - #pragma unroll - for (int i = 0; i < b_thread_vecs; i++) { - frag_b_quant[k % 2][i] = *reinterpret_cast( - &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); - } - }; - - bool is_same_group[stages]; - int same_group_id[stages]; - - auto init_same_group = [&](int pipe) { - if constexpr (!has_act_order) { - is_same_group[pipe] = false; - same_group_id[pipe] = 0; - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - int group_id_1 = sh_g_idx_int_ptr[0]; - int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; - - is_same_group[pipe] = group_id_1 == group_id_2; - same_group_id[pipe] = group_id_1; - }; - - auto fetch_scales_to_registers = [&](int k, int full_pipe) { - int pipe = full_pipe % stages; - - if constexpr (!has_act_order) { - // No act-order case - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = k_blocks / group_blocks; - - int4* sh_s_stage = sh_s + s_sh_stage * pipe; - - reinterpret_cast(&frag_s[k % 2])[0] = - sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; - } - } - - return; - } - - // Act-order case - - // Determine K of the "current" thread-block - int cur_k = slice_k_start + tb_k * full_pipe; - if (cur_k >= prob_k || cur_k >= slice_k_finish) { - return; - } - - // Reset (to current thread-block) since we read g_idx portion from the - // shared memory - cur_k = 0; - - // Progress to current iteration - cur_k += k_iter_size * (k % b_sh_wr_iters); - - // Determine "position" inside the thread-block (based on warp and - // thread-id) - auto warp_id = threadIdx.x / 32; - int n_warps = - thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N - - int warp_row = warp_id / n_warps; - int warp_col = warp_id % n_warps; - - cur_k += warp_row * 16; - - auto th_id = threadIdx.x % 32; - cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix - - int s_col_shift = - /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + - (th_id / 4) * act_s_col_stride; - - if (is_same_group[pipe]) { - if (k % 2 == 0) { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + - s_col_shift]; - } else { - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = - *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); - } - - for (int i = 1; i < 4; i++) { - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); - } - return; - } - - int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; - int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); - - constexpr int k_frag_offsets[4] = {0, 1, 8, - 9}; // Tensor core offsets per thread - - #pragma unroll - for (int i = 0; i < 4; i++) { - int actual_k = cur_k + k_frag_offsets[i]; - - int group_id = sh_g_idx_int_ptr[actual_k]; - int rel_group_id = group_id - sh_first_group_id; - - *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = - sh_s[rel_group_id * s_sh_stride + s_col_shift]; - } - }; - - auto fetch_zp_to_registers = [&](int k, int full_pipe) { - // This code does not handle group_blocks == 0, - // which signifies act_order. - // has_zp implies AWQ, which doesn't have act_order, - static_assert(!has_zp || group_blocks != 0); - - if constexpr (has_zp && !is_zp_float) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks == -1) { - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; - } - - } else if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - int cur_group_id = 0; - - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - sh_zp_stage += cur_group_id * zp_sh_stride; - - for (int i = 0; i < num_ints_per_thread; i++) { - frag_qzp[k % 2][i] = - (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; - } - } - } - - else if constexpr (has_zp && is_zp_float) { - int pipe = full_pipe % stages; - - if constexpr (group_blocks != -1) { - if constexpr (group_blocks >= thread_k_blocks) { - int4* sh_zp_stage = - sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * - (pipe / (group_blocks / thread_k_blocks))); - reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; - } else { - auto warp_id = threadIdx.x / 32; - int n_warps = thread_n_blocks / 4; - - int warp_row = warp_id / n_warps; - - int cur_k = warp_row * 16; - cur_k += k_iter_size * (k % b_sh_wr_iters); - - int k_blocks = cur_k / 16; - // Suppress bogus and persistent divide-by-zero warning - #pragma nv_diagnostic push - #pragma nv_diag_suppress divide_by_zero - int cur_group_id = k_blocks / group_blocks; - #pragma nv_diagnostic pop - - int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; - - reinterpret_cast(&frag_zpf[k % 2])[0] = - sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; - } - } - } - }; - - // Execute the actual tensor core matmul of a sub-tile. - auto matmul = [&](int k) { - if constexpr (has_zp && !is_zp_float) { - FragB frag_zp_0; - FragB frag_zp_1; - int zp_quant_0, zp_quant_1; - - if constexpr (w_type.size_bits() == 4) { - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = zp_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - zp_quant_0 = frag_qzp[k % 2][0]; - zp_quant_1 = frag_qzp[k % 2][1]; - } - - frag_zp_0 = dequant(zp_quant_0); - frag_zp_1 = dequant(zp_quant_1); - - frag_zp[0] = frag_zp_0[0]; - frag_zp[1] = frag_zp_0[1]; - frag_zp[2] = frag_zp_1[0]; - frag_zp[3] = frag_zp_1[1]; - } - - // We have the m dimension as the inner loop in order to encourage overlapping - // dequantization and matmul operations. - #pragma unroll - for (int j = 0; j < 4; j++) { - FragB frag_b0; - FragB frag_b1; - int b_quant_0, b_quant_1; - - if constexpr (w_type.size_bits() == 4) { - b_quant_0 = frag_b_quant[k % 2][0][j]; - b_quant_1 = b_quant_0 >> 8; - } else { - static_assert(w_type.size_bits() == 8); - int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); - b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; - b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; - } - - frag_b0 = dequant(b_quant_0); - frag_b1 = dequant(b_quant_1); - - // Apply zero-point to frag_b0 - if constexpr (has_zp && !is_zp_float) { - sub_zp(frag_b0, frag_zp[j], 0); - } - - else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - sub_zp(frag_b0, frag_zpf[k % 2][j], 0); - } - - // Apply scale to frag_b0 - if constexpr (has_act_order) { - scale4(frag_b0, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 0); - } else { - if constexpr (group_blocks != -1) { - scale(frag_b0, frag_s[k % 2][j], 0); - } - } - - // Apply zero-point to frag_b1 - if constexpr (has_zp && !is_zp_float) { - sub_zp(frag_b1, frag_zp[j], 1); - } - - else if constexpr (has_zp && is_zp_float && group_blocks != -1) { - sub_zp(frag_b1, frag_zpf[k % 2][j], 1); - } - - // Apply scale to frag_b1 - if constexpr (has_act_order) { - scale4(frag_b1, act_frag_s[k % 2][0][j], - act_frag_s[k % 2][1][j], act_frag_s[k % 2][2][j], - act_frag_s[k % 2][3][j], 1); - - } else { - if constexpr (group_blocks != -1) { - scale(frag_b1, frag_s[k % 2][j], 1); - } - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); - } - } - }; - - // Since we slice across the k dimension of a tile in order to increase the - // number of warps while keeping the n dimension of a tile reasonable, we have - // multiple warps that accumulate their partial sums of the same output - // location; which we have to reduce over in the end. We do in shared memory. - auto thread_block_reduce = [&]() { - constexpr int red_off = threads / b_sh_stride_threads / 2; - if (red_off >= 1) { - auto red_idx = threadIdx.x / b_sh_stride_threads; - constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; - constexpr int red_sh_delta = b_sh_stride_threads; - int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + - (threadIdx.x % b_sh_stride_threads); - - // Parallel logarithmic shared memory reduction. We make sure to avoid any - // unnecessary read or write iterations, e.g., for two warps we write only - // once by warp 1 and read only once by warp 0. - - #pragma unroll - for (int m_block = 0; m_block < thread_m_blocks; m_block++) { - #pragma unroll - for (int i = red_off; i > 0; i /= 2) { - if (i <= red_idx && red_idx < 2 * i) { - #pragma unroll - for (int j = 0; j < 4 * 2; j++) { - int red_sh_wr = - red_sh_delta * j + (red_sh_rd - red_sh_stride * i); - if (i < red_off) { - float* c_rd = reinterpret_cast( - &sh_red[red_sh_delta * j + red_sh_rd]); - float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); - #pragma unroll - for (int k = 0; k < 4; k++) - reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += - c_rd[k] + c_wr[k]; - } - sh_red[red_sh_wr] = - reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; - } - } - __syncthreads(); - } - if (red_idx == 0) { - #pragma unroll - for (int i = 0; i < 4 * 2; i++) { - float* c_rd = - reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); - #pragma unroll - for (int j = 0; j < 4; j++) - reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += - c_rd[j]; - } - } - __syncthreads(); - } - } - }; - - // Since multiple threadblocks may process parts of the same column slice, we - // finally have to globally reduce over the results. As the striped - // partitioning minimizes the number of such reductions and our outputs are - // usually rather small, we perform this reduction serially in L2 cache. - auto global_reduce_fp16 = [&](bool first = false, bool last = false) { - // We are very careful here to reduce directly in the output buffer to - // maximize L2 cache utilization in this step. To do this, we write out - // results in FP16 (but still reduce with FP32 compute). - constexpr int active_threads = 32 * thread_n_blocks / 4; - if (threadIdx.x < active_threads) { - int c_gl_stride = prob_n / 8; - int c_gl_wr_delta_o = 8 * c_gl_stride; - int c_gl_wr_delta_i = 4 * (active_threads / 32); - int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + - 4 * (threadIdx.x / 32) + threadIdx.x % 4; - c_gl_wr += (2 * thread_n_blocks) * slice_col; - constexpr int c_sh_wr_delta = active_threads; - auto c_sh_wr = threadIdx.x; - - int row = (threadIdx.x % 32) / 4; - - if (!first) { - // Interestingly, doing direct global accesses here really seems to mess up - // the compiler and lead to slowdowns, hence we also use async-copies even - // though these fetches are not actually asynchronous. - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - cp_async4_pred( - &sh_red[c_sh_wr + c_sh_wr_delta * i], - &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + - c_gl_wr_delta_i * (i % 2)], - i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); - } - cp_async_fence(); - cp_async_wait<0>(); - } - - #pragma unroll - for (int i = 0; i < thread_m_blocks * 4; i++) { - if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { - if (!first) { - int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += - Dtype::num2float(reinterpret_cast(&c_red)[j]); - } - } - if (!last) { - int4 c; - #pragma unroll - for (int j = 0; j < 2 * 4; j++) { - reinterpret_cast(&c)[j] = - Dtype::float2num(reinterpret_cast( - &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); - } - C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = - c; - } - } - } - } - }; - - // Globally reduce over threadblocks that compute the same column block. - // We use a tmp C buffer to reduce in full fp32 precision. - auto global_reduce_fp32 = [&](bool first = false, bool last = false) { - constexpr int tb_m = thread_m_blocks * 16; - constexpr int tb_n = thread_n_blocks * 16; - - constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; - - constexpr int active_threads = 32 * thread_n_blocks / 4; - bool is_th_active = threadIdx.x < active_threads; - - int par_offset = c_size * n_tiles * par_id; - int slice_offset = c_size * slice_col; - - constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; - constexpr int th_size = num_floats * sizeof(float) / 16; - - int c_cur_offset = par_offset + slice_offset; - - if (!is_th_active) { - return; - } - - if (!first) { - float* frag_c_ptr = reinterpret_cast(&frag_c); - #pragma unroll - for (int k = 0; k < th_size; k++) { - sh_red[threadIdx.x] = - C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; - - float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); - #pragma unroll - for (int f = 0; f < 4; f++) { - frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; - } - } - } - - if (!last) { - int4* frag_c_ptr = reinterpret_cast(&frag_c); - #pragma unroll - for (int k = 0; k < th_size; k++) { - C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; - } - } - }; - - // Write out the reduce final result in the correct layout. We only actually - // reshuffle matrix fragments in this step, the reduction above is performed - // in fragment layout. - auto write_result = [&]() { - int c_gl_stride = prob_n / 8; - constexpr int c_sh_stride = 2 * thread_n_blocks + 1; - int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); - constexpr int c_sh_rd_delta = - c_sh_stride * (threads / (2 * thread_n_blocks)); - - int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - c_gl_wr += (2 * thread_n_blocks) * slice_col; - int c_sh_wr = - (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; - c_sh_wr += 32 * (threadIdx.x / 32); - int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + - (threadIdx.x % (2 * thread_n_blocks)); - - int c_gl_wr_end = c_gl_stride * prob_m; - - // We first reorder in shared memory to guarantee the most efficient final - // global write patterns - auto write = [&](int idx, float c0, float c1, FragS& s) { - scalar_t2 res = - Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); - - // For per-column quantization we finally apply the scale here (only for - // 4-bit) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 4) { - res = __hmul2(res, s[0]); - } - - ((scalar_t2*)sh_red)[idx] = res; - }; - - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - int wr = c_sh_wr + 8 * j; - write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], - frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], - frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); - write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], - frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); - write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], - frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); - } - c_sh_wr += 16 * (4 * c_sh_stride); - } - } - __syncthreads(); - - #pragma unroll - for (int i = 0; - i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); - i++) { - if (c_gl_wr < c_gl_wr_end) { - if (use_atomic_add && slice_count > 1) { - scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); - scalar_t2* sh_red_half2 = - reinterpret_cast(&sh_red[c_sh_rd]); - #pragma unroll - for (int a = 0; a < 4; a++) { - atomicAdd(&C_half2[a], sh_red_half2[a]); - } - } else { - C[c_gl_wr] = sh_red[c_sh_rd]; - } - c_gl_wr += c_gl_wr_delta; - c_sh_rd += c_sh_rd_delta; - } - } - }; - - // Start global fetch and register load pipelines. - auto start_pipes = [&]() { - - #pragma unroll - for (int i = 0; i < stages - 1; i++) { - if (has_act_order && i == 0) { - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); - } - - if constexpr (has_zp && !is_zp_float && group_blocks == -1) { - if (i == 0) { - fetch_zp_to_shared(); - } - } - fetch_to_shared(i, i, i < slice_iters); - } - - zero_accums(); - wait_for_stage(); - init_same_group(0); - fetch_to_registers(0, 0); - fetch_scales_to_registers(0, 0); - fetch_zp_to_registers(0, 0); - a_gl_rd += a_gl_rd_delta_o * (stages - 1); - slice_k_start_shared_fetch += tb_k * (stages - 1); - }; - if (slice_iters) { - start_pipes(); - } - - // Main loop. - while (slice_iters) { - // We unroll over both the global fetch and the register load pipeline to - // ensure all shared memory accesses are static. Note that both pipelines - // have even length meaning that the next iteration will always start at - // index 0. - - #pragma unroll - for (int pipe = 0; pipe < stages;) { - #pragma unroll - for (int k = 0; k < b_sh_wr_iters; k++) { - fetch_to_registers(k + 1, pipe % stages); - fetch_scales_to_registers(k + 1, pipe); - fetch_zp_to_registers(k + 1, pipe); - if (k == b_sh_wr_iters - 2) { - fetch_to_shared((pipe + stages - 1) % stages, pipe, - slice_iters >= stages); - pipe++; - wait_for_stage(); - init_same_group(pipe % stages); - } - matmul(k); - } - slice_iters--; - if (slice_iters == 0) { - break; - } - } - - a_gl_rd += a_gl_rd_delta_o * stages; - slice_k_start += tb_k * stages; - slice_k_start_shared_fetch += tb_k * stages; - - if constexpr (has_act_order) { - int first_group_id = g_idx[slice_k_start]; - int last_g_idx = slice_k_start + stages * tb_k * 2; - if (last_g_idx >= prob_k) { - last_g_idx = prob_k - 1; - } - int last_group_id = g_idx[last_g_idx]; - if (last_group_id >= sh_first_group_id + sh_num_groups) { - fetch_scales_to_shared(false, first_group_id, last_group_id); - __syncthreads(); - } - } - - // Process results and, if necessary, proceed to the next column slice. - // While this pattern may not be the most readable, other ways of writing - // the loop seemed to noticeably worse performance after compilation. - if (slice_iters == 0) { - cp_async_wait<0>(); - bool last = slice_idx == slice_count - 1; - // For per-column scales, we only fetch them here in the final step before - // write-out - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } else { - if (last || use_atomic_add) { - if (s_sh_wr_pred) { - cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); - } - cp_async_fence(); - } - } - } - - thread_block_reduce(); - if constexpr (!has_act_order && group_blocks == -1) { - if constexpr (w_type.size_bits() == 8) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - - } else { - if (last || use_atomic_add) { - cp_async_wait<0>(); - __syncthreads(); - if (threadIdx.x / 32 < thread_n_blocks / 4) { - reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; - reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; - } - } - } - } - - // For 8-bit channelwise, we apply the scale before the global reduction - // that converts the fp32 results to fp16 (so that we avoid possible - // overflow in fp16) - if constexpr (!has_act_order && group_blocks == -1 && - w_type.size_bits() == 8) { - if (threadIdx.x / 32 < thread_n_blocks / 4) { - #pragma unroll - for (int i = 0; i < thread_m_blocks; i++) { - #pragma unroll - for (int j = 0; j < 4; j++) { - scale_float( - reinterpret_cast(&frag_c[i][j][0][0]), - frag_s[j / 2][2 * (j % 2) + 0]); - scale_float( - reinterpret_cast(&frag_c[i][j][0][2]), - frag_s[j / 2][2 * (j % 2) + 0]); - - scale_float( - reinterpret_cast(&frag_c[i][j][1][0]), - frag_s[j / 2][2 * (j % 2) + 1]); - scale_float( - reinterpret_cast(&frag_c[i][j][1][2]), - frag_s[j / 2][2 * (j % 2) + 1]); - } - } - } - } - - if (slice_count > 1 && !use_atomic_add) { - // only globally reduce if there is more than one block in a slice - barrier_acquire(&locks[slice_col], slice_idx); - if (use_fp32_reduce) { - global_reduce_fp32(slice_idx == 0, last); - } else { - global_reduce_fp16(slice_idx == 0, last); - } - barrier_release(&locks[slice_col], last); - } - if (last || use_atomic_add) - // only the last block in a slice actuallywrites the result - write_result(); - slice_row = 0; - slice_col_par++; - slice_col++; - init_slice(); - if (slice_iters) { - a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + - (threadIdx.x % a_gl_rd_delta_o); - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) - B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; - if (slice_col == 0) { - #pragma unroll - for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; - } - - // Update slice k/n for scales loading - if constexpr (has_act_order) { - slice_k_start = tb_k * slice_row; - slice_k_finish = slice_k_start + tb_k * slice_iters; - slice_k_start_shared_fetch = slice_k_start; - slice_n_offset = act_s_col_tb_stride * slice_col; - - } else { - s_gl_rd = s_sh_stride * slice_col + threadIdx.x; - zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; - } - - start_pipes(); - } - } - } -} - - #define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS, \ - IS_ZP_FLOAT) \ - else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ - thread_n_blocks == THREAD_N_BLOCKS && \ - thread_k_blocks == THREAD_K_BLOCKS && \ - has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ - is_zp_float == IS_ZP_FLOAT) { \ - if constexpr (!IS_ZP_FLOAT || std::is_same::value) { \ - cudaFuncSetAttribute( \ - Marlin, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - Marlin \ - <<>>( \ - A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \ - num_groups, prob_m, prob_n, prob_k, lda, locks, \ - part_use_atomic_add, use_fp32_reduce); \ - } \ - } - typedef struct { int thread_k; int thread_n; int num_threads; } thread_config_t; -typedef struct { - int max_m_blocks; - thread_config_t tb_cfg; -} exec_config_t; - thread_config_t small_batch_thread_configs[] = { // Ordered by priority // thread_k, thread_n, num_threads {128, 128, 256}, {64, 128, 128}, - {128, 64, 128}, -}; + {128, 64, 128}}; thread_config_t large_batch_thread_configs[] = { // Ordered by priority @@ -1815,9 +138,12 @@ thread_config_t large_batch_thread_configs[] = { // thread_k, thread_n, num_threads {64, 256, 256}, {64, 128, 128}, - {128, 64, 128}, + {128, 64, 128}}; -}; +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, @@ -1842,7 +168,6 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; - } else { int tb_scales = tb_groups * tb_n * 2; @@ -1850,49 +175,43 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, } } -bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, - int prob_m, int prob_n, int prob_k, int num_bits, - int scales_cache_size, int max_shared_mem) { +int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int has_zp, int is_zp_float) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; - - int b_size = (tb_k * tb_n / pack_factor) * 4; - - // Get A size - int m_blocks = div_ceil(prob_m, 16); - int tb_max_m = 16; - - while (true) { - if (m_blocks >= max_m_blocks) { - tb_max_m *= max_m_blocks; - break; - } - - max_m_blocks--; - if (max_m_blocks == 0) { - TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); - } + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; } - int a_size = (tb_max_m * tb_k) * 2; - - float pipe_size = (a_size + b_size) * pipe_stages; - - float reduce_size = max(th_config.num_threads * 32 * 4, - (tb_n / 64) * 32 * (tb_max_m / 16) * 4 * 2 * 4 * 2); + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + + sh_zp_size + sh_g_idx_size; - TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity - - return pipe_size + reduce_size < 0.95f * (max_shared_mem - scales_cache_size); + return total_size; } -bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, +bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int max_shared_mem) { + int has_zp, int is_zp_float, int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -1914,242 +233,250 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, return false; } - // Determine cache for scales - int scales_cache_size = - get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - // Check that pipeline fits into cache - if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, scales_cache_size, max_shared_mem)) { - return false; - } - - return true; + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, has_zp, is_zp_float); + return cache_size <= max_shared_mem; } -int determine_reduce_max_m(int prob_m, int max_par) { - constexpr int tile_m_size = 16; + #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + m_block_size_8 == M_BLOCK_SIZE_8 && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin; \ + } + + // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) + // this is the most common cases + // BIGGROUP: cases for big group size (group_blocks in [-1, 8]) + // FZP: cases for float-zero-point (is_zp_float = true) + // ACT: cases for act order case (group_blocks == 0) + // FP4: cases for nvfp4(e2m1) (group_blocks == 1) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + + #define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + + #define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + + #define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + + #define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + + // We currently have 4-bit models only with group_blocks == 4 + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + + #define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) - if (prob_m <= tile_m_size) { - return tile_m_size; +template +MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, + int thread_m_blocks, int thread_n_blocks, + int thread_k_blocks, bool m_block_size_8, + bool has_act_order, bool has_zp, + int group_blocks, int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } - } else if (prob_m <= tile_m_size * 2) { - return tile_m_size * 2; + COMMON_GET_IF(vllm::kU4) + COMMON_GET_IF(vllm::kU4B8) + COMMON_GET_IF(vllm::kU8B128) - } else if (prob_m <= tile_m_size * 3) { - return tile_m_size * 3; + FP4_GET_IF(vllm::kFE2M1f) - } else if (prob_m <= tile_m_size * 4) { - return tile_m_size * 4; + BIGGROUP_GET_IF(vllm::kFE4M3fn) - } else { - int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par); - return tile_m_size * 4 * cur_par; + ACT_GET_IF(vllm::kU4B8) + ACT_GET_IF(vllm::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(vllm::kU4) } + + return kernel; } -exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, - int num_bits, int group_size, - bool has_act_order, bool is_k_full, - int max_shared_mem) { - int max_m_blocks = 4; - while (max_m_blocks > 0) { - if (prob_m <= 16) { - for (auto th_config : small_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } - } else { - for (auto th_config : large_batch_thread_configs) { - if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, - num_bits, group_size, has_act_order, is_k_full, - max_shared_mem)) { - return exec_config_t{max_m_blocks, th_config}; - } - } +template +exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, + int prob_n, int prob_k, int thread_m_blocks, + bool m_block_size_8, int num_bits, + int group_size, bool has_act_order, + bool is_k_full, bool has_zp, + bool is_zp_float, int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 + ? large_batch_thread_configs + : small_batch_thread_configs; + int thread_configs_size = + thread_m_blocks > 1 + ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, has_zp, + is_zp_float, max_shared_mem)) { + continue; } - max_m_blocks--; // Process less M blocks per invocation to reduce cache - // usage - } + int cache_size = get_kernel_cache_size( + th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full, has_zp, is_zp_float); - return exec_config_t{0, {-1, -1, -1}}; -} + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } - #define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS, \ - false) - - #define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, \ - false) \ - \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - false) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS, false) + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, th_config.thread_n / 16, + th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, + group_blocks, th_config.num_threads, is_zp_float); - // We currently have 4-bit models only with group_blocks == 4 - #define HQQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - __CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, \ - true) \ - __CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS, true) + if (kernel == MarlinDefault) continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} template void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, - void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m, - int prob_n, int prob_k, int lda, void* workspace, + void* s2, void* zp, void* g_idx, void* perm, void* a_tmp, + int prob_m, int prob_n, int prob_k, int lda, void* workspace, vllm::ScalarType const& q_type, bool has_act_order, bool is_k_full, bool has_zp, int num_groups, int group_size, - int dev, cudaStream_t stream, int thread_k, int thread_n, - int sms, int max_par, bool use_atomic_add, bool use_fp32_reduce, - bool is_zp_float) { + int dev, cudaStream_t stream, int thread_k_init, + int thread_n_init, int sms, bool use_atomic_add, + bool use_fp32_reduce, bool is_zp_float) { if (has_zp) { TORCH_CHECK( q_type == vllm::kU4 || q_type == vllm::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); } else { TORCH_CHECK( - q_type == vllm::kU4B8 || q_type == vllm::kU8B128, - "q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", + q_type == vllm::kU4B8 || q_type == vllm::kU8B128 || + q_type == vllm::kFE4M3fn || q_type == vllm::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", q_type.str()); } TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); - // TODO: remove alias when we start supporting other 8bit types - int num_bits = q_type.size_bits(); - int tot_m = prob_m; - int tot_m_blocks = div_ceil(tot_m, 16); - int pad = 16 * tot_m_blocks - tot_m; - - if (sms == -1) { - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); - } - - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); - TORCH_CHECK(max_shared_mem > 0); - - // Set thread config - exec_config_t exec_cfg; - if (thread_k != -1 && thread_n != -1) { - // User-defined config - exec_cfg = - exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}}; - } else { - // Auto config - exec_cfg = - determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem); - } - - TORCH_CHECK(exec_cfg.max_m_blocks > 0 && - is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, - prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, max_shared_mem), - "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, - ", thread_k = ", exec_cfg.tb_cfg.thread_k, - ", thread_n = ", exec_cfg.tb_cfg.thread_n, - ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", - prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, - ", group_size = ", group_size, - ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, - ", max_shared_mem = ", max_shared_mem); - - int num_threads = exec_cfg.tb_cfg.num_threads; - thread_k = exec_cfg.tb_cfg.thread_k; - thread_n = exec_cfg.tb_cfg.thread_n; - - int thread_k_blocks = thread_k / 16; - int thread_n_blocks = thread_n / 16; - - int blocks = sms; - - TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, - " is not divisible by thread_n = ", thread_n); - TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, - " is not divisible by thread_k = ", thread_k); - int group_blocks = 0; if (has_act_order) { if (is_k_full) { @@ -2161,7 +488,6 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, TORCH_CHECK(group_size == 0); group_blocks = 0; } - } else { if (group_size == -1) { group_blocks = -1; @@ -2172,11 +498,13 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, } } + int num_bits = q_type.size_bits(); const int4* A_ptr = (const int4*)A; const int4* B_ptr = (const int4*)B; int4* C_ptr = (int4*)C; int4* C_tmp_ptr = (int4*)C_tmp; const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -2186,106 +514,139 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s, if (has_act_order) { // Permute A columns - int block_rows = div_ceil(prob_m, blocks); - permute_cols_kernel<<>>( + int block_rows = div_ceil(prob_m, sms); + // avoid ">>>" being formatted to "> > >" + // clang-format off + permute_cols_kernel<<>>( A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + // clang-format on A_ptr = a_tmp_ptr; lda = prob_k; - } - // If we have a full K, then we can run the non-act-order version of Marlin - // (since the weight rows are reordered by increasing group ids, and by having - // a full K, we have full original groups) - if (is_k_full) { - has_act_order = false; + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; } - // Main loop - for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) { - int thread_m_blocks = tot_m_blocks - i; - prob_m = tot_m - 16 * i; - int par = 1; - if (thread_m_blocks > exec_cfg.max_m_blocks) { - // Note that parallel > 1 currently only works for inputs without any - // padding - par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks); - if (par > max_par) par = max_par; - prob_m = (16 * exec_cfg.max_m_blocks) * par; - i += exec_cfg.max_m_blocks * (par - 1); - thread_m_blocks = exec_cfg.max_m_blocks; + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = + par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, + num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, + max_shared_mem, sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } } - // atomic add reduce have better performance only when m * n is small - bool part_use_atomic_add = - use_atomic_add && div_ceil(prob_m, 64) * prob_n <= 2048; + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) + max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; - if (false) { - } - GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256) - GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256) - GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128) - GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128) - GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256) - GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256) - GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128) - GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128) - - AWQ_CALL_IF(vllm::kU4, 16, 4, 256) - AWQ_CALL_IF(vllm::kU4, 8, 8, 256) - AWQ_CALL_IF(vllm::kU4, 8, 4, 128) - AWQ_CALL_IF(vllm::kU4, 4, 8, 128) - AWQ_CALL_IF(vllm::kU8, 16, 4, 256) - AWQ_CALL_IF(vllm::kU8, 8, 8, 256) - AWQ_CALL_IF(vllm::kU8, 8, 4, 128) - AWQ_CALL_IF(vllm::kU8, 4, 8, 128) - - HQQ_CALL_IF(vllm::kU4, 16, 4, 256) - HQQ_CALL_IF(vllm::kU4, 8, 8, 256) - HQQ_CALL_IF(vllm::kU4, 8, 4, 128) - HQQ_CALL_IF(vllm::kU4, 4, 8, 128) - else { + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + TORCH_CHECK( + is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, + prob_k, num_bits, group_size, has_act_order, is_k_full, + has_zp, is_zp_float, max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", thread_m_blocks, + ", thread_k = ", thread_tfg.thread_k, + ", thread_n = ", thread_tfg.thread_n, + ", num_threads = ", thread_tfg.num_threads, " for MKN = [", prob_m, + ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", prob_m_split = ", prob_m_split, ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, + ", max_shared_mem_new = ", max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, + m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]", ", has_act_order = ", has_act_order, ", num_groups = ", num_groups, ", group_size = ", group_size, + ", prob_m_split = ", prob_m_split, ", thread_m_blocks = ", thread_m_blocks, ", thread_n_blocks = ", thread_n_blocks, ", thread_k_blocks = ", thread_k_blocks, - ", num_bits = ", num_bits); + ", num_threads = ", num_threads, ", num_bits = ", num_bits); } - A_ptr += 16 * thread_m_blocks * (lda / 8) * par; - C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem_new); + + bool part_use_atomic_add = + use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + // avoid ">>>" being formatted to "> > >" + // clang-format off + kernel<<>>( + A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, s2_ptr, zp_ptr, g_idx_ptr, num_groups, + prob_m_split, prob_n, prob_k, lda, locks, part_use_atomic_add, + use_fp32_reduce, max_shared_mem_new); + // clang-format on + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; } } } // namespace marlin -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeId const& b_q_type_id, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, bool use_atomic_add, - bool use_fp32_reduce, bool is_zp_float) { +torch::Tensor gptq_marlin_gemm( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { vllm::ScalarType const b_q_type = vllm::ScalarType::from_id(b_q_type_id); - if (has_zp) { - TORCH_CHECK( - b_q_type == vllm::kU4 || b_q_type == vllm::kU8, - "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); - } else { - TORCH_CHECK( - b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128, - "b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ", - b_q_type.str()); - } - - if (has_zp && is_zp_float) { - TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, - "Computation type must be float16 (half) when using float zero " - "points."); - } - int pack_factor = 32 / b_q_type.size_bits(); // Verify A @@ -2295,15 +656,19 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, ", size_k = ", size_k); // Verify B - TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k, - " is not divisible by tile_size = ", marlin::tile_size); - TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0), + TORCH_CHECK( + size_k % MARLIN_NAMESPACE_NAME::tile_size == 0, "size_k = ", size_k, + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK((size_k / MARLIN_NAMESPACE_NAME::tile_size) == b_q_weight.size(0), "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), - ", size_k = ", size_k, ", tile_size = ", marlin::tile_size); - TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0, - "b_q_weight.size(1) = ", b_q_weight.size(1), - " is not divisible by tile_size = ", marlin::tile_size); - int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor; + ", size_k = ", size_k, + ", tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + TORCH_CHECK( + b_q_weight.size(1) % MARLIN_NAMESPACE_NAME::tile_size == 0, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not divisible by tile_size = ", MARLIN_NAMESPACE_NAME::tile_size); + int actual_size_n = + (b_q_weight.size(1) / MARLIN_NAMESPACE_NAME::tile_size) * pack_factor; TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n, ", actual_size_n = ", actual_size_n); @@ -2320,63 +685,47 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); - TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); - TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); - - TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); - TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); - - TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); - TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel + int sms = -1; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); torch::Tensor c; - if (use_atomic_add) { - c = torch::zeros({size_m, size_n}, options); + if (c_or_none.has_value()) { + c = c_or_none.value(); + TORCH_CHECK(c.device().is_cuda(), "c is not on GPU"); + TORCH_CHECK(c.is_contiguous(), "c is not contiguous"); + TORCH_CHECK(c.size(0) == size_m, "Shape mismatch: c.size(0) = ", c.size(0), + ", size_m = ", size_m); + TORCH_CHECK(c.size(1) == size_n, "Shape mismatch: c.size(1) = ", c.size(1), + ", size_n = ", size_n); } else { c = torch::empty({size_m, size_n}, options); } - - torch::Tensor a_tmp; - bool has_act_order = g_idx.size(0) != 0; - if (has_act_order) { - a_tmp = torch::empty({size_m, size_k}, options); - } else { - a_tmp = torch::empty({0}, options); - } + if (size_m == 0) return c; // Alloc C tmp buffer that is going to be used for the global reduce torch::Tensor c_tmp; - int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par); - int reduce_n = size_n; auto options_fp32 = torch::TensorOptions().dtype(at::kFloat).device(a.device()); if (use_fp32_reduce) { - c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32); + int max_m_block_size = (size_m + 16 - 1) / 16 * 16; + max_m_block_size = min(max_m_block_size, 64); + int max_c_tmp_size = + sms * max_m_block_size * MARLIN_NAMESPACE_NAME::max_thread_n; + c_tmp = torch::empty({max_c_tmp_size}, options_fp32); } else { - reduce_max_m = 0; - reduce_n = 0; c_tmp = torch::empty({0}, options_fp32); } - // thread_k: `k` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_k = -1; - // thread_n: `n` size of a thread_tile in `weights` (can usually be left as - // auto -1) - int thread_n = -1; - // sms: number of SMs to use for the kernel (can usually be left as auto -1) - int sms = -1; - - // Verify g_idx and perm - TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || - (g_idx.size(0) == size_k && perm.size(0) == size_k), - "Unexpected g_idx.size(0) = ", g_idx.size(0), - " and perm.size(0) = ", perm.size(0), - ", where size_k = ", size_k); - // Detect groupsize and act_order int num_groups = -1; int group_size = -1; @@ -2387,7 +736,31 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, " is not size_n = ", size_n); num_groups = b_scales.size(0); + torch::Tensor g_idx, perm, a_tmp; + if (g_idx_or_none.has_value() && perm_or_none.has_value()) { + g_idx = g_idx_or_none.value(); + perm = perm_or_none.value(); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(-1) == 0 && perm.size(-1) == 0) || + (g_idx.size(-1) == size_k && perm.size(-1) == size_k), + "Unexpected g_idx.size(-1) = ", g_idx.size(-1), + " and perm.size(-1) = ", perm.size(-1), + ", where size_k = ", size_k); + } else { + g_idx = torch::empty({0}, options); + perm = torch::empty({0}, options); + a_tmp = torch::empty({0}, options); + } + bool has_act_order = g_idx.size(-1) > 0 && perm.size(-1) > 0; + if (has_act_order) { + a_tmp = torch::empty({size_m, size_k}, options); if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, @@ -2398,6 +771,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } else { + a_tmp = torch::empty({0}, options); if (num_groups > 1) { TORCH_CHECK( size_k % num_groups == 0, "size_k = ", size_k, @@ -2408,6 +782,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } } + torch::Tensor global_scale; + if (global_scale_or_none.has_value()) { + global_scale = global_scale_or_none.value(); + TORCH_CHECK(b_q_type == vllm::kFE2M1f, + "global_scale can only be used for float4_e2m1f."); + } else { + global_scale = torch::empty({0}, options); + TORCH_CHECK(!(b_q_type == vllm::kFE2M1f), + "the global_scale parameter must be passed for float4_e2m1f."); + } + + torch::Tensor b_zeros; + if (b_zeros_or_none.has_value()) { + b_zeros = b_zeros_or_none.value(); + TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU"); + TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous"); + } else { + b_zeros = torch::empty({0}, options); + } + bool has_zp = b_zeros.size(-1) > 0; + if (has_zp) { + TORCH_CHECK( + b_q_type == vllm::kU4 || b_q_type == vllm::kU8, + "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + TORCH_CHECK(b_q_type == vllm::kU4B8 || b_q_type == vllm::kU8B128 || + b_q_type == vllm::kFE4M3fn || b_q_type == vllm::kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or " + "float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + TORCH_CHECK(a.scalar_type() == at::ScalarType::Half, + "Computation type must be float16 (half) when using float zero " + "points."); + } + // Verify b_zeros if (has_zp) { int rank = b_zeros.sizes().size(); @@ -2431,34 +844,49 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } // Verify workspace size - TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n, - ", is not divisible by min_thread_n = ", marlin::min_thread_n); - int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par; + TORCH_CHECK(size_n % MARLIN_NAMESPACE_NAME::min_thread_n == 0, + "size_n = ", size_n, ", is not divisible by min_thread_n = ", + MARLIN_NAMESPACE_NAME::min_thread_n); + + int min_workspace_size = sms; TORCH_CHECK(workspace.numel() >= min_workspace_size, "workspace.numel = ", workspace.numel(), " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); if (a.scalar_type() == at::ScalarType::Half) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), - c_tmp.data_ptr(), b_scales.data_ptr(), + c_tmp.data_ptr(), scales_ptr, global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), - thread_k, thread_n, sms, marlin::max_par, use_atomic_add, - use_fp32_reduce, is_zp_float); + thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); } else if (a.scalar_type() == at::ScalarType::BFloat16) { + void* scales_ptr; + if (b_q_type == vllm::kFE2M1f) { + scales_ptr = b_scales.data_ptr(); + } else { + scales_ptr = b_scales.data_ptr(); + } + marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), - c.data_ptr(), c_tmp.data_ptr(), - b_scales.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), - perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, - a.stride(0), workspace.data_ptr(), b_q_type, has_act_order, is_k_full, - has_zp, num_groups, group_size, dev, + c.data_ptr(), c_tmp.data_ptr(), scales_ptr, + global_scale.data_ptr(), b_zeros.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, + has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, - marlin::max_par, use_atomic_add, use_fp32_reduce, is_zp_float); + use_atomic_add, use_fp32_reduce, is_zp_float); } else { TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); } diff --git a/csrc/quantization/gptq_marlin/kernel.h b/csrc/quantization/gptq_marlin/kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f92056589d206f5c361da947216fa7bb9f8cc916 --- /dev/null +++ b/csrc/quantization/gptq_marlin/kernel.h @@ -0,0 +1,38 @@ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "core/scalar_type.hpp" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, \ + int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, \ + const uint16_t *__restrict__ scale2_ptr, \ + const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ + int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace MARLIN_NAMESPACE_NAME { +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} diff --git a/csrc/quantization/gptq_marlin/marlin_template.h b/csrc/quantization/gptq_marlin/marlin_template.h new file mode 100644 index 0000000000000000000000000000000000000000..e416d5a76a410eba0a2ae054e42ccbf5ca871da7 --- /dev/null +++ b/csrc/quantization/gptq_marlin/marlin_template.h @@ -0,0 +1,1731 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin +#endif + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" +#include "dequant.h" +#include "core/scalar_type.hpp" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || \ + std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace MARLIN_NAMESPACE_NAME { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, + const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" + : "=r"(a[0]) + : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = + ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub( + typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, + typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = + ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, + typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, + typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, + typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +template shared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == vllm::kU4 || w_type == vllm::kU8; + constexpr bool is_int_type = w_type == vllm::kU4 || w_type == vllm::kU8 || + w_type == vllm::kU4B8 || w_type == vllm::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = + !is_int_type || + has_zp && !is_zp_float && !std::is_same::value || + has_zp && !is_zp_float && !(w_type == vllm::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == vllm::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { + locks_off++; + } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = + div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == vllm::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float + ? 16 * thread_n_blocks / 8 + : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / + (w_type == vllm::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == vllm::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && + (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh; + int4* sh_red = sh; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) + : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= + stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { + sh_num_groups = act_s_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm( + frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = + reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = + k_blocks / (group_blocks * (w_type == vllm::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != vllm::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast( + sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + #pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = + sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { + dequant(q, frag_b_ptr); + }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = + ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) || + (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { + reinterpret_cast(&frag_zp)[0] = + reinterpret_cast(&frag_zpf[k2])[0]; + } + } + + if constexpr (w_type == vllm::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, + reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales( + s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == vllm::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], + act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && + group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2( + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) + frag_zp[j] = __hmul2(frag_zp[j], + *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast( + &sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = + reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred(&sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred( + &sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m) || + (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { + delta = j % 2 == 1 ? -2 : 0; + } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { + return; + } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = + C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); + #pragma unroll + for (int f = 0; f < 4; f++) { + frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; + } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); + #pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = + Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4 && + (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == vllm::kFE2M1f) { + res = __hmul2(res, global_scale); + } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], + frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], + frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = + reinterpret_cast(&sh_red[c_sh_rd]); + #pragma unroll + for (int a = 0; a < 4; a++) { + atomicAdd(&C_half2[a], sh_red_half2[a]); + } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], + g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { + fetch_col_scale_to_shared(); + } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { + slice_k_start_shared_fetch += tb_k * (stages - 1); + } + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, + last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); + #pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2( + reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8 && + (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float( + reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float( + reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float( + reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float( + reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) + wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index ba0a2410c037c08aa603e11acef03bee0f48ab3c..ea96326ed7e61e95d27d12b26e02eaece0c073a8 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -96,8 +96,8 @@ __device__ inline FragB dequant(int q) { const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu index cd1830764cceb37cba2663640b3c9b735646f418..c96d68d9b29aaa38b27bfe82e9bd18d118595505 100644 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu @@ -141,8 +141,8 @@ __device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { static constexpr uint32_t HI = 0x00f000f0; static constexpr uint32_t EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. - uint32_t t0 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - uint32_t t1 = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. static constexpr uint32_t SUB = 0x64086408; diff --git a/csrc/quantization/marlin/sparse/common/mma.h b/csrc/quantization/marlin/sparse/common/mma.h index 49eee4128ee7c0a2ab0b96223ac2eeea60ab3ca8..b26505f771c8b1ee5eb8dbbf736a7d2a5d0243ec 100644 --- a/csrc/quantization/marlin/sparse/common/mma.h +++ b/csrc/quantization/marlin/sparse/common/mma.h @@ -127,8 +127,8 @@ __device__ inline FragB dequant_4bit(int q) { const int HI = 0x00f000f0; const int EX = 0x64006400; // Guarantee that the `(a & b) | c` operations are LOP3s. - int lo = lop3 < (0xf0 & 0xcc) | 0xaa > (q, LO, EX); - int hi = lop3 < (0xf0 & 0xcc) | 0xaa > (q, HI, EX); + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point // directly into `SUB` and `ADD`. const int SUB = 0x64086408; diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 2c3cae95e7f55fc560b4269e78265cbfac8c335a..8cc5a0f4f218683fbc6410a583e64624530443ce 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -25,8 +25,9 @@ #include "../attention/dtype_fp8.cuh" #include "../quantization/fp8/amd/quant_utils.cuh" -#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx942__)) - #define __HIP__MI300_MI250__ +#if defined(__HIPCC__) && \ + (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) + #define __HIP__GFX9__ #endif #if defined(NDEBUG) @@ -42,7 +43,7 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) -#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +#if defined(__HIP__GFX9__) // TODO: Add NAVI support #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 @@ -1286,7 +1287,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] const int* __restrict__ query_start_loc_ptr, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const auto num_heads = gridDim.x; const auto head_idx = blockIdx.x; const auto seq_idx = blockIdx.y; @@ -1464,8 +1465,10 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; - + acc *= out_scale; const int64_t query_start_off = static_cast( query_start_loc_ptr ? query_start_loc_ptr[seq_idx] : seq_idx); OUTT* out_ptr = out + query_start_off * num_heads * HEAD_SIZE + @@ -1479,7 +1482,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } -#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +#else // !defined(__HIP__GFX9__) TODO: Add NAVI support // clang-format off template \ <<>>( \ out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ - context_lens_ptr, query_start_loc_ptr, max_num_partitions); + context_lens_ptr, query_start_loc_ptr, max_num_partitions, \ + fp8_out_scale_ptr); template & query_start_loc, int max_context_len, const std::optional& alibi_slopes, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, const std::optional& fp8_out_scale) { int num_seqs = block_tables.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -1625,6 +1629,11 @@ void paged_attention_custom_launcher( int* context_lens_ptr = context_lens.data_ptr(); const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); + // NOTE: fp8_out_scale is optional. + const auto fp8_out_scale_ptr = + fp8_out_scale + ? static_cast(fp8_out_scale.value().data_ptr()) + : nullptr; OUTT* out_ptr = reinterpret_cast(out.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); @@ -1735,33 +1744,54 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, \ - ALIBI_ENABLED) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ - max_context_len, alibi_slopes, k_scale, v_scale); - -#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ - PSIZE) \ - if (alibi_slopes) { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, true); \ - } else { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, PSIZE, false); \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE, ALIBI_ENABLED) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, query_start_loc, \ + max_context_len, alibi_slopes, k_scale, v_scale, fp8_out_scale); + +#define CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT, PSIZE) \ + if (alibi_slopes) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + true); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, PSIZE, \ + false); \ } -#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ - switch (block_size) { \ - case 16: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 16, HEAD_SIZE, 256); \ - break; \ - case 32: \ - CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, 32, HEAD_SIZE, 256); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + uint8_t, 256); \ + } else { \ + CALL_CUSTOM_LAUNCHER_ALIBI(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T, \ + 256); \ + } +#endif + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ } #define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ @@ -1794,7 +1824,8 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale) { + torch::Tensor& v_scale, + const std::optional& fp8_out_scale) { // clang-format on const int head_size = query.size(2); if (kv_cache_dtype == "auto") { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index b90cfdc617afdbef58879ae1ad5ad2ee008878ed..e538197dbcb0423268032262633f32c349705281 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,14 +11,12 @@ torch::Tensor wvSplitK(at::Tensor& in_a, at::Tensor& in_b, void wvSplitKQ(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, at::Tensor& scale_a, at::Tensor& scale_b, const int64_t CuCount); -void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, - double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, - const std::optional& query_start_loc, - int64_t block_size, int64_t max_context_len, - const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, torch::Tensor& k_scale, - torch::Tensor& v_scale); +void paged_attention( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + const std::optional& query_start_loc, int64_t block_size, + int64_t max_context_len, const std::optional& alibi_slopes, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const std::optional& fp8_out_scale); diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 72d2820f2aabfdbc8c9eb3c4aa9a014074164e33..b3717892db784a125a8795223abd88c5be06b28e 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -126,8 +126,8 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, const int warp = threadIdx.x / WARP_SIZE; const int lane = threadIdx.x % WARP_SIZE; const int num_warps = blockDim.x / WARP_SIZE; - const int qwarpid = threadid / num_warps; - const int qthreadid = threadid % num_warps; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; scalar2_t colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; float acc[NUM_A_ROWS_PER_BLOCK]; @@ -142,15 +142,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, // rowA_elem4[i] holds 8 * half numbers seen as a single float4. rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; } - colB_elem4x = bf4[threadid * 4 + 0]; - colB_elem4y = bf4[threadid * 4 + 1]; - colB_elem4z = bf4[threadid * 4 + 2]; - colB_elem4w = bf4[threadid * 4 + 3]; - scalar2_t Af2; - [[maybe_unused]] scalar2_t Bf2; float2 S; auto Ah2ptr = reinterpret_cast(&rowA_elem4); @@ -193,12 +191,13 @@ __global__ void LLGemm1_kernel(const scalar_t* in_a, const scalar_t* in_b, if (qwarpid < NUM_A_ROWS_PER_BLOCK) { acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; - for (int mask = num_warps / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); } - float oval2 = __shfl_xor(acc[qwarpid], num_warps); + float oval2 = __shfl_xor(acc[qwarpid], 16); - if (lane % (num_warps * 2) == 0) { + if (lane % 32 == 0) { oval = __float22s2_rn(make_float2(acc[qwarpid], oval2)); c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; } @@ -222,9 +221,10 @@ torch::Tensor LLMM1(at::Tensor& in_a, at::Tensor& in_b, // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle // operations. const int NUM_THREADS = - K * 2 / 16 % WARP_SIZE == 0 - ? K * 2 / 16 - : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + max(rows_per_block * 16, + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE)); int NUM_BLOCKS = M / rows_per_block; @@ -275,13 +275,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_sml_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -318,6 +327,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -343,7 +353,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -374,24 +388,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -419,32 +417,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (uint32_t n = 0; n < N; n++) { #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]) - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -453,37 +436,84 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - // if (commitColumn[i]) C[m + i + n * M] = __float2half(sum[n][i]); - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } - m += CuCount * _WvPrGrp * YTILE; } } @@ -505,13 +535,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -573,6 +612,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (threadIdx.y >= _WvPrGrp) return; float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -598,7 +638,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -628,24 +672,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -676,32 +704,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -710,34 +723,82 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + // float accm1 = 0; + // for (int i=0; i<64; i++) + // accm1 += __shfl(sum4[n][y][i%4], i); + + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } @@ -774,14 +835,22 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) wvSplitK_hf_big_(const int K, const int M, const scalar_t* B, const scalar_t* __restrict__ A, scalar_t* C, const int _WvPrGrp, const int CuCount) { + #if defined(__HIP__MI300__) + constexpr bool use_mfma = (std::is_same_v); + #else + constexpr bool use_mfma = false; + #endif + using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; - + using half4 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(__bf16)))) __bf16; union bigType { scalar_t h[A_CHUNK]; float f[A_CHUNK / 2]; float2 f2[A_CHUNK / 4]; double d[A_CHUNK / 4]; + half4 h4[A_CHUNK / 4]; scalar8 h8; }; @@ -857,6 +926,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) kFit = min(kFit, K); float sum[N][YTILE]; + scalar8 sum4[N][YTILE]; //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -888,7 +958,11 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // are being worked on by each wave. //---------------------------------------------------- for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) sum[n][i] = 0; + for (int n = 0; n < N; n++) + if constexpr (!use_mfma) + sum[n][i] = 0; + else + sum4[n][i] = {0, 0, 0, 0}; bigType bigA[N][UNRL]; bigType bigB[YTILE][UNRL]; @@ -937,24 +1011,8 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if (k_ >= K) break; const scalar_t* B_ = &B[(m + 0) * K + k_]; - bigB[0][k2].h8 = (loadnt((scalar8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- - if constexpr (YTILE >= 2) - bigB[1][k2].h8 = (loadnt((scalar8*)(&B_[1 * K]))); - if constexpr (YTILE >= 3) - bigB[2][k2].h8 = (loadnt((scalar8*)(&B_[2 * K]))); - if constexpr (YTILE >= 4) - bigB[3][k2].h8 = (loadnt((scalar8*)(&B_[3 * K]))); - if constexpr (YTILE >= 5) - bigB[4][k2].h8 = (loadnt((scalar8*)(&B_[4 * K]))); - if constexpr (YTILE >= 6) - bigB[5][k2].h8 = (loadnt((scalar8*)(&B_[5 * K]))); - if constexpr (YTILE >= 7) - bigB[6][k2].h8 = (loadnt((scalar8*)(&B_[6 * K]))); - if constexpr (YTILE >= 8) - bigB[7][k2].h8 = (loadnt((scalar8*)(&B_[7 * K]))); + for (int b = 0; b < YTILE; b++) + bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -989,32 +1047,17 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) { - DOT2C(sum[n][0], bigA[n][k2].f[b], bigB[0][k2].f[b]); - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- - if constexpr (YTILE >= 2) { - DOT2C(sum[n][1], bigA[n][k2].f[b], bigB[1][k2].f[b]); - } - if constexpr (YTILE >= 3) { - DOT2C(sum[n][2], bigA[n][k2].f[b], bigB[2][k2].f[b]); - } - if constexpr (YTILE >= 4) { - DOT2C(sum[n][3], bigA[n][k2].f[b], bigB[3][k2].f[b]); - } - if constexpr (YTILE >= 5) { - DOT2C(sum[n][4], bigA[n][k2].f[b], bigB[4][k2].f[b]); - } - if constexpr (YTILE >= 6) { - DOT2C(sum[n][5], bigA[n][k2].f[b], bigB[5][k2].f[b]); - } - if constexpr (YTILE >= 7) { - DOT2C(sum[n][6], bigA[n][k2].f[b], bigB[6][k2].f[b]); - } - if constexpr (YTILE >= 8) { - DOT2C(sum[n][7], bigA[n][k2].f[b], bigB[7][k2].f[b]); - } + for (int y = 0; y < YTILE; y++) { + if constexpr (!use_mfma) + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) + } + else + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 4; b++) + sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( + bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); } } } @@ -1031,34 +1074,78 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- - for (int n = 0; n < N; n++) { - for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + if constexpr (!use_mfma) { + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[n][y]) + : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + } } - } - if (threadIdx.x == 63) { + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) + C[m + i + n * M] = __float2s(sum[n][i]); + } + } + } + } else { + #pragma unroll for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) - C[m + i + n * M] = __float2s(sum[n][i]); + #pragma unroll + for (int y = 0; y < YTILE; y++) { + float accm = sum4[n][y][0]; + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " + : "=v"(accm) + : "0"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(accm) + : "0"(accm), "v"(accm), "v"(accm)); + + sum4[n][y][0] = accm; + } + } + if (threadIdx.x == 63) { + for (int n = 0; n < N; n++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[n][i]); + C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + } } } } diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 4ac6fd1e994081a5a88e2b6ac28bb1b45f07fdb6..34575477bcc94346f2c31268e40a2103053a9b9e 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -47,7 +47,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " Tensor k_scale, Tensor v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale," + " Tensor? fp8_out_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 8025d926bcadabc28e3844cbb6634a4b299568cf..9259b65a90ba13775e92a647b85db1ce85139aa9 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -77,13 +77,40 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor suffix_output," " Tensor suffix_lse) -> ()"); ops.impl("merge_attn_states", torch::kCUDA, &merge_attn_states); + + ops.def( + "convert_vertical_slash_indexes(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes", torch::kCUDA, + &convert_vertical_slash_indexes); + + ops.def( + "convert_vertical_slash_indexes_mergehead(" + " Tensor! block_count, Tensor! block_offset, " + " Tensor! column_count, Tensor! column_index, " + " Tensor q_seqlens, Tensor q_seqlens, " + " Tensor vertical_indexes, Tensor slash_indexes, " + " Tensor vertical_indices_count, Tensor slash_indices_count, " + " int context_size, int block_size_M, int block_size_N, " + " bool causal) -> ()"); + ops.impl("convert_vertical_slash_indexes_mergehead", torch::kCUDA, + &convert_vertical_slash_indexes_mergehead); #endif // Activation ops // Activation function used in SwiGLU. - ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()"); + ops.def("silu_and_mul(Tensor! result, Tensor input) -> ()"); ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul); + ops.def( + "silu_and_mul_quant(Tensor! result, Tensor input, Tensor scale) -> ()"); + ops.impl("silu_and_mul_quant", torch::kCUDA, &silu_and_mul_quant); + ops.def("mul_and_silu(Tensor! out, Tensor input) -> ()"); ops.impl("mul_and_silu", torch::kCUDA, &mul_and_silu); @@ -130,13 +157,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ") -> ()"); ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); - // Compute MLA decode using cutlass. -// ops.def( -// "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," -// " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," -// " Tensor page_table, float scale) -> ()"); -// ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); - // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. ops.def( @@ -179,7 +199,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // Apply GPT-NeoX or GPT-J style rotary embedding to query and key. ops.def( "rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCUDA, &rotary_embedding); @@ -187,7 +207,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // (supports multiple loras). ops.def( "batched_rotary_embedding(Tensor positions, Tensor! query," - " Tensor! key, int head_size," + " Tensor!? key, int head_size," " Tensor cos_sin_cache, bool is_neox," " int rot_dim," " Tensor cos_sin_cache_offsets) -> ()"); @@ -298,12 +318,11 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // gptq_marlin Optimized Quantized GEMM for GPTQ. ops.def( - "gptq_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor b_zeros, Tensor g_idx, Tensor perm, Tensor workspace, " - "int b_q_type, " + "gptq_marlin_gemm(Tensor a, Tensor? c_or_none, Tensor b_q_weight, " + "Tensor b_scales, Tensor? global_scale, Tensor? b_zeros_or_none, Tensor? " + "g_idx_or_none, Tensor? perm_or_none, Tensor workspace, int b_q_type, " "SymInt size_m, SymInt size_n, SymInt size_k, bool is_k_full, " - "bool has_zp, bool use_atomic_add, bool use_fp32_reduce, " - "bool is_zp_float) -> Tensor", + "bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) -> Tensor", {stride_tag}); // conditionally compiled so impl registration is in source file @@ -345,17 +364,15 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int type, SymInt row, SymInt top_k, SymInt tokens) -> Tensor"); ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8); + ops.def( + "ggml_moe_a8_vec(Tensor X, Tensor W, " + "Tensor topk_ids, int top_k, " + "int type, SymInt row, SymInt tokens) -> Tensor"); + ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec); + ops.def("ggml_moe_get_block_size", &ggml_moe_get_block_size); #ifndef USE_ROCM - // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. - ops.def( - "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " - "Tensor! workspace, int num_bits, SymInt size_m, SymInt size_n, " - "SymInt size_k) -> Tensor", - {stride_tag}); - // conditionally compiled so impl registration is in source file - // marlin_qqq_gemm for QQQ. ops.def( "marlin_qqq_gemm(Tensor a, Tensor b_q_weight, " @@ -373,6 +390,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm); + // cutlass nvfp4 block scaled group GEMM + ops.def( + "cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b," + " Tensor a_blockscale, Tensor b_blockscales, Tensor alphas," + " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()", + {stride_tag}); + ops.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm); + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( @@ -454,6 +479,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_sparse_compress(Tensor a) -> Tensor[]"); ops.impl("cutlass_sparse_compress", &cutlass_sparse_compress); + // CUTLASS MLA decode + ops.def( + "cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe," + " Tensor kv_c_and_k_pe_cache, Tensor seq_lens," + " Tensor page_table, float scale) -> ()"); + ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode); + // Mamba selective scan kernel ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," @@ -495,6 +527,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! output_scale, Tensor input_scale) -> ()"); ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant); + // Compute NVFP4 experts quantization. + ops.def( + "scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale," + "Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts," + "Tensor output_scale_offset_by_experts) -> ()"); + ops.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant); + // Check if cutlass_scaled_mm_fp4 is supported for CUDA devices // of the given capability ops.def("cutlass_scaled_mm_supports_fp4(int cuda_device_capability) -> bool"); diff --git a/docker/Dockerfile b/docker/Dockerfile index 1b28845d0ac04ac0b277360c58f82a0e6b1cf347..97a7879da8767c2935074df4c65df9fc1b11514a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -5,11 +5,11 @@ # docs/source/contributing/dockerfile/dockerfile.md and # docs/source/assets/contributing/dockerfile-stages-dependency.png -ARG CUDA_VERSION=12.4.1 +ARG CUDA_VERSION=12.8.1 #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base -ARG CUDA_VERSION=12.4.1 +ARG CUDA_VERSION=12.8.1 ARG PYTHON_VERSION=3.12 ARG TARGETPLATFORM ENV DEBIAN_FRONTEND=noninteractive @@ -19,7 +19,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl sudo \ - && add-apt-repository ppa:deadsnakes/ppa \ + && for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ @@ -34,6 +37,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" # Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 # as it was causing spam when compiling the CUTLASS kernels @@ -66,13 +70,14 @@ RUN --mount=type=cache,target=/root/.cache/uv \ COPY requirements/common.txt requirements/common.txt COPY requirements/cuda.txt requirements/cuda.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/cuda.txt + uv pip install --system -r requirements/cuda.txt \ + --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # cuda arch list used by torch # can be useful for both `dev` and `test` # explicitly set the list to avoid issues with torch 2.2 # see https://github.com/pytorch/pytorch/pull/123243 -ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' +ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX' ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # Override the arch list for flash-attn to reduce the binary size ARG vllm_fa_cmake_gpu_arches='80-real;90-real' @@ -89,9 +94,11 @@ COPY requirements/build.txt requirements/build.txt # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/build.txt + uv pip install --system -r requirements/build.txt \ + --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') COPY . . ARG GIT_REPO_CHECK=0 @@ -158,22 +165,25 @@ FROM base as dev # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" + +# Workaround for #17068 +RUN --mount=type=cache,target=/root/.cache/uv \ + uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" COPY requirements/lint.txt requirements/lint.txt COPY requirements/test.txt requirements/test.txt COPY requirements/dev.txt requirements/dev.txt -# Workaround for #17068 RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system mamba-ssm==2.2.4 --no-build-isolation -RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/dev.txt + uv pip install --system -r requirements/dev.txt \ + --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') #################### DEV IMAGE #################### #################### vLLM installation IMAGE #################### # image with vLLM installed # TODO: Restore to base image after FlashInfer AOT wheel fixed FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base -ARG CUDA_VERSION=12.4.1 +ARG CUDA_VERSION=12.8.1 ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive @@ -188,7 +198,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ - && add-apt-repository ppa:deadsnakes/ppa \ + && for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ @@ -203,6 +216,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully @@ -223,7 +237,8 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install vllm wheel first, so that torch etc will be installed. RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \ --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system dist/*.whl --verbose + uv pip install --system dist/*.whl --verbose \ + --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') # If we need to build FlashInfer wheel before its release: # $ export FLASHINFER_ENABLE_AOT=1 @@ -240,19 +255,32 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist RUN --mount=type=cache,target=/root/.cache/uv \ . /etc/environment && \ if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \ - uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.1.post2/flashinfer_python-0.2.1.post2+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ + # uv pip install --system https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.4/flashinfer_python-0.2.4+cu124torch2.6-cp38-abi3-linux_x86_64.whl ; \ + # TESTING: install FlashInfer from source to test 2.7.0 final RC + if [[ "$CUDA_VERSION" == 12.8* ]]; then \ + export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'; \ + else \ + export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX'; \ + fi && \ + export FLASHINFER_ENABLE_AOT=1; \ + uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@21ea1d2545f74782b91eb8c08fd503ac4c0743fc" ; \ fi COPY examples examples COPY benchmarks benchmarks COPY ./vllm/collect_env.py . +RUN --mount=type=cache,target=/root/.cache/uv \ +. /etc/environment && \ +uv pip list + # Although we build Flashinfer with AOT mode, there's still # some issues w.r.t. JIT compilation. Therefore we need to # install build dependencies for JIT compilation. # TODO: Remove this once FlashInfer AOT wheel is fixed COPY requirements/build.txt requirements/build.txt RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system -r requirements/build.txt + uv pip install --system -r requirements/build.txt \ + --extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') #################### vLLM installation IMAGE #################### @@ -266,11 +294,13 @@ ADD . /vllm-workspace/ # This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out # Reference: https://github.com/astral-sh/uv/pull/1694 ENV UV_HTTP_TIMEOUT=500 +ENV UV_INDEX_STRATEGY="unsafe-best-match" -# install development dependencies (for testing) # Workaround for #17068 RUN --mount=type=cache,target=/root/.cache/uv \ - uv pip install --system mamba-ssm==2.2.4 --no-build-isolation + uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4" + +# install development dependencies (for testing) RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/dev.txt diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 0063712e47818fa08cca5ae350024140961d9fd5..53b8ccd804924da807e1e378b62eb8f9428840cf 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -16,7 +16,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl sudo \ - && add-apt-repository ppa:deadsnakes/ppa \ + && for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ @@ -197,7 +200,10 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && apt-get update -y \ && apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \ && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ - && add-apt-repository ppa:deadsnakes/ppa \ + && for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ @@ -303,5 +309,7 @@ ENV HF_HUB_ENABLE_HF_TRANSFER 1 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system -r requirements/nightly_torch_test.txt -#################### UNITTEST IMAGE ############################# +# Logging to confirm the torch versions +RUN pip freeze | grep -E 'torch|xformers|vllm|flashinfer' +#################### UNITTEST IMAGE ############################# diff --git a/docker/Dockerfile.ppc64le b/docker/Dockerfile.ppc64le index ec979227871c63a1c0bb7a345509e87caebddba9..14043eb7a8e3ba6ebca87b638e59525f83117d6f 100644 --- a/docker/Dockerfile.ppc64le +++ b/docker/Dockerfile.ppc64le @@ -21,12 +21,8 @@ ENV UV_LINK_MODE=copy # Note: A dummy file 'control' is created in /tmp/ to artificially create dependencies between stages when building stages in parallel # when `--jobs=` is passed with podman build command RUN microdnf install -y openssl-devel dnf \ - && dnf install -y https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-gpg-keys-9.0-24.el9.noarch.rpm \ - https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os/Packages/centos-stream-repos-9.0-24.el9.noarch.rpm \ - https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \ - && dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/BaseOS/`arch`/os \ - && dnf config-manager --add-repo https://mirror.stream.centos.org/9-stream/AppStream/`arch`/os \ - && dnf config-manager --set-enabled crb \ + && dnf install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-9.noarch.rpm \ + && dnf config-manager --set-enabled codeready-builder-for-rhel-9-ppc64le-rpms \ && dnf install -y \ git tar gcc-toolset-13 automake libtool numactl-devel lapack-devel \ pkgconfig xsimd zeromq-devel kmod findutils protobuf* \ diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index f9ebb10ca8731ba318f416d7ed61e3250d26437c..e60cf5e69a4c46d258859d7700e74a66201ab87d 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -114,8 +114,16 @@ COPY --from=export_vllm /examples ${COMMON_WORKDIR}/vllm/examples ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false +# ENV that can improve safe tensor loading, and end-to-end time +ENV SAFETENSORS_FAST_GPU=1 + +# User-friendly environment setting for multi-processing to avoid below RuntimeError. +# RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, +# you must use the 'spawn' start method +# See https://pytorch.org/docs/stable/notes/multiprocessing.html#cuda-in-multiprocessing +ENV VLLM_WORKER_MULTIPROC_METHOD=spawn + # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 CMD ["/bin/bash"] - diff --git a/docker/Dockerfile.rocm_base b/docker/Dockerfile.rocm_base index 1776b26d445ce0b9b404f3860b93782686e352cd..222b9c158e5e0dec2bcaae06fc642c45b4d6164d 100644 --- a/docker/Dockerfile.rocm_base +++ b/docker/Dockerfile.rocm_base @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git" ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git" ARG FA_BRANCH="1a7f4dfa" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" -ARG AITER_BRANCH="7e1ed08" +ARG AITER_BRANCH="5a77249" ARG AITER_REPO="https://github.com/ROCm/aiter.git" FROM ${BASE_IMAGE} AS base @@ -32,7 +32,10 @@ ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies RUN apt-get update -y \ && apt-get install -y software-properties-common git curl sudo vim less libgfortran5 \ - && add-apt-repository ppa:deadsnakes/ppa \ + && for i in 1 2 3; do \ + add-apt-repository -y ppa:deadsnakes/ppa && break || \ + { echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \ + done \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ python${PYTHON_VERSION}-lib2to3 python-is-python3 \ diff --git a/docker/Dockerfile.s390x b/docker/Dockerfile.s390x index 128929ac333113fc94c67d160d7c96180ea9af4f..9c10cd56b5949929c8a182a6c4caf447d752e7bf 100644 --- a/docker/Dockerfile.s390x +++ b/docker/Dockerfile.s390x @@ -16,7 +16,7 @@ ENV LANG=C.UTF-8 \ RUN microdnf install -y \ which procps findutils tar vim git gcc gcc-gfortran g++ make patch zlib-devel \ libjpeg-turbo-devel libtiff-devel libpng-devel libwebp-devel freetype-devel harfbuzz-devel \ - openssl-devel openblas openblas-devel autoconf automake libtool cmake && \ + openssl-devel openblas openblas-devel autoconf automake libtool cmake numpy && \ microdnf clean all # Python Installation @@ -123,6 +123,7 @@ ENV UV_LINK_MODE=copy ENV CARGO_HOME=/root/.cargo ENV RUSTUP_HOME=/root/.rustup ENV PATH="$CARGO_HOME/bin:$RUSTUP_HOME/bin:$PATH" +ENV GRPC_PYTHON_BUILD_SYSTEM_OPENSSL=1 COPY . /workspace/vllm WORKDIR /workspace/vllm diff --git a/docker/Dockerfile.tpu b/docker/Dockerfile.tpu index 50806d8820a301990e403a2d2dec8dc66caca792..295270d29f7656da796cc1c14019a8efaedd27cc 100644 --- a/docker/Dockerfile.tpu +++ b/docker/Dockerfile.tpu @@ -23,7 +23,7 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 -m pip install \ -r requirements/tpu.txt -RUN python3 setup.py develop +RUN python3 -m pip install -e . # install development dependencies (for testing) RUN python3 -m pip install -e tests/vllm_test_utils diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index ad4abf16b43b6ae658192e9e2eab1e4bee5b531a..681102b9d18be25991985ef5860501b6357256f5 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -40,12 +40,6 @@ RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ python3 setup.py install -# Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu -# FIXME: This will be fix in ipex 2.7. just leave this here for awareness. -RUN --mount=type=cache,target=/root/.cache/pip \ - pip install intel-extension-for-pytorch==2.6.10+xpu \ - --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ - CMD ["/bin/bash"] FROM vllm-base AS vllm-openai diff --git a/docs/Makefile b/docs/Makefile index 5b801f79d1f26e776917f5ae2da3c80c75837af7..d3b429dfb92578c6c23d4140a0c2b6ccce10d744 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -22,3 +22,4 @@ help: clean: @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) rm -rf "$(SOURCEDIR)/getting_started/examples" + rm -rf "$(SOURCEDIR)/api/vllm" diff --git a/docs/source/api/engine/async_llm_engine.md b/docs/source/api/engine/async_llm_engine.md deleted file mode 100644 index 904feaa5051647e85077b0619b9cc4ea7325296e..0000000000000000000000000000000000000000 --- a/docs/source/api/engine/async_llm_engine.md +++ /dev/null @@ -1,7 +0,0 @@ -# AsyncLLMEngine - -```{eval-rst} -.. autoclass:: vllm.AsyncLLMEngine - :members: - :show-inheritance: -``` diff --git a/docs/source/api/engine/index.md b/docs/source/api/engine/index.md deleted file mode 100644 index b6544d94afdf8e18b0de73691e4d256ba84cfe78..0000000000000000000000000000000000000000 --- a/docs/source/api/engine/index.md +++ /dev/null @@ -1,17 +0,0 @@ -# vLLM Engine - -```{eval-rst} -.. automodule:: vllm.engine -``` - -```{eval-rst} -.. currentmodule:: vllm.engine -``` - -:::{toctree} -:caption: Engines -:maxdepth: 2 - -llm_engine -async_llm_engine -::: diff --git a/docs/source/api/engine/llm_engine.md b/docs/source/api/engine/llm_engine.md deleted file mode 100644 index d6613ef5562dce7e98a383ddbe0b02f62b8121ab..0000000000000000000000000000000000000000 --- a/docs/source/api/engine/llm_engine.md +++ /dev/null @@ -1,7 +0,0 @@ -# LLMEngine - -```{eval-rst} -.. autoclass:: vllm.LLMEngine - :members: - :show-inheritance: -``` diff --git a/docs/source/api/inference_params.md b/docs/source/api/inference_params.md deleted file mode 100644 index 181c30cab9c4a3515a8fe1952995164a15524e44..0000000000000000000000000000000000000000 --- a/docs/source/api/inference_params.md +++ /dev/null @@ -1,21 +0,0 @@ -# Inference Parameters - -Inference parameters for vLLM APIs. - -(sampling-params)= - -## Sampling Parameters - -```{eval-rst} -.. autoclass:: vllm.SamplingParams - :members: -``` - -(pooling-params)= - -## Pooling Parameters - -```{eval-rst} -.. autoclass:: vllm.PoolingParams - :members: -``` diff --git a/docs/source/api/model/adapters.md b/docs/source/api/model/adapters.md deleted file mode 100644 index e103a51d0070d1187f226132d5df39f50ffb2b4c..0000000000000000000000000000000000000000 --- a/docs/source/api/model/adapters.md +++ /dev/null @@ -1,9 +0,0 @@ -# Model Adapters - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.model_executor.models.adapters - :members: - :member-order: bysource -``` diff --git a/docs/source/api/model/index.md b/docs/source/api/model/index.md deleted file mode 100644 index 8fee3a55c93de82d7ba12b871f9711795dd1851b..0000000000000000000000000000000000000000 --- a/docs/source/api/model/index.md +++ /dev/null @@ -1,11 +0,0 @@ -# Model Development - -## Submodules - -:::{toctree} -:maxdepth: 1 - -interfaces_base -interfaces -adapters -::: diff --git a/docs/source/api/model/interfaces.md b/docs/source/api/model/interfaces.md deleted file mode 100644 index 55bee57f64faa4d2c4da336202100f5c67edfe9c..0000000000000000000000000000000000000000 --- a/docs/source/api/model/interfaces.md +++ /dev/null @@ -1,9 +0,0 @@ -# Optional Interfaces - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.model_executor.models.interfaces - :members: - :member-order: bysource -``` diff --git a/docs/source/api/model/interfaces_base.md b/docs/source/api/model/interfaces_base.md deleted file mode 100644 index 75d58d34228e9eaf925a86adfcad56a785eaca7f..0000000000000000000000000000000000000000 --- a/docs/source/api/model/interfaces_base.md +++ /dev/null @@ -1,9 +0,0 @@ -# Base Model Interfaces - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.model_executor.models.interfaces_base - :members: - :member-order: bysource -``` diff --git a/docs/source/api/multimodal/index.md b/docs/source/api/multimodal/index.md deleted file mode 100644 index 069ed53e545c5648ac335a10751751045a9bf0e1..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/index.md +++ /dev/null @@ -1,28 +0,0 @@ -(multi-modality)= - -# Multi-Modality - -vLLM provides experimental support for multi-modal models through the {mod}`vllm.multimodal` package. - -Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models) -via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`. - -Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal). - -## Module Contents - -```{eval-rst} -.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY -``` - -## Submodules - -:::{toctree} -:maxdepth: 1 - -inputs -parse -processing -profiling -registry -::: diff --git a/docs/source/api/multimodal/inputs.md b/docs/source/api/multimodal/inputs.md deleted file mode 100644 index 21bd938be9e89e468f74dd187ef64f2895232857..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/inputs.md +++ /dev/null @@ -1,49 +0,0 @@ -# Input Definitions - -## User-facing inputs - -```{eval-rst} -.. autodata:: vllm.multimodal.inputs.MultiModalDataDict -``` - -## Internal data structures - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.PlaceholderRange - :members: - :show-inheritance: -``` - -```{eval-rst} -.. autodata:: vllm.multimodal.inputs.NestedTensors -``` - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.MultiModalFieldElem - :members: - :show-inheritance: -``` - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.MultiModalFieldConfig - :members: - :show-inheritance: -``` - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.MultiModalKwargsItem - :members: - :show-inheritance: -``` - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.MultiModalKwargs - :members: - :show-inheritance: -``` - -```{eval-rst} -.. autoclass:: vllm.multimodal.inputs.MultiModalInputs - :members: - :show-inheritance: -``` diff --git a/docs/source/api/multimodal/parse.md b/docs/source/api/multimodal/parse.md deleted file mode 100644 index 4676139efe6260a5568c2f0cd71790c7f3e18b9e..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/parse.md +++ /dev/null @@ -1,9 +0,0 @@ -# Data Parsing - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.multimodal.parse - :members: - :member-order: bysource -``` diff --git a/docs/source/api/multimodal/processing.md b/docs/source/api/multimodal/processing.md deleted file mode 100644 index 0d81c8d3966ee8da607db4668d605a8c8c7c2e3c..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/processing.md +++ /dev/null @@ -1,9 +0,0 @@ -# Data Processing - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.multimodal.processing - :members: - :member-order: bysource -``` diff --git a/docs/source/api/multimodal/profiling.md b/docs/source/api/multimodal/profiling.md deleted file mode 100644 index b455145212202f7ab16d78db5ab594598fc2555e..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/profiling.md +++ /dev/null @@ -1,9 +0,0 @@ -# Memory Profiling - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.multimodal.profiling - :members: - :member-order: bysource -``` diff --git a/docs/source/api/multimodal/registry.md b/docs/source/api/multimodal/registry.md deleted file mode 100644 index 0737a4385cf324f812f2f527c98a9f28c7ae7436..0000000000000000000000000000000000000000 --- a/docs/source/api/multimodal/registry.md +++ /dev/null @@ -1,9 +0,0 @@ -# Registry - -## Module Contents - -```{eval-rst} -.. automodule:: vllm.multimodal.registry - :members: - :member-order: bysource -``` diff --git a/docs/source/api/offline_inference/index.md b/docs/source/api/offline_inference/index.md deleted file mode 100644 index ec2cc599d923cc88301b044a6301665f7b3c2d47..0000000000000000000000000000000000000000 --- a/docs/source/api/offline_inference/index.md +++ /dev/null @@ -1,9 +0,0 @@ -# Offline Inference - -:::{toctree} -:caption: Contents -:maxdepth: 1 - -llm -llm_inputs -::: diff --git a/docs/source/api/offline_inference/llm.md b/docs/source/api/offline_inference/llm.md deleted file mode 100644 index 9f129d5e41686314aac02e9016399e8407159957..0000000000000000000000000000000000000000 --- a/docs/source/api/offline_inference/llm.md +++ /dev/null @@ -1,7 +0,0 @@ -# LLM Class - -```{eval-rst} -.. autoclass:: vllm.LLM - :members: - :show-inheritance: -``` diff --git a/docs/source/api/offline_inference/llm_inputs.md b/docs/source/api/offline_inference/llm_inputs.md deleted file mode 100644 index 21f688a12c5369023e242b133648342af3064952..0000000000000000000000000000000000000000 --- a/docs/source/api/offline_inference/llm_inputs.md +++ /dev/null @@ -1,19 +0,0 @@ -# LLM Inputs - -```{eval-rst} -.. autodata:: vllm.inputs.PromptType -``` - -```{eval-rst} -.. autoclass:: vllm.inputs.TextPrompt - :show-inheritance: - :members: - :member-order: bysource -``` - -```{eval-rst} -.. autoclass:: vllm.inputs.TokensPrompt - :show-inheritance: - :members: - :member-order: bysource -``` diff --git a/docs/source/api/summary.md b/docs/source/api/summary.md new file mode 100644 index 0000000000000000000000000000000000000000..46de545f9ded479f3b4c1a702d80f642eb050c3c --- /dev/null +++ b/docs/source/api/summary.md @@ -0,0 +1,133 @@ +# Summary + +(configuration)= + +## Configuration + +API documentation for vLLM's configuration classes. + +```{autodoc2-summary} + vllm.config.ModelConfig + vllm.config.CacheConfig + vllm.config.TokenizerPoolConfig + vllm.config.LoadConfig + vllm.config.ParallelConfig + vllm.config.SchedulerConfig + vllm.config.DeviceConfig + vllm.config.SpeculativeConfig + vllm.config.LoRAConfig + vllm.config.PromptAdapterConfig + vllm.config.MultiModalConfig + vllm.config.PoolerConfig + vllm.config.DecodingConfig + vllm.config.ObservabilityConfig + vllm.config.KVTransferConfig + vllm.config.CompilationConfig + vllm.config.VllmConfig +``` + +(offline-inference-api)= + +## Offline Inference + +LLM Class. + +```{autodoc2-summary} + vllm.LLM +``` + +LLM Inputs. + +```{autodoc2-summary} + vllm.inputs.PromptType + vllm.inputs.TextPrompt + vllm.inputs.TokensPrompt +``` + +## vLLM Engines + +Engine classes for offline and online inference. + +```{autodoc2-summary} + vllm.LLMEngine + vllm.AsyncLLMEngine +``` + +## Inference Parameters + +Inference parameters for vLLM APIs. + +(sampling-params)= +(pooling-params)= + +```{autodoc2-summary} + vllm.SamplingParams + vllm.PoolingParams +``` + +(multi-modality)= + +## Multi-Modality + +vLLM provides experimental support for multi-modal models through the {mod}`vllm.multimodal` package. + +Multi-modal inputs can be passed alongside text and token prompts to [supported models](#supported-mm-models) +via the `multi_modal_data` field in {class}`vllm.inputs.PromptType`. + +Looking to add your own multi-modal model? Please follow the instructions listed [here](#supports-multimodal). + +```{autodoc2-summary} + vllm.multimodal.MULTIMODAL_REGISTRY +``` + +### Inputs + +User-facing inputs. + +```{autodoc2-summary} + vllm.multimodal.inputs.MultiModalDataDict +``` + +Internal data structures. + +```{autodoc2-summary} + vllm.multimodal.inputs.PlaceholderRange + vllm.multimodal.inputs.NestedTensors + vllm.multimodal.inputs.MultiModalFieldElem + vllm.multimodal.inputs.MultiModalFieldConfig + vllm.multimodal.inputs.MultiModalKwargsItem + vllm.multimodal.inputs.MultiModalKwargs + vllm.multimodal.inputs.MultiModalInputs +``` + +### Data Parsing + +```{autodoc2-summary} + vllm.multimodal.parse +``` + +### Data Processing + +```{autodoc2-summary} + vllm.multimodal.processing +``` + +### Memory Profiling + +```{autodoc2-summary} + vllm.multimodal.profiling +``` + +### Registry + +```{autodoc2-summary} + vllm.multimodal.registry +``` + +## Model Development + +```{autodoc2-summary} + vllm.model_executor.models.interfaces_base + vllm.model_executor.models.interfaces + vllm.model_executor.models.adapters +``` diff --git a/docs/source/assets/contributing/dockerfile-stages-dependency.png b/docs/source/assets/contributing/dockerfile-stages-dependency.png index 6ace54f6676203dc05aa2a9d44248b621771c8a2..0838bfa37fe62d60fba9adcbd18c81de0809f253 100644 Binary files a/docs/source/assets/contributing/dockerfile-stages-dependency.png and b/docs/source/assets/contributing/dockerfile-stages-dependency.png differ diff --git a/docs/source/assets/deployment/chatbox-chat.png b/docs/source/assets/deployment/chatbox-chat.png new file mode 100644 index 0000000000000000000000000000000000000000..b1718cb504717578dd36062759af8b834426483c Binary files /dev/null and b/docs/source/assets/deployment/chatbox-chat.png differ diff --git a/docs/source/assets/deployment/chatbox-settings.png b/docs/source/assets/deployment/chatbox-settings.png new file mode 100644 index 0000000000000000000000000000000000000000..a8e3d7b2894c720fdbcf7b6615ea0b1c892db409 Binary files /dev/null and b/docs/source/assets/deployment/chatbox-settings.png differ diff --git a/docs/source/assets/deployment/dify-chat.png b/docs/source/assets/deployment/dify-chat.png new file mode 100644 index 0000000000000000000000000000000000000000..dfea23309c1cfac44ea6c021fafdcaaab80633a3 Binary files /dev/null and b/docs/source/assets/deployment/dify-chat.png differ diff --git a/docs/source/assets/deployment/dify-create-chatbot.png b/docs/source/assets/deployment/dify-create-chatbot.png new file mode 100644 index 0000000000000000000000000000000000000000..07bbde5ba28541a4d64aea55335f0ecb7bfc7ff9 Binary files /dev/null and b/docs/source/assets/deployment/dify-create-chatbot.png differ diff --git a/docs/source/assets/deployment/dify-settings.png b/docs/source/assets/deployment/dify-settings.png new file mode 100644 index 0000000000000000000000000000000000000000..7900cc774741b9884869a5a38fbb0348f1b694a6 Binary files /dev/null and b/docs/source/assets/deployment/dify-settings.png differ diff --git a/docs/source/assets/deployment/streamlit-chat.png b/docs/source/assets/deployment/streamlit-chat.png new file mode 100644 index 0000000000000000000000000000000000000000..1e37b9d70e15df2d253319dcd0ebeb123ee719a0 Binary files /dev/null and b/docs/source/assets/deployment/streamlit-chat.png differ diff --git a/docs/source/autodoc2_docstring_parser.py b/docs/source/autodoc2_docstring_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..41c49ed1c545afb85a7b45417cf07b97941e5373 --- /dev/null +++ b/docs/source/autodoc2_docstring_parser.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +from docutils import nodes +from myst_parser.parsers.sphinx_ import MystParser +from sphinx.ext.napoleon import docstring + + +class NapoleonParser(MystParser): + + def parse(self, input_string: str, document: nodes.document) -> None: + # Get the Sphinx configuration + config = document.settings.env.config + + parsed_content = str( + docstring.GoogleDocstring( + str(docstring.NumpyDocstring(input_string, config)), + config, + )) + return super().parse(parsed_content, document) + + +Parser = NapoleonParser diff --git a/docs/source/community/meetups.md b/docs/source/community/meetups.md index 085918bed2b0981e1c4104517f2867ca585b0abf..aa1a71c86c0a6ce56666cb748879bb66324fa18f 100644 --- a/docs/source/community/meetups.md +++ b/docs/source/community/meetups.md @@ -4,6 +4,7 @@ We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- [NYC vLLM Meetup](https://lu.ma/c1rqyf1f), May 7th, 2025. [[Slides]](https://docs.google.com/presentation/d/1_q_aW_ioMJWUImf1s1YM-ZhjXz8cUeL0IJvaquOYBeA/edit?usp=sharing) - [Asia Developer Day](https://www.sginnovate.com/event/limited-availability-morning-evening-slots-remaining-inaugural-vllm-asia-developer-day), April 3rd 2025. [[Slides]](https://docs.google.com/presentation/d/19cp6Qu8u48ihB91A064XfaXruNYiBOUKrBxAmDOllOo/edit?usp=sharing). - [vLLM x Ollama Inference Night](https://lu.ma/vllm-ollama), March 27th 2025. [[Slides]](https://docs.google.com/presentation/d/16T2PDD1YwRnZ4Tu8Q5r6n53c5Lr5c73UV9Vd2_eBo4U/edit?usp=sharing). - [The first vLLM China Meetup](https://mp.weixin.qq.com/s/n77GibL2corAtQHtVEAzfg), March 16th 2025. [[Slides]](https://docs.google.com/presentation/d/1REHvfQMKGnvz6p3Fd23HhSO4c8j5WPGZV0bKYLwnHyQ/edit?usp=sharing). diff --git a/docs/source/conf.py b/docs/source/conf.py index c2ad6f9fa3a55d61fd499f3314911caeb1939351..5620d6de2c59be9af67680951641c89874844916 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,16 +13,17 @@ # documentation root, use os.path.abspath to make it absolute, like shown here. import datetime -import inspect import logging import os +import re import sys +from pathlib import Path import requests -from sphinx.ext import autodoc logger = logging.getLogger(__name__) -sys.path.append(os.path.abspath("../..")) +REPO_ROOT = Path(__file__).resolve().parent.parent.parent +sys.path.append(os.path.abspath(REPO_ROOT)) # -- Project information ----------------------------------------------------- @@ -40,8 +41,7 @@ extensions = [ "sphinx.ext.linkcode", "sphinx.ext.intersphinx", "sphinx_copybutton", - "sphinx.ext.autodoc", - "sphinx.ext.autosummary", + "autodoc2", "myst_parser", "sphinxarg.ext", "sphinx_design", @@ -49,7 +49,19 @@ extensions = [ ] myst_enable_extensions = [ "colon_fence", + "fieldlist", ] +autodoc2_packages = [ + { + "path": "../../vllm", + "exclude_dirs": ["__pycache__", "third_party"], + }, +] +autodoc2_output_dir = "api" +autodoc2_render_plugin = "myst" +autodoc2_hidden_objects = ["dunder", "private", "inherited"] +autodoc2_sort_names = True +autodoc2_index_template = None # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] @@ -77,6 +89,11 @@ html_theme_options = { 'repository_url': 'https://github.com/vllm-project/vllm', 'use_repository_button': True, 'use_edit_page_button': True, + # Prevents the full API being added to the left sidebar of every page. + # Reduces build time by 2.5x and reduces build size from ~225MB to ~95MB. + 'collapse_navbar': True, + # Makes API visible in the right sidebar on API reference pages. + 'show_toc_level': 3, } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -164,73 +181,64 @@ def linkcode_resolve(domain, info): return None if not info['module']: return None - filename = info['module'].replace('.', '/') - module = info['module'] - - # try to determine the correct file and line number to link to - obj = sys.modules[module] - - # get as specific as we can - lineno: int = 0 - filename: str = "" - try: - for part in info['fullname'].split('.'): - obj = getattr(obj, part) - - # Skip decorator wrappers by checking if the object is a function - # and has a __wrapped__ attribute (which decorators typically set) - while hasattr(obj, '__wrapped__'): - obj = obj.__wrapped__ - - if not (inspect.isclass(obj) or inspect.isfunction(obj) - or inspect.ismethod(obj)): - obj = obj.__class__ # Get the class of the instance - - lineno = inspect.getsourcelines(obj)[1] - filename = (inspect.getsourcefile(obj) - or f"{filename}.py").split("vllm/", 1)[1] - except Exception: - # For some things, like a class member, won't work, so - # we'll use the line number of the parent (the class) - pass - - if filename.startswith("checkouts/"): + + # Get path from module name + file = Path(f"{info['module'].replace('.', '/')}.py") + path = REPO_ROOT / file + if not path.exists(): + path = REPO_ROOT / file.with_suffix("") / "__init__.py" + if not path.exists(): + return None + + # Get the line number of the object + with open(path) as f: + lines = f.readlines() + name = info['fullname'].split(".")[-1] + pattern = fr"^( {{4}})*((def|class) )?{name}\b.*" + for lineno, line in enumerate(lines, 1): + if not line or line.startswith("#"): + continue + if re.match(pattern, line): + break + + # If the line number is not found, return None + if lineno == len(lines): + return None + + # If the line number is found, create the URL + filename = path.relative_to(REPO_ROOT) + if "checkouts" in path.parts: # a PR build on readthedocs - pr_number = filename.split("/")[1] - filename = filename.split("/", 2)[2] + pr_number = REPO_ROOT.name base, branch = get_repo_base_and_branch(pr_number) if base and branch: return f"https://github.com/{base}/blob/{branch}/{filename}#L{lineno}" - # Otherwise, link to the source file on the main branch return f"https://github.com/vllm-project/vllm/blob/main/{filename}#L{lineno}" -# Mock out external dependencies here, otherwise the autodoc pages may be blank. +# Mock out external dependencies here, otherwise sphinx-argparse won't work. autodoc_mock_imports = [ + "huggingface_hub", + "pydantic", + "zmq", + "cloudpickle", + "aiohttp", + "starlette", "blake3", - "compressed_tensors", "cpuinfo", - "cv2", - "torch", "transformers", "psutil", - "prometheus_client", - "sentencepiece", "vllm._C", "PIL", "numpy", - 'triton', "tqdm", - "tensorizer", - "pynvml", - "outlines", - "xgrammar", - "librosa", - "soundfile", - "gguf", - "lark", - "decord", + # The mocks below are required by + # docs/source/serving/openai_compatible_server.md's + # vllm.entrypoints.openai.cli_args + "openai", + "fastapi", + "partial_json_parser", ] for mock_target in autodoc_mock_imports: @@ -241,18 +249,6 @@ for mock_target in autodoc_mock_imports: "been loaded into sys.modules when the sphinx build starts.", mock_target) - -class MockedClassDocumenter(autodoc.ClassDocumenter): - """Remove note about base class when a class is derived from object.""" - - def add_line(self, line: str, source: str, *lineno: int) -> None: - if line == " Bases: :py:class:`object`": - return - super().add_line(line, source, *lineno) - - -autodoc.ClassDocumenter = MockedClassDocumenter - intersphinx_mapping = { "python": ("https://docs.python.org/3", None), "typing_extensions": @@ -264,7 +260,4 @@ intersphinx_mapping = { "psutil": ("https://psutil.readthedocs.io/en/stable", None), } -autodoc_preserve_defaults = True -autodoc_warningiserror = True - navigation_with_keys = False diff --git a/docs/source/contributing/deprecation_policy.md b/docs/source/contributing/deprecation_policy.md new file mode 100644 index 0000000000000000000000000000000000000000..598f1612d3af383a1b3558fba327b78a31684c69 --- /dev/null +++ b/docs/source/contributing/deprecation_policy.md @@ -0,0 +1,87 @@ +# Deprecation Policy + +This document outlines the official policy and process for deprecating features +in the vLLM project. + +## Overview + +vLLM uses a structured "deprecation pipeline" to guide the lifecycle of +deprecated features. This policy ensures that users are given clear and +sufficient notice when a feature is deprecated and that deprecations proceed in +a consistent and predictable manner. + +We aim to strike a balance between continued innovation and respecting users’ +reliance on existing functionality. Deprecations are tied to our **minor (Y) +releases** following semantic versioning (X.Y.Z), where: + +- **X** is a major version (rare) +- **Y** is a minor version (used for significant changes, including deprecations/removals) +- **Z** is a patch version (used for fixes and safer enhancements) + +Features that fall under this policy include (at a minimum) the following: + +- CLI flags +- Environment variables +- Configuration files +- APIs in the OpenAI-compatible API server +- Public Python APIs for the `vllm` library + +## Deprecation Pipeline + +The deprecation process consists of several clearly defined stages that span +multiple Y releases: + +**1. Deprecated (Still On By Default)** + +- **Action**: Feature is marked as deprecated. +- **Timeline**: A removal version is explicitly stated in the deprecation +warning (e.g., "This will be removed in v0.10.0"). +- **Communication**: Deprecation is noted in the following, as applicable: + - Help strings + - Log output + - API responses + - `/metrics` output (for metrics features) + - User-facing documentation + - Release notes + - GitHub Issue (RFC) for feedback + - Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs + +**2.Deprecated (Off By Default)** + +- **Action**: Feature is disabled by default, but can still be re-enabled via a +CLI flag or environment variable. Feature throws an error when used without +re-enabling. +- **Purpose**: Allows users who missed earlier warnings a temporary escape hatch +while signaling imminent removal. Ensures any remaining usage is clearly +surfaced and blocks silent breakage before full removal. + +**3. Removed** + +- **Action**: Feature is completely removed from the codebase. +- **Note**: Only features that have passed through the previous deprecation +stages will be removed. + +## Example Timeline + +Assume a feature is deprecated in `v0.9.0`. + +| Release | Status | +|---------------|-------------------------------------------------------------------------------------------------| +| `v0.9.0` | Feature is deprecated with clear removal version listed. | +| `v0.10.0` | Feature is now off by default, throws an error when used, and can be re-enabled for legacy use. | +| `v0.11.0` | Feature is removed. | + +## Important Guidelines + +- **No Removals in Patch Releases**: Removing deprecated features in patch +(`.Z`) releases is disallowed to avoid surprising users. +- **Grace Period for Existing Deprecations**: Any feature deprecated **before +this policy** will have its grace period start **now**, not retroactively. +- **Documentation is Critical**: Ensure every stage of the pipeline is +documented clearly for users. + +## Final Notes + +This policy is a living document and may evolve as the needs of the project and +its users change. Community feedback is welcome and encouraged as we refine the +process. diff --git a/docs/source/contributing/overview.md b/docs/source/contributing/overview.md index 31c7059fda3648fd432ef1ad2a3f4878fb0938a4..89b31f0311e23e5bcd3da63d982ac53245dedfed 100644 --- a/docs/source/contributing/overview.md +++ b/docs/source/contributing/overview.md @@ -17,7 +17,7 @@ Unsure on where to start? Check out the following links for tasks to work on: - [Good first issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22good%20first%20issue%22) - [Selected onboarding tasks](gh-project:6) -- [New model requests](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22new%20model%22) +- [New model requests](https://github.com/vllm-project/vllm/issues?q=is%3Aissue%20state%3Aopen%20label%3A%22new-model%22) - [Models with multi-modal capabilities](gh-project:10) ## License @@ -40,6 +40,10 @@ pre-commit install --hook-type pre-commit --hook-type commit-msg # You can manually run pre-commit with pre-commit run --all-files +# To manually run something from CI that does not run +# locally by default, you can run: +pre-commit run mypy-3.9 --hook-stage manual --all-files + # Unit tests pytest tests/ ``` @@ -54,6 +58,12 @@ Therefore, we recommend developing with Python 3.12 to minimise the chance of yo Currently, the repository is not fully checked by `mypy`. ::: +:::{note} +Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU +platform to run unit tests locally, rely on the continuous integration system to run the tests for +now. +::: + ## Issues If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. diff --git a/docs/source/deployment/frameworks/chatbox.md b/docs/source/deployment/frameworks/chatbox.md new file mode 100644 index 0000000000000000000000000000000000000000..e62f4647150f4f8899d23ac40e1bab51525b17e1 --- /dev/null +++ b/docs/source/deployment/frameworks/chatbox.md @@ -0,0 +1,36 @@ +(deployment-chatbox)= + +# Chatbox + +[Chatbox](https://github.com/chatboxai/chatbox) is a desktop client for LLMs, available on Windows, Mac, Linux. + +It allows you to deploy a large language model (LLM) server with vLLM as the backend, which exposes OpenAI-compatible endpoints. + +## Prerequisites + +- Setup vLLM environment + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +- Download and install [Chatbox desktop](https://chatboxai.app/en#download). + +- On the bottom left of settings, Add Custom Provider + - API Mode: `OpenAI API Compatible` + - Name: vllm + - API Host: `http://{vllm server host}:{vllm server port}/v1` + - API Path: `/chat/completions` + - Model: `qwen/Qwen1.5-0.5B-Chat` + +:::{image} /assets/deployment/chatbox-settings.png +::: + +- Go to `Just chat`, and start to chat: + +:::{image} /assets/deployment/chatbox-chat.png +::: diff --git a/docs/source/deployment/frameworks/dify.md b/docs/source/deployment/frameworks/dify.md new file mode 100644 index 0000000000000000000000000000000000000000..5cdf6a3876371db12b0c642af396fcc496aa5414 --- /dev/null +++ b/docs/source/deployment/frameworks/dify.md @@ -0,0 +1,56 @@ +(deployment-dify)= + +# Dify + +[Dify](https://github.com/langgenius/dify) is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. + +It supports vLLM as a model provider to efficiently serve large language models. + +This guide walks you through deploying Dify using a vLLM backend. + +## Prerequisites + +- Setup vLLM environment +- Install [Docker](https://docs.docker.com/engine/install/) and [Docker Compose](https://docs.docker.com/compose/install/) + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve Qwen/Qwen1.5-7B-Chat +``` + +- Start the Dify server with docker compose ([details](https://github.com/langgenius/dify?tab=readme-ov-file#quick-start)): + +```console +git clone https://github.com/langgenius/dify.git +cd dify +cd docker +cp .env.example .env +docker compose up -d +``` + +- Open the browser to access `http://localhost/install`, config the basic login information and login. + +- In the top-right user menu (under the profile icon), go to Settings, then click `Model Provider`, and locate the `vLLM` provider to install it. + +- Fill in the model provider details as follows: + - **Model Type**: `LLM` + - **Model Name**: `Qwen/Qwen1.5-7B-Chat` + - **API Endpoint URL**: `http://{vllm_server_host}:{vllm_server_port}/v1` + - **Model Name for API Endpoint**: `Qwen/Qwen1.5-7B-Chat` + - **Completion Mode**: `Completion` + +:::{image} /assets/deployment/dify-settings.png +::: + +- To create a test chatbot, go to `Studio → Chatbot → Create from Blank`, then select Chatbot as the type: + +:::{image} /assets/deployment/dify-create-chatbot.png +::: + +- Click the chatbot you just created to open the chat interface and start interacting with the model: + +:::{image} /assets/deployment/dify-chat.png +::: diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index a1b405386b77aa1b0f6a48685eb2c703a960f40b..3408c6c10edef898a2df595d274e5e59fc6948ef 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -6,11 +6,17 @@ anything-llm bentoml cerebrium +chatbox +dify dstack helm +litellm +lobe-chat lws modal open-webui +retrieval_augmented_generation skypilot +streamlit triton ::: diff --git a/docs/source/deployment/frameworks/litellm.md b/docs/source/deployment/frameworks/litellm.md new file mode 100644 index 0000000000000000000000000000000000000000..6dd3607ca5e370f216a32dd50c09339206bbd5c5 --- /dev/null +++ b/docs/source/deployment/frameworks/litellm.md @@ -0,0 +1,75 @@ +(deployment-litellm)= + +# LiteLLM + +[LiteLLM](https://github.com/BerriAI/litellm) call all LLM APIs using the OpenAI format [Bedrock, Huggingface, VertexAI, TogetherAI, Azure, OpenAI, Groq etc.] + +LiteLLM manages: + +- Translate inputs to provider's `completion`, `embedding`, and `image_generation` endpoints +- [Consistent output](https://docs.litellm.ai/docs/completion/output), text responses will always be available at `['choices'][0]['message']['content']` +- Retry/fallback logic across multiple deployments (e.g. Azure/OpenAI) - [Router](https://docs.litellm.ai/docs/routing) +- Set Budgets & Rate limits per project, api key, model [LiteLLM Proxy Server (LLM Gateway)](https://docs.litellm.ai/docs/simple_proxy) + +And LiteLLM supports all models on VLLM. + +## Prerequisites + +- Setup vLLM and litellm environment + +```console +pip install vllm litellm +``` + +## Deploy + +### Chat completion + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +- Call it with litellm: + +```python +import litellm + +messages = [{ "content": "Hello, how are you?","role": "user"}] + +# hosted_vllm is prefix key word and necessary +response = litellm.completion( + model="hosted_vllm/qwen/Qwen1.5-0.5B-Chat", # pass the vllm model name + messages=messages, + api_base="http://{your-vllm-server-host}:{your-vllm-server-port}/v1", + temperature=0.2, + max_tokens=80) + +print(response) +``` + +### Embeddings + +- Start the vLLM server with the supported embedding model, e.g. + +```console +vllm serve BAAI/bge-base-en-v1.5 +``` + +- Call it with litellm: + +```python +from litellm import embedding +import os + +os.environ["HOSTED_VLLM_API_BASE"] = "http://{your-vllm-server-host}:{your-vllm-server-port}/v1" + +# hosted_vllm is prefix key word and necessary +# pass the vllm model name +embedding = embedding(model="hosted_vllm/BAAI/bge-base-en-v1.5", input=["Hello world"]) + +print(embedding) +``` + +For details, see the tutorial [Using vLLM in LiteLLM](https://docs.litellm.ai/docs/providers/vllm). diff --git a/docs/source/deployment/frameworks/lobe-chat.md b/docs/source/deployment/frameworks/lobe-chat.md new file mode 100644 index 0000000000000000000000000000000000000000..6d86b7fa9cce1203f54c2b8cd056a6e27edbc76f --- /dev/null +++ b/docs/source/deployment/frameworks/lobe-chat.md @@ -0,0 +1,13 @@ +(deployment-lobe-chat)= + +# Lobe Chat + +[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework. + +Supports speech-synthesis, multi-modal, and extensible (function call) plugin system. + +One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. + +It supports vLLM as a AI model provider to efficiently serve large language models. + +For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/source/deployment/frameworks/retrieval_augmented_generation.md b/docs/source/deployment/frameworks/retrieval_augmented_generation.md new file mode 100644 index 0000000000000000000000000000000000000000..f84451fafe91d6602ced1ac14625e6f52c94ea36 --- /dev/null +++ b/docs/source/deployment/frameworks/retrieval_augmented_generation.md @@ -0,0 +1,84 @@ +(deployment-retrieval-augmented-generation)= + +# Retrieval-Augmented Generation + +[Retrieval-augmented generation (RAG)](https://en.wikipedia.org/wiki/Retrieval-augmented_generation) is a technique that enables generative artificial intelligence (Gen AI) models to retrieve and incorporate new information. It modifies interactions with a large language model (LLM) so that the model responds to user queries with reference to a specified set of documents, using this information to supplement information from its pre-existing training data. This allows LLMs to use domain-specific and/or updated information. Use cases include providing chatbot access to internal company data or generating responses based on authoritative sources. + +Here are the integrations: +- vLLM + [langchain](https://github.com/langchain-ai/langchain) + [milvus](https://github.com/milvus-io/milvus) +- vLLM + [llamaindex](https://github.com/run-llama/llama_index) + [milvus](https://github.com/milvus-io/milvus) + +## vLLM + langchain + +### Prerequisites + +- Setup vLLM and langchain environment + +```console +pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_langchain.py +``` + +## vLLM + llamaindex + +### Prerequisites + +- Setup vLLM and llamaindex environment + +```console +pip install vllm \ + llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ +``` + +### Deploy + +- Start the vLLM server with the supported embedding model, e.g. + +```console +# Start embedding service (port 8000) +vllm serve ssmits/Qwen2-7B-Instruct-embed-base +``` + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +# Start chat service (port 8001) +vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 +``` + +- Use the script: + +- Run the script + +```python +python retrieval_augmented_generation_with_llamaindex.py +``` diff --git a/docs/source/deployment/frameworks/streamlit.md b/docs/source/deployment/frameworks/streamlit.md new file mode 100644 index 0000000000000000000000000000000000000000..084550ec991e1af0ac258a9a8babed882a10ec3b --- /dev/null +++ b/docs/source/deployment/frameworks/streamlit.md @@ -0,0 +1,42 @@ +(deployment-streamlit)= + +# Streamlit + +[Streamlit](https://github.com/streamlit/streamlit) lets you transform Python scripts into interactive web apps in minutes, instead of weeks. Build dashboards, generate reports, or create chat apps. + +It can be quickly integrated with vLLM as a backend API server, enabling powerful LLM inference via API calls. + +## Prerequisites + +- Setup vLLM environment + +## Deploy + +- Start the vLLM server with the supported chat completion model, e.g. + +```console +vllm serve qwen/Qwen1.5-0.5B-Chat +``` + +- Install streamlit and openai: + +```console +pip install streamlit openai +``` + +- Use the script: + +- Start the streamlit web UI and start to chat: + +```console +streamlit run streamlit_openai_chatbot_webserver.py + +# or specify the VLLM_API_BASE or VLLM_API_KEY +VLLM_API_BASE="http://vllm-server-host:vllm-server-port/v1" streamlit run streamlit_openai_chatbot_webserver.py + +# start with debug mode to view more details +streamlit run streamlit_openai_chatbot_webserver.py --logger.level=debug +``` + +:::{image} /assets/deployment/streamlit-chat.png +::: diff --git a/docs/source/deployment/security.md b/docs/source/deployment/security.md index e2ef8196c16711ca8837e5a9966ff75465d457fd..9c4d639c0b3da68b2b1dc96e35d627f18218a589 100644 --- a/docs/source/deployment/security.md +++ b/docs/source/deployment/security.md @@ -53,6 +53,45 @@ Key points from the PyTorch security guide: - Implement proper authentication and authorization for management interfaces - Follow the principle of least privilege for all system components +## Security and Firewalls: Protecting Exposed vLLM Systems + +While vLLM is designed to allow unsafe network services to be isolated to +private networks, there are components—such as dependencies and underlying +frameworks—that may open insecure services listening on all network interfaces, +sometimes outside of vLLM's direct control. + +A major concern is the use of `torch.distributed`, which vLLM leverages for +distributed communication, including when using vLLM on a single host. When vLLM +uses TCP initialization (see [PyTorch TCP Initialization +documentation](https://docs.pytorch.org/docs/stable/distributed.html#tcp-initialization)), +PyTorch creates a `TCPStore` that, by default, listens on all network +interfaces. This means that unless additional protections are put in place, +these services may be accessible to any host that can reach your machine via any +network interface. + +**From a PyTorch perspective, any use of `torch.distributed` should be +considered insecure by default.** This is a known and intentional behavior from +the PyTorch team. + +### Firewall Configuration Guidance + +The best way to protect your vLLM system is to carefully configure a firewall to +expose only the minimum network surface area necessary. In most cases, this +means: + +- **Block all incoming connections except to the TCP port the API server is +listening on.** + +- Ensure that ports used for internal communication (such as those for +`torch.distributed` and KV cache transfer) are only accessible from trusted +hosts or networks. + +- Never expose these internal ports to the public internet or untrusted +networks. + +Consult your operating system or application platform documentation for specific +firewall configuration instructions. + ## Reporting Security Vulnerabilities If you believe you have found a security vulnerability in vLLM, please report it following the project's security policy. For more information on how to report security issues and the project's security policy, please see the [vLLM Security Policy](https://github.com/vllm-project/vllm/blob/main/SECURITY.md). diff --git a/docs/source/design/arch_overview.md b/docs/source/design/arch_overview.md index 7bed0a001d6f50ec6e4e1d6fd3ccfa5bc0c912c0..94bda8b5c58d505e2d2cb5c80f3a4ac29e727a78 100644 --- a/docs/source/design/arch_overview.md +++ b/docs/source/design/arch_overview.md @@ -52,8 +52,8 @@ for output in outputs: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` -More API details can be found in the {doc}`Offline Inference -` section of the API docs. +More API details can be found in the [Offline Inference] +(#offline-inference-api) section of the API docs. The code for the `LLM` class can be found in . diff --git a/docs/source/design/v1/metrics.md b/docs/source/design/v1/metrics.md index 3f96290798a334c768f95996db5d68478a6dd6f1..de80226553728c2cd939b768028113c013181955 100644 --- a/docs/source/design/v1/metrics.md +++ b/docs/source/design/v1/metrics.md @@ -415,8 +415,8 @@ The discussion in about adding prefix cache metrics yielded some interesting points which may be relevant to how we approach future metrics. -Every time the prefix cache is queried, we record the number of blocks -queried and the number of queried blocks present in the cache +Every time the prefix cache is queried, we record the number of tokens +queried and the number of queried tokens present in the cache (i.e. hits). However, the metric of interest is the hit rate - i.e. the number of @@ -467,6 +467,9 @@ In general: hatch](https://kubernetes.io/docs/concepts/cluster-administration/system-metrics/#show-hidden-metrics) for some time before deleting them. +See the [deprecation policy](project:../../contributing/deprecation_policy.md) for +the project-wide deprecation policy. + ### Unimplemented - `vllm:tokens_total` Added by , but apparently never implemented. This can just be diff --git a/docs/source/design/v1/prefix_caching.md b/docs/source/design/v1/prefix_caching.md index ec1f3cb8d64a84fd20a2b56097368e4b2d2a61aa..0f7475777797b6cb2fe454d6acb8f5908d3a24e8 100644 --- a/docs/source/design/v1/prefix_caching.md +++ b/docs/source/design/v1/prefix_caching.md @@ -16,7 +16,7 @@ In the example above, the KV cache in the first block can be uniquely identified * Parent hash value: The hash value of the parent hash block. * Block tokens: A tuple of tokens in this block. The reason to include the exact tokens is to reduce potential hash value collision. -* Extra hashes: Other values required to make this block unique, such as LoRA IDs and multi-modality input hashes (see the example below). +* Extra hashes: Other values required to make this block unique, such as LoRA IDs, multi-modality input hashes (see the example below), and cache salts to isolate caches in multi-tenant environments. > **Note 1:** We only cache full blocks. @@ -76,6 +76,24 @@ Block 3 In the rest of this document, we first introduce the data structure used for prefix caching in vLLM v1, followed by the prefix caching workflow of major KV cache operators (e.g., allocate, append, free, eviction). Finally, we use an example to illustrate the end to end prefix caching workflow. +**Cache Isolation for Security** +To improve privacy in shared environments, vLLM supports isolating prefix cache reuse through optional per-request salting. By including a `cache_salt` in the request, this value is injected into the hash of the first block, ensuring that only requests with the same salt can reuse cached KV blocks. This prevents timing-based attacks where an adversary could infer cached content by observing latency differences. This offers protection without compromising performance. + +```json +{ + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Here is a document with details about the world series: ..."}, + {"role": "user", "content": "Who won the world series in 2020?"} + ], + "cache_salt": "your-cache-salt" +} +``` + +With this setup, cache sharing is limited to users or requests that explicitly agree on a common salt, enabling cache reuse within a trust group while isolating others. + +> **Note:** Cache isolation is not supported in engine V0. + ## Data Structure The prefix caching in vLLM v1 is implemented in the KV cache manager. The basic building block is the “Block” data class (simplified): diff --git a/docs/source/design/v1/torch_compile.md b/docs/source/design/v1/torch_compile.md index 7920131643c26153d5ca4b7305ffb02c31835fab..4d8ce0fd9227f1255840124ff23cdd2d1285be18 100644 --- a/docs/source/design/v1/torch_compile.md +++ b/docs/source/design/v1/torch_compile.md @@ -137,3 +137,9 @@ By default, vLLM will try to determine a set of sizes to capture cudagraph. You `vllm serve meta-llama/Llama-3.2-1B --compilation-config "{'cudagraph_capture_sizes': [1, 2, 4, 8]}"` Then it will only capture cudagraph for the specified sizes. It can be useful to have fine-grained control over the cudagraph capture. + +### Full Cudagraph capture + +It is possible to include attention as part of the cudagraph if using an attention backend that is cudagraph compatible. This can improve performance in some cases such as decode speed for smaller models. Enable this using `--compilation-config "{'full_cuda_graph': True}"` + +Currently only FlashAttention 3 is compatible, and only when cascade attention is disabled. diff --git a/docs/source/features/compatibility_matrix.md b/docs/source/features/compatibility_matrix.md index 6056ca0d366b5e38ae391285dc0033c9993471c6..8865d26deaedaf18a746aec8bbf75e3a12620ae7 100644 --- a/docs/source/features/compatibility_matrix.md +++ b/docs/source/features/compatibility_matrix.md @@ -42,7 +42,7 @@ Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/h * [APC](#automatic-prefix-caching) * [LoRA](#lora-adapter) * prmpt adptr - * [SD](#spec_decode) + * [SD](#spec-decode) * CUDA graph * pooling * enc-dec @@ -122,7 +122,7 @@ Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/h * * * -- * [SD](#spec_decode) +- * [SD](#spec-decode) * ✅ * ✅ * ❌ @@ -377,7 +377,7 @@ Check the ❌ or 🟠 with links to see tracking issue for unsupported feature/h * ✅ * [❌](gh-issue:8475) * ✅ -- * [SD](#spec_decode) +- * [SD](#spec-decode) * ✅ * ✅ * ✅ diff --git a/docs/source/features/lora.md b/docs/source/features/lora.md index b5b51095b3a75656d69e3303200a80346052f19c..5a3ce0c01f3fabc133d1f28f4fac0cabd2d5a826 100644 --- a/docs/source/features/lora.md +++ b/docs/source/features/lora.md @@ -66,7 +66,7 @@ The commit ID `0dfa347e8877a4d4ed19ee56c140fa518470028c` may change over time. P The server entrypoint accepts all other LoRA configuration parameters (`max_loras`, `max_lora_rank`, `max_cpu_loras`, etc.), which will apply to all forthcoming requests. Upon querying the `/models` endpoint, we should see our LoRA along -with its base model: +with its base model (if `jq` is not installed, you can follow [this guide](https://jqlang.org/download/) to install it.): ```bash curl localhost:8000/v1/models | jq . @@ -134,7 +134,7 @@ curl -X POST http://localhost:8000/v1/load_lora_adapter \ }' ``` -Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter +Upon a successful request, the API will respond with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' added successfully`. If an error occurs, such as if the adapter cannot be found or loaded, an appropriate error message will be returned. Unloading a LoRA Adapter: @@ -142,6 +142,8 @@ Unloading a LoRA Adapter: To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint with the name or ID of the adapter to be unloaded. +Upon a successful request, the API responds with a `200 OK` status code from `vllm serve`, and `curl` returns the response body: `Success: LoRA adapter 'sql_adapter' removed successfully`. + Example request to unload a LoRA adapter: ```bash @@ -157,9 +159,12 @@ Alternatively, you can use the LoRAResolver plugin to dynamically load LoRA adap You can set up multiple LoRAResolver plugins if you want to load LoRA adapters from different sources. For example, you might have one resolver for local files and another for S3 storage. vLLM will load the first LoRA adapter that it finds. -You can either install existing plugins or implement your own. +You can either install existing plugins or implement your own. By default, vLLM comes with a [resolver plugin to load LoRA adapters from a local directory.](https://github.com/vllm-project/vllm/tree/main/vllm/plugins/lora_resolvers) +To enable this resolver, set `VLLM_ALLOW_RUNTIME_LORA_UPDATING` to True, set `VLLM_PLUGINS` to include `lora_filesystem_resolver`, and then set `VLLM_LORA_RESOLVER_CACHE_DIR` to a local directory. When vLLM receives a request using a LoRA adapter `foobar`, +it will first look in the local directory for a directory `foobar`, and attempt to load the contents of that directory as a LoRA adapter. If successful, the request will complete as normal and +that adapter will then be available for normal use on the server. -Steps to implement your own LoRAResolver plugin: +Alternatively, follow these example steps to implement your own plugin: 1. Implement the LoRAResolver interface. Example of a simple S3 LoRAResolver implementation: diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/features/multimodal_inputs.md similarity index 96% rename from docs/source/serving/multimodal_inputs.md rename to docs/source/features/multimodal_inputs.md index d9a093e8d145d2facb9438212efc9a8dfbb806ae..bb2997f008ed5f30b32394b4c70496c0943b7d00 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/features/multimodal_inputs.md @@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions :::{important} A chat template is **required** to use Chat Completions API. +For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. -Although most models come with a chat template, for others you have to define one yourself. -The chat template can be inferred based on the documentation on the model's HuggingFace repo. -For example, LLaVA-1.5 (`llava-hf/llava-1.5-7b-hf`) requires a chat template that can be found here: +If no default chat template is available, we will first look for a built-in fallback in . +If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. + +For certain models, we provide alternative chat templates inside . +For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. ::: ### Image Inputs diff --git a/docs/source/features/prompt_embeds.md b/docs/source/features/prompt_embeds.md new file mode 100644 index 0000000000000000000000000000000000000000..4e4648d171d55d453dcd4ddd5dcee2f04e535963 --- /dev/null +++ b/docs/source/features/prompt_embeds.md @@ -0,0 +1,144 @@ +# Prompt Embedding Inputs + +This page teaches you how to pass prompt embedding inputs to vLLM. + +## What are prompt embeddings? + +The traditional flow of text data for a Large Language Model goes from text to token ids (via a tokenizer) then from token ids to prompt embeddings. For a traditional decoder-only model (such as meta-llama/Llama-3.1-8B-Instruct), this step of converting token ids to prompt embeddings happens via a look-up from a learned embedding matrix, but the model is not limited to processing only the embeddings corresponding to its token vocabulary. + +:::{note} +Prompt embeddings are currently only supported in the v0 engine. +::: + +## Offline Inference + +To input multi-modal data, follow this schema in {class}`vllm.inputs.EmbedsPrompt`: + +- `prompt_embeds`: A torch tensor representing a sequence of prompt/token embeddings. This has the shape (sequence_length, hidden_size), where sequence length is the number of tokens embeddings and hidden_size is the hidden size (embedding size) of the model. + +### Hugging Face Transformers Inputs + +You can pass prompt embeddings from Hugging Face Transformers models to the `'prompt_embeds'` field of the prompt embedding dictionary, as shown in the following examples: + +```python +from vllm import LLM +import transformers + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + +llm = LLM(model=model_name, enable_prompt_embeds=True) + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +embedding_layer = transformers_model.get_input_embeddings() +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Single prompt inference +outputs = llm.generate({ + "prompt_embeds": prompt_embeds, +}) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +# Batch inference + +chats = [ + [{"role": "user", "content": "Please tell me about the capital of France."}], + [{"role": "user", "content": "When is the day longest during the year?"}], + [{"role": "user", "content": "Where is bigger, the moon or the sun?"}] +] + +token_ids_list = [ + tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') for chat in chats +] +prompt_embeds_list = [embedding_layer(token_ids).squeeze(0) for token_ids in token_ids_list] + +outputs = llm.generate( + [ + { + "prompt_embeds": prompt_embeds, + } for prompt_embeds in prompt_embeds_list + ] +) + +for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) +``` + +## Online Serving + +Our OpenAI-compatible server accepts prompt embeddings inputs via the [Completions API](https://platform.openai.com/docs/api-reference/completions). Prompt embeddings inputs are added via a new `'prompt_embeds'` key in the JSON package. + +When a mixture of `'prompt_embeds'` and `'prompt'` inputs are provided in a single request, the prompt embeds are always returned first. + +Prompt embeddings are passed in as base64 encoded torch tensors. + +### Transformers Inputs via OpenAI Client + +First, launch the OpenAI-compatible server: + +```bash +vllm serve meta-llama/Llama-3.2-1B-Instruct --task generate \ + --max-model-len 4096 --enable-prompt-embeds +``` + +Then, you can use the OpenAI client as follows: + +```python +from openai import OpenAI +import transformers +import torch + +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + +model_name = "meta-llama/Llama-3.2-1B-Instruct" + +# Transformers +tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) +transformers_model = transformers.AutoModelForCausalLM.from_pretrained(model_name) + + +# Refer to the HuggingFace repo for the correct format to use +chat = [{"role": "user", "content": "Please tell me about the capital of France."}] +token_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors='pt') + +embedding_layer = transformers_model.get_input_embeddings() +prompt_embeds = embedding_layer(token_ids).squeeze(0) + +# Prompt embeddings +buffer = io.BytesIO() +torch.save(prompt_embeds, buffer) +buffer.seek(0) +binary_data = buffer.read() +encoded_embeds = base64.b64encode(binary_data).decode('utf-8') + + +completion = client_with_prompt_embeds.completions.create( + model=model_name, + # NOTE: The OpenAI client does not allow `None` as an input to + # `prompt`. Use an empty string if you have no text prompts. + prompt="", + max_tokens=5, + temperature=0.0, + # NOTE: The OpenAI client allows passing in extra JSON body via the + # `extra_body` argument. + extra_body={"prompt_embeds": encoded_embeds} +) + +print(completion.choices[0].text) +``` diff --git a/docs/source/features/quantization/fp8.md b/docs/source/features/quantization/fp8.md index a62e0124b77060f5bc38c49bd7541525715df3d5..cb304d54726c828706230b477470628448525c36 100644 --- a/docs/source/features/quantization/fp8.md +++ b/docs/source/features/quantization/fp8.md @@ -19,23 +19,6 @@ FP8 computation is supported on NVIDIA GPUs with compute capability > 8.9 (Ada L FP8 models will run on compute capability > 8.0 (Ampere) as weight-only W8A16, utilizing FP8 Marlin. ::: -## Quick Start with Online Dynamic Quantization - -Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor. - -In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode. - -```python -from vllm import LLM -model = LLM("facebook/opt-125m", quantization="fp8") -# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB -result = model.generate("Hello, my name is") -``` - -:::{warning} -Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. -::: - ## Installation To produce performant FP8 quantized models with vLLM, you'll need to install the [llm-compressor](https://github.com/vllm-project/llm-compressor/) library: @@ -86,7 +69,7 @@ recipe = QuantizationModifier( # Apply the quantization algorithm. oneshot(model=model, recipe=recipe) -# Save the model. +# Save the model: Meta-Llama-3-8B-Instruct-FP8-Dynamic SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic" model.save_pretrained(SAVE_DIR) tokenizer.save_pretrained(SAVE_DIR) @@ -94,7 +77,7 @@ tokenizer.save_pretrained(SAVE_DIR) ### 3. Evaluating Accuracy -Install `vllm` and `lm-evaluation-harness`: +Install `vllm` and `lm-evaluation-harness` for evaluation: ```console pip install vllm lm-eval==0.4.4 @@ -105,7 +88,8 @@ Load and run the model in `vllm`: ```python from vllm import LLM model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic") -model.generate("Hello my name is") +result = model.generate("Hello my name is") +print(result[0].outputs[0].text) ``` Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`): @@ -133,59 +117,22 @@ Here's an example of the resulting scores: ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the `vllm-project/llm-compressor` GitHub repository. - -## Deprecated Flow - -:::{note} -The following information is preserved for reference and search purposes. -The quantization method described below is deprecated in favor of the `llmcompressor` method described above. -::: - -For static per-tensor offline quantization to FP8, please install the [AutoFP8 library](https://github.com/neuralmagic/autofp8). - -```bash -git clone https://github.com/neuralmagic/AutoFP8.git -pip install -e AutoFP8 -``` - -This package introduces the `AutoFP8ForCausalLM` and `BaseQuantizeConfig` objects for managing how your model will be compressed. - -## Offline Quantization with Static Activation Scaling Factors - -You can use AutoFP8 with calibration data to produce per-tensor static scales for both the weights and activations by enabling the `activation_scheme="static"` argument. - -```python -from datasets import load_dataset -from transformers import AutoTokenizer -from auto_fp8 import AutoFP8ForCausalLM, BaseQuantizeConfig - -pretrained_model_dir = "meta-llama/Meta-Llama-3-8B-Instruct" -quantized_model_dir = "Meta-Llama-3-8B-Instruct-FP8" - -tokenizer = AutoTokenizer.from_pretrained(pretrained_model_dir, use_fast=True) -tokenizer.pad_token = tokenizer.eos_token +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. -# Load and tokenize 512 dataset samples for calibration of activation scales -ds = load_dataset("mgoin/ultrachat_2k", split="train_sft").select(range(512)) -examples = [tokenizer.apply_chat_template(batch["messages"], tokenize=False) for batch in ds] -examples = tokenizer(examples, padding=True, truncation=True, return_tensors="pt").to("cuda") +## Online Dynamic Quantization -# Define quantization config with static activation scales -quantize_config = BaseQuantizeConfig(quant_method="fp8", activation_scheme="static") - -# Load the model, quantize, and save checkpoint -model = AutoFP8ForCausalLM.from_pretrained(pretrained_model_dir, quantize_config) -model.quantize(examples) -model.save_quantized(quantized_model_dir) -``` +Dynamic quantization of an original precision BF16/FP16 model to FP8 can be achieved with vLLM without any calibration data required. You can enable the feature by specifying `--quantization="fp8"` in the command line or setting `quantization="fp8"` in the LLM constructor. -Your model checkpoint with quantized weights and activations should be available at `Meta-Llama-3-8B-Instruct-FP8/`. -Finally, you can load the quantized model checkpoint directly in vLLM. +In this mode, all Linear modules (except for the final `lm_head`) have their weights quantized down to FP8_E4M3 precision with a per-tensor scale. Activations have their minimum and maximum values calculated during each forward pass to provide a dynamic per-tensor scale for high accuracy. As a result, latency improvements are limited in this mode. ```python from vllm import LLM -model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/") -# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB +model = LLM("facebook/opt-125m", quantization="fp8") +# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB result = model.generate("Hello, my name is") +print(result[0].outputs[0].text) ``` + +:::{warning} +Currently, we load the model at original precision before quantizing down to 8-bits, so you need enough memory to load the whole model. +::: diff --git a/docs/source/features/quantization/index.md b/docs/source/features/quantization/index.md index c7c8aeb662a56d189111fff7c0c1df334b385335..7ad46b7094ee964ccf86eb4f854e5edb60ae51d2 100644 --- a/docs/source/features/quantization/index.md +++ b/docs/source/features/quantization/index.md @@ -17,6 +17,7 @@ gptqmodel int4 int8 fp8 +modelopt quark quantized_kvcache torchao diff --git a/docs/source/features/quantization/int4.md b/docs/source/features/quantization/int4.md index f8939e5bf01505e5b057633820e89fac98c3b251..7a0ab4ad229e6f278a320846383f71817518cd8c 100644 --- a/docs/source/features/quantization/int4.md +++ b/docs/source/features/quantization/int4.md @@ -18,6 +18,12 @@ To use INT4 quantization with vLLM, you'll need to install the [llm-compressor]( pip install llmcompressor ``` +Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: + +```console +pip install vllm lm-eval==0.4.4 +``` + ## Quantization Process The quantization process involves four main steps: @@ -87,7 +93,7 @@ oneshot( num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Save the compressed model +# Save the compressed model: Meta-Llama-3-8B-Instruct-W4A16-G128 SAVE_DIR = MODEL_ID.split("/")[1] + "-W4A16-G128" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) @@ -163,4 +169,4 @@ recipe = GPTQModifier( ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py). +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. The full INT4 quantization example in `llm-compressor` is available [here](https://github.com/vllm-project/llm-compressor/blob/main/examples/quantization_w4a16/llama3_example.py). diff --git a/docs/source/features/quantization/int8.md b/docs/source/features/quantization/int8.md index b381f34bccd34e24180958435a5733640e64bb59..1e4b01d35575c70aa65e8a857d7c8967e391d7b8 100644 --- a/docs/source/features/quantization/int8.md +++ b/docs/source/features/quantization/int8.md @@ -19,6 +19,12 @@ To use INT8 quantization with vLLM, you'll need to install the [llm-compressor]( pip install llmcompressor ``` +Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: + +```console +pip install vllm lm-eval==0.4.4 +``` + ## Quantization Process The quantization process involves four main steps: @@ -91,7 +97,7 @@ oneshot( num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Save the compressed model +# Save the compressed model: Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Per-Token SAVE_DIR = MODEL_ID.split("/")[1] + "-W8A8-Dynamic-Per-Token" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) @@ -132,4 +138,4 @@ Quantized models can be sensitive to the presence of the `bos` token. Make sure ## Troubleshooting and Support -If you encounter any issues or have feature requests, please open an issue on the [`vllm-project/llm-compressor`](https://github.com/vllm-project/llm-compressor) GitHub repository. +If you encounter any issues or have feature requests, please open an issue on the [vllm-project/llm-compressor](https://github.com/vllm-project/llm-compressor/issues) GitHub repository. diff --git a/docs/source/features/quantization/modelopt.md b/docs/source/features/quantization/modelopt.md new file mode 100644 index 0000000000000000000000000000000000000000..001d18657dad084e1b64e5dac7440a6df39e3c6c --- /dev/null +++ b/docs/source/features/quantization/modelopt.md @@ -0,0 +1,78 @@ +# NVIDIA TensorRT Model Optimizer + +The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models. + +We recommend installing the library with: + +```console +pip install nvidia-modelopt +``` + +## Quantizing HuggingFace Models with PTQ + +You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory. + +Below is an example showing how to quantize a model using modelopt's PTQ API: + +```python +import modelopt.torch.quantization as mtq +from transformers import AutoModelForCausalLM + +# Load the model from HuggingFace +model = AutoModelForCausalLM.from_pretrained("") + +# Select the quantization config, for example, FP8 +config = mtq.FP8_DEFAULT_CFG + +# Define a forward loop function for calibration +def forward_loop(model): + for data in calib_set: + model(data) + +# PTQ with in-place replacement of quantized modules +model = mtq.quantize(model, config, forward_loop) +``` + +After the model is quantized, you can export it to a quantized checkpoint using the export API: + +```python +import torch +from modelopt.torch.export import export_hf_checkpoint + +with torch.inference_mode(): + export_hf_checkpoint( + model, # The quantized model. + export_dir, # The directory where the exported files will be stored. + ) +``` + +The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM: + +```python +from vllm import LLM, SamplingParams + +def main(): + + model_id = "nvidia/Llama-3.1-8B-Instruct-FP8" + # Ensure you specify quantization='modelopt' when loading the modelopt checkpoint + llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True) + + sampling_params = SamplingParams(temperature=0.8, top_p=0.9) + + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + outputs = llm.generate(prompts, sampling_params) + + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + +if __name__ == "__main__": + main() +``` diff --git a/docs/source/features/quantization/quantized_kvcache.md b/docs/source/features/quantization/quantized_kvcache.md index 9f36c2949e0dd11849be687365eaa521e322c0d7..86e6354ec82e094f8f86543ca3255f1480236f07 100644 --- a/docs/source/features/quantization/quantized_kvcache.md +++ b/docs/source/features/quantization/quantized_kvcache.md @@ -126,7 +126,7 @@ oneshot( num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) -# Save quantized model +# Save quantized model: Llama-3.1-8B-Instruct-FP8-KV SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-KV" model.save_pretrained(SAVE_DIR, save_compressed=True) tokenizer.save_pretrained(SAVE_DIR) diff --git a/docs/source/features/quantization/quark.md b/docs/source/features/quantization/quark.md index 935ee37a815ffd1d4a382a2bd1266b6f58e5b11f..955890dbc75ba3099213ccbf97076b0e2b81b28f 100644 --- a/docs/source/features/quantization/quark.md +++ b/docs/source/features/quantization/quark.md @@ -19,6 +19,12 @@ pip install amd-quark You can refer to [Quark installation guide](https://quark.docs.amd.com/latest/install.html) for more installation details. +Additionally, install `vllm` and `lm-evaluation-harness` for evaluation: + +```console +pip install vllm lm-eval==0.4.4 +``` + ## Quantization Process After installing Quark, we will use an example to illustrate how to use Quark. @@ -150,6 +156,7 @@ LLAMA_KV_CACHE_GROUP = ["*k_proj", "*v_proj"] export_config = ExporterConfig(json_export_config=JsonExporterConfig()) export_config.json_export_config.kv_cache_group = LLAMA_KV_CACHE_GROUP +# Model: Llama-2-70b-chat-hf-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant EXPORT_DIR = MODEL_ID.split("/")[1] + "-w-fp8-a-fp8-kvcache-fp8-pertensor-autosmoothquant" exporter = ModelExporter(config=export_config, export_dir=EXPORT_DIR) with torch.no_grad(): diff --git a/docs/source/features/quantization/supported_hardware.md b/docs/source/features/quantization/supported_hardware.md index 984e6626e2417fc095850a14d1894a288d2a37ac..f8af1ba60b125f5a2bc0d8b99e501ec529fb8d49 100644 --- a/docs/source/features/quantization/supported_hardware.md +++ b/docs/source/features/quantization/supported_hardware.md @@ -80,7 +80,7 @@ The table below shows the compatibility of various quantization implementations * ✅︎ * ✅︎ * ✅︎ - * ✅︎ + * ❌ * ❌ * ❌ * ❌ @@ -129,7 +129,17 @@ The table below shows the compatibility of various quantization implementations * ❌ * ❌ * ❌ - +- * modelopt + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎ + * ✅︎︎ + * ❌ + * ❌ + * ❌ + * ❌ + * ❌ ::: - Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0. diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 3a0be69f8e1c661f899889aeacd20bfff9fbf015..bf4f8901a11a8ed11577cea1a88ff838ffc4611d 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -15,16 +15,19 @@ vLLM currently supports the following reasoning models: | [DeepSeek R1 series](https://huggingface.co/collections/deepseek-ai/deepseek-r1-678e1e131c0169c0bc89728d) | `deepseek_r1` | `guided_json`, `guided_regex` | ❌ | | [QwQ-32B](https://huggingface.co/Qwen/QwQ-32B) | `deepseek_r1` | `guided_json`, `guided_regex` | ✅ | | [IBM Granite 3.2 language models](https://huggingface.co/collections/ibm-granite/granite-32-language-models-67b3bc8c13508f6d064cff9a) | `granite` | ❌ | ❌ | +| [Qwen3 series](https://huggingface.co/collections/Qwen/qwen3-67dd247413f0e2e4f653967f) | `qwen3` | `guided_json`, `guided_regex` | ✅ | -- IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +:::{note} +IBM Granite 3.2 reasoning is disabled by default; to enable it, you must also pass `thinking=True` in your `chat_template_kwargs`. +The reasoning feature for the Qwen3 series is enabled by default. To disable it, you must pass `enable_thinking=False` in your `chat_template_kwargs`. +::: ## Quickstart -To use reasoning models, you need to specify the `--enable-reasoning` and `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output. +To use reasoning models, you need to specify the `--reasoning-parser` flags when making a request to the chat completion endpoint. The `--reasoning-parser` flag specifies the reasoning parser to use for extracting reasoning content from the model output. ```bash -vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --enable-reasoning --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` Next, make a request to the model that should return the reasoning content in the response. @@ -47,6 +50,8 @@ model = models.data[0].id # Round 1 messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] # For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` +# For Qwen3 series, if you want to disable thinking in reasoning mode, add: +# extra_body={"chat_template_kwargs": {"enable_thinking": False}} response = client.chat.completions.create(model=model, messages=messages) reasoning_content = response.choices[0].message.reasoning_content @@ -83,7 +88,7 @@ Streaming chat completions are also supported for reasoning models. The `reasoni } ``` -OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client support extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: +OpenAI Python client library does not officially support `reasoning_content` attribute for streaming output. But the client supports extra attributes in the response. You can use `hasattr` to check if the `reasoning_content` attribute is present in the response. For example: ```python from openai import OpenAI @@ -102,6 +107,8 @@ model = models.data[0].id messages = [{"role": "user", "content": "9.11 and 9.8, which is greater?"}] # For granite, add: `extra_body={"chat_template_kwargs": {"thinking": True}}` +# For Qwen3 series, if you want to disable thinking in reasoning mode, add: +# extra_body={"chat_template_kwargs": {"enable_thinking": False}} stream = client.chat.completions.create(model=model, messages=messages, stream=True) @@ -139,11 +146,10 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. ```bash -VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --enable-reasoning --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` -Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. +The following is an example client: ```python from openai import OpenAI @@ -222,7 +228,7 @@ print(f"Function called: {tool_call.name}") print(f"Arguments: {tool_call.arguments}") ``` -For more examples, please refer to . +For more examples, please refer to . ## Limitations @@ -230,13 +236,12 @@ For more examples, please refer to . ```python # import the required packages -from vllm.entrypoints.openai.reasoning_parsers.abs_reasoning_parsers import ( - ReasoningParser, ReasoningParserManager) +from vllm.reasoning import ReasoningParser, ReasoningParserManager from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, DeltaMessage) @@ -287,7 +292,7 @@ class ExampleParser(ReasoningParser): """ ``` -Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in `vllm/model_executor/guided_decoding/reasoner/deepseek_reasoner.py`. +Additionally, to enable structured output, you'll need to create a new `Reasoner` similar to the one in . ```python @dataclass @@ -313,11 +318,10 @@ class DeepSeekReasoner(Reasoner): ... ``` -The structured output engine like `xgrammar` will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. +The structured output engine like [xgrammar](https://github.com/mlc-ai/xgrammar) will use `end_token_id` to check if the reasoning content is present in the model output and skip the structured output if it is the case. -Finally, you can enable reasoning for the model by using the `--enable-reasoning` and `--reasoning-parser` flags. +Finally, you can enable reasoning for the model by using the `--reasoning-parser` flags. ```bash -vllm serve \ - --enable-reasoning --reasoning-parser example +vllm serve --reasoning-parser example ``` diff --git a/docs/source/features/tool_calling.md b/docs/source/features/tool_calling.md index f98ec6108cea616068ab726c5f1bbd569467156e..2795b769345eea216fb5ab4810bffce013b3596d 100644 --- a/docs/source/features/tool_calling.md +++ b/docs/source/features/tool_calling.md @@ -141,9 +141,9 @@ Known issues: much shorter than what vLLM generates. Since an exception is thrown when this condition is not met, the following additional chat templates are provided: -* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +* - this is the "official" Mistral chat template, but tweaked so that it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) -* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +* - this is a "better" version that adds a tool-use system prompt when tools are provided, that results in much better reliability when working with parallel tool calling. Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` @@ -170,15 +170,15 @@ Known issues: VLLM provides two JSON based chat templates for Llama 3.1 and 3.2: -* `examples/tool_chat_template_llama3.1_json.jinja` - this is the "official" chat template for the Llama 3.1 +* - this is the "official" chat template for the Llama 3.1 models, but tweaked so that it works better with vLLM. -* `examples/tool_chat_template_llama3.2_json.jinja` - this extends upon the Llama 3.1 chat template by adding support for +* - this extends upon the Llama 3.1 chat template by adding support for images. Recommended flags: `--tool-call-parser llama3_json --chat-template {see_above}` VLLM also provides a JSON based chat template for Llama 4: -* `examples/tool_chat_template_llama4_json.jinja` - this is based on the "official" chat template for the Llama 4 +* - this is based on the "official" chat template for the Llama 4 models, but tweaked so that it works better with vLLM. For Llama 4 use `--tool-call-parser llama4_json examples/tool_chat_template_llama4_json.jinja`. @@ -191,7 +191,7 @@ Supported models: Recommended flags: `--tool-call-parser granite --chat-template examples/tool_chat_template_granite.jinja` -`examples/tool_chat_template_granite.jinja`: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. +: this is a modified chat template from the original on Huggingface. Parallel function calls are supported. * `ibm-granite/granite-3.1-8b-instruct` @@ -203,7 +203,7 @@ The chat template from Huggingface can be used directly. Parallel function calls Recommended flags: `--tool-call-parser granite-20b-fc --chat-template examples/tool_chat_template_granite_20b_fc.jinja` -`examples/tool_chat_template_granite_20b_fc.jinja`: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. +: this is a modified chat template from the original on Huggingface, which is not vLLM compatible. It blends function description elements from the Hermes template and follows the same system prompt as "Response Generation" mode from [the paper](https://arxiv.org/abs/2407.00121). Parallel function calls are supported. ### InternLM Models (`internlm`) @@ -236,6 +236,13 @@ For Qwen2.5, the chat template in tokenizer_config.json has already included sup Flags: `--tool-call-parser hermes` +### DeepSeek-V3 Models (`deepseek_v3`) + +Supported models: +* `deepseek-ai/DeepSeek-V3-0324` + +Flags: `--tool-call-parser deepseek_v3 --chat-template examples/tool_chat_template_deepseekv3.jinja` + ### Models with Pythonic Tool Calls (`pythonic`) A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models. @@ -253,12 +260,12 @@ Limitations: Example supported models: -* `meta-llama/Llama-3.2-1B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `meta-llama/Llama-3.2-3B-Instruct`\* (use with `examples/tool_chat_template_llama3.2_pythonic.jinja`) -* `Team-ACE/ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with `examples/tool_chat_template_toolace.jinja`) -* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`) -* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with `examples/tool_chat_template_llama4_pythonic.jinja`) +* `meta-llama/Llama-3.2-1B-Instruct`\* (use with ) +* `meta-llama/Llama-3.2-3B-Instruct`\* (use with ) +* `Team-ACE/ToolACE-8B` (use with ) +* `fixie-ai/ultravox-v0_4-ToolACE-8B` (use with ) +* `meta-llama/Llama-4-Scout-17B-16E-Instruct`\* (use with ) +* `meta-llama/Llama-4-Maverick-17B-128E-Instruct`\* (use with ) Flags: `--tool-call-parser pythonic --chat-template {see_above}` @@ -270,7 +277,7 @@ Llama's smaller models frequently fail to emit tool calls in the correct format. ## How to write a tool parser plugin -A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in . Here is a summary of a plugin file: diff --git a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md index 8beb92ef7da0a3e50b7be726b78dd9631e9f95a6..4459cc61e1cde4772bbe220dace7809d4e02ecf7 100644 --- a/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md +++ b/docs/source/getting_started/installation/ai_accelerator/tpu.inc.md @@ -158,7 +158,7 @@ sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev Run the setup script: ```bash -VLLM_TARGET_DEVICE="tpu" python setup.py develop +VLLM_TARGET_DEVICE="tpu" python -m pip install -e . ``` ## Set up using Docker diff --git a/docs/source/getting_started/installation/gpu/cuda.inc.md b/docs/source/getting_started/installation/gpu/cuda.inc.md index 46bdb08ebb77c921bbce06193f269c6ff11c7591..06915f09dd5171dcc9885ca656e796c823bde849 100644 --- a/docs/source/getting_started/installation/gpu/cuda.inc.md +++ b/docs/source/getting_started/installation/gpu/cuda.inc.md @@ -1,6 +1,6 @@ # Installation -vLLM contains pre-compiled C++ and CUDA (12.1) binaries. +vLLM contains pre-compiled C++ and CUDA (12.6) binaries. ## Requirements @@ -23,12 +23,12 @@ Therefore, it is recommended to install vLLM with a **fresh new** environment. I You can install vLLM using either `pip` or `uv pip`: ```console -# Install vLLM with CUDA 12.4. +# Install vLLM with CUDA 12.6. pip install vllm # If you are using pip. uv pip install vllm # If you are using uv. ``` -As of now, vLLM's binaries are compiled with CUDA 12.4 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.1, 11.8, and public PyTorch release versions: +As of now, vLLM's binaries are compiled with CUDA 12.6 and public PyTorch release versions by default. We also provide vLLM binaries compiled with CUDA 12.8, 11.8, and public PyTorch release versions: ```console # Install vLLM with CUDA 11.8. diff --git a/docs/source/getting_started/installation/gpu/rocm.inc.md b/docs/source/getting_started/installation/gpu/rocm.inc.md index 21c8d7d01adebcd1271a5e8498169c427508bfb5..dc74368fe2c96e289b00ba3ac3e265c816c791b1 100644 --- a/docs/source/getting_started/installation/gpu/rocm.inc.md +++ b/docs/source/getting_started/installation/gpu/rocm.inc.md @@ -73,7 +73,22 @@ Currently, there are no pre-built ROCm wheels. You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) ::: -3. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: +3. If you choose to build AITER yourself to use a certain branch or commit, you can build AITER using the following steps: + + ```console + python3 -m pip uninstall -y aiter + git clone --recursive https://github.com/ROCm/aiter.git + cd aiter + git checkout $AITER_BRANCH_OR_COMMIT + git submodule sync; git submodule update --init --recursive + python3 setup.py develop + ``` + + :::{note} + You will need to config the `$AITER_BRANCH_OR_COMMIT` for your purpose. + ::: + +4. Build vLLM. For example, vLLM on ROCM 6.3 can be built with the following steps: ```bash $ pip install --upgrade pip diff --git a/docs/source/getting_started/installation/gpu/xpu.inc.md b/docs/source/getting_started/installation/gpu/xpu.inc.md index fbf5421eeec5b358732a2c73246708121dcbb2d6..4ab41a21c2a15e6e4c2c6ca184a00a66f7aa40c6 100644 --- a/docs/source/getting_started/installation/gpu/xpu.inc.md +++ b/docs/source/getting_started/installation/gpu/xpu.inc.md @@ -35,13 +35,6 @@ pip install -v -r requirements/xpu.txt VLLM_TARGET_DEVICE=xpu python setup.py install ``` -- Finally, due to a known issue of conflict dependency(oneapi related) in torch-xpu 2.6 and ipex-xpu 2.6, we install ipex here. This will be fixed in the ipex-xpu 2.7. - -```console -pip install intel-extension-for-pytorch==2.6.10+xpu \ - --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ -``` - :::{note} - FP16 is the default data type in the current XPU backend. The BF16 data type is supported on Intel Data Center GPU, not supported on Intel Arc GPU yet. @@ -81,5 +74,3 @@ python -m vllm.entrypoints.openai.api_server \ ``` By default, a ray instance will be launched automatically if no existing one is detected in the system, with `num-gpus` equals to `parallel_config.world_size`. We recommend properly starting a ray cluster before execution, referring to the helper script. - -There are some new features coming with ipex-xpu 2.6, e.g. **chunked prefill**, **V1 engine support**, **lora**, **MoE**, etc. diff --git a/docs/source/getting_started/installation/python_env_setup.inc.md b/docs/source/getting_started/installation/python_env_setup.inc.md index a03d35030fe8a46563ef0b83ac89165032b2d674..00b61ea5c826468fd37cec437858b238d35fb40c 100644 --- a/docs/source/getting_started/installation/python_env_setup.inc.md +++ b/docs/source/getting_started/installation/python_env_setup.inc.md @@ -14,6 +14,6 @@ Or you can create a new Python environment using [uv](https://docs.astral.sh/uv/ ```console # (Recommended) Create a new uv environment. Use `--seed` to install `pip` and `setuptools` in the environment. -uv venv vllm --python 3.12 --seed -source vllm/bin/activate +uv venv --python 3.12 --seed +source .venv/bin/activate ``` diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 25189b006c2602d866712a1e9b0ce675d889894d..298ba59f7d8b6e5885170b369660c1cadf9b3a0d 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -19,8 +19,8 @@ If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/ It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: ```console -uv venv myenv --python 3.12 --seed -source myenv/bin/activate +uv venv --python 3.12 --seed +source .venv/bin/activate uv pip install vllm ``` diff --git a/docs/source/index.md b/docs/source/index.md index 43b330e4b432e5c83a2282809db1b498621625bb..db2192e87dcf2b4998733c3d2df50d5923629486 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -90,6 +90,8 @@ models/extensions/index :maxdepth: 1 features/quantization/index +features/multimodal_inputs +features/prompt_embeds features/lora features/tool_calling features/reasoning_outputs @@ -117,7 +119,7 @@ training/rlhf.md serving/offline_inference serving/openai_compatible_server -serving/multimodal_inputs +serving/serve_args serving/distributed_serving serving/metrics serving/engine_args @@ -181,6 +183,7 @@ design/v1/metrics :maxdepth: 2 contributing/overview +contributing/deprecation_policy contributing/profiling/profiling_index contributing/dockerfile/dockerfile contributing/model/index @@ -193,11 +196,8 @@ contributing/vulnerability_management :caption: API Reference :maxdepth: 2 -api/offline_inference/index -api/engine/index -api/inference_params -api/multimodal/index -api/model/index +api/summary +api/vllm/vllm ::: % Latest news and acknowledgements diff --git a/docs/source/models/generative_models.md b/docs/source/models/generative_models.md index 3291006ed668ca9bd0c95999740bbd5ad1be54bf..dd765e4a976583c26705a5fd33a7d0a76ced2a0f 100644 --- a/docs/source/models/generative_models.md +++ b/docs/source/models/generative_models.md @@ -14,7 +14,7 @@ Usually, this is automatically inferred so you don't have to specify it. ## Offline Inference The {class}`~vllm.LLM` class provides various methods for offline inference. -See [Engine Arguments](#engine-args) for a list of options when initializing the model. +See for a list of options when initializing the model. ### `LLM.generate` diff --git a/docs/source/models/pooling_models.md b/docs/source/models/pooling_models.md index 7daa0ec1de4de6596c823d8365dd850464f2e01c..3fd35e2e8bd1798042467f53fcafd0331b25444e 100644 --- a/docs/source/models/pooling_models.md +++ b/docs/source/models/pooling_models.md @@ -60,7 +60,7 @@ which takes priority over both the model's and Sentence Transformers's defaults. ## Offline Inference The {class}`~vllm.LLM` class provides various methods for offline inference. -See [Engine Arguments](#engine-args) for a list of options when initializing the model. +See for a list of options when initializing the model. ### `LLM.encode` @@ -140,6 +140,7 @@ Our [OpenAI-Compatible Server](#openai-compatible-server) provides endpoints tha - [Pooling API](#pooling-api) is similar to `LLM.encode`, being applicable to all types of pooling models. - [Embeddings API](#embeddings-api) is similar to `LLM.embed`, accepting both text and [multi-modal inputs](#multimodal-inputs) for embedding models. +- [Classification API](#classification-api) is similar to `LLM.classify` and is applicable to sequence classification models. - [Score API](#score-api) is similar to `LLM.score` for cross-encoder models. ## Matryoshka Embeddings diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index bc68e34832ccb116826f4c1d5a69b26a3bf58952..2b18ea197fd170619032e9594f2a927a06d9384c 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -54,7 +54,7 @@ For a model to be compatible with the Transformers backend for vLLM it must: If the compatible model is: -- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remode-code` for the . +- on the Hugging Face Model Hub, simply set `trust_remote_code=True` for or `--trust-remote-code` for the . - in a local directory, simply pass directory path to `model=` for or `vllm serve ` for the . This means that, with the Transformers backend for vLLM, new models can be used before they are officially supported in Transformers or vLLM! @@ -168,6 +168,66 @@ If vLLM successfully returns text (for generative models) or hidden states (for Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM. Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. +#### Download a model + +If you prefer, you can use the Hugging Face CLI to [download a model](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-download) or specific files from a model repository: + +```console +# Download a model +huggingface-cli download HuggingFaceH4/zephyr-7b-beta + +# Specify a custom cache directory +huggingface-cli download HuggingFaceH4/zephyr-7b-beta --cache-dir ./path/to/cache + +# Download a specific file from a model repo +huggingface-cli download HuggingFaceH4/zephyr-7b-beta eval_results.json +``` + +#### List the downloaded models + +Use the Hugging Face CLI to [manage models](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#scan-your-cache) stored in local cache: + +```console +# List cached models +huggingface-cli scan-cache + +# Show detailed (verbose) output +huggingface-cli scan-cache -v + +# Specify a custom cache directory +huggingface-cli scan-cache --dir ~/.cache/huggingface/hub +``` + +#### Delete a cached model + +Use the Hugging Face CLI to interactively [delete downloaded model](https://huggingface.co/docs/huggingface_hub/guides/manage-cache#clean-your-cache) from the cache: + +```console +# The `delete-cache` command requires extra dependencies to work with the TUI. +# Please run `pip install huggingface_hub[cli]` to install them. + +# Launch the interactive TUI to select models to delete +$ huggingface-cli delete-cache +? Select revisions to delete: 1 revisions selected counting for 438.9M. + ○ None of the following (if selected, nothing will be deleted). +Model BAAI/bge-base-en-v1.5 (438.9M, used 1 week ago) +❯ ◉ a5beb1e3: main # modified 1 week ago + +Model BAAI/bge-large-en-v1.5 (1.3G, used 1 week ago) + ○ d4aa6901: main # modified 1 week ago + +Model BAAI/bge-reranker-base (1.1G, used 4 weeks ago) + ○ 2cfc18c9: main # modified 4 weeks ago + +Press to select, to validate and to quit without modification. + +# Need to confirm after selected +? Select revisions to delete: 1 revision(s) selected. +? 1 revisions selected counting for 438.9M. Confirm deletion ? Yes +Start deletion. +Done. Deleted 1 repo(s) and 0 revision(s) for a total of 438.9M. +``` + #### Using a proxy Here are some tips for loading/downloading models from Hugging Face using a proxy: @@ -239,7 +299,9 @@ print(output) See [this page](#generative-models) for more information on how to use generative models. -#### Text Generation (`--task generate`) +#### Text Generation + +Specified using `--task generate`. :::{list-table} :widths: 25 25 50 5 5 @@ -385,6 +447,11 @@ See [this page](#generative-models) for more information on how to use generativ * `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. * ✅︎ * ✅︎ +- * `GraniteMoeHybridForCausalLM` + * Granite 4.0 MoE Hybrid + * `ibm-granite/granite-4.0-tiny-preview`, etc. + * ✅︎ + * ✅︎ - * `GraniteMoeSharedForCausalLM` * Granite MoE Shared * `ibm-research/moe-7b-1b-active-shared-experts` (test model) @@ -472,7 +539,7 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `OLMo2ForCausalLM` * OLMo2 - * `allenai/OLMo2-7B-1124`, etc. + * `allenai/OLMo-2-0425-1B`, etc. * * ✅︎ - * `OLMoEForCausalLM` @@ -542,8 +609,8 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ - * `Qwen3MoeForCausalLM` * Qwen3MoE - * `Qwen/Qwen3-MoE-15B-A2B`, etc. - * ✅︎ + * `Qwen/Qwen3-30B-A3B`, etc. + * * ✅︎ - * `StableLmForCausalLM` * StableLM @@ -585,6 +652,11 @@ See [this page](#generative-models) for more information on how to use generativ * `Zyphra/Zamba2-7B-instruct`, `Zyphra/Zamba2-2.7B-instruct`, `Zyphra/Zamba2-1.2B-instruct`, etc. * * +- * `MiMoForCausalLM` + * MiMo + * `XiaomiMiMo/MiMo-7B-RL`, etc. + * + * ::: :::{note} @@ -600,7 +672,9 @@ Since some model architectures support both generative and pooling tasks, you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. ::: -#### Text Embedding (`--task embed`) +#### Text Embedding + +Specified using `--task embed`. :::{list-table} :widths: 25 25 50 5 5 @@ -613,7 +687,7 @@ you should explicitly specify the task type to ensure that the model is used in * [PP](#distributed-serving) - * `BertModel` * BERT-based - * `BAAI/bge-base-en-v1.5`, etc. + * `BAAI/bge-base-en-v1.5`, `Snowflake/snowflake-arctic-embed-xs`, etc. * * - * `Gemma2Model` @@ -626,6 +700,26 @@ you should explicitly specify the task type to ensure that the model is used in * `parasail-ai/GritLM-7B-vllm`. * ✅︎ * ✅︎ +- * `GteModel` + * Arctic-Embed-2.0-M + * `Snowflake/snowflake-arctic-embed-m-v2.0`. + * + * ︎ +- * `GteNewModel` + * mGTE-TRM (see note) + * `Alibaba-NLP/gte-multilingual-base`, etc. + * ︎ + * ︎ +- * `ModernBertModel` + * ModernBERT-based + * `Alibaba-NLP/gte-modernbert-base`, etc. + * ︎ + * ︎ +- * `NomicBertModel` + * Nomic BERT + * `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. + * ︎ + * ︎ - * `LlamaModel`, `LlamaForCausalLM`, `MistralModel`, etc. * Llama-based * `intfloat/e5-mistral-7b-instruct`, etc. @@ -638,12 +732,12 @@ you should explicitly specify the task type to ensure that the model is used in * ✅︎ - * `RobertaModel`, `RobertaForMaskedLM` * RoBERTa-based - * `sentence-transformers/all-roberta-large-v1`, `sentence-transformers/all-roberta-large-v1`, etc. + * `sentence-transformers/all-roberta-large-v1`, etc. * * - * `XLMRobertaModel` * XLM-RoBERTa-based - * `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, etc. + * `intfloat/multilingual-e5-large`, `jinaai/jina-reranker-v2-base-multilingual`, `Snowflake/snowflake-arctic-embed-l-v2.0`, `jinaai/jina-embeddings-v3`(see note), etc. * * ::: @@ -661,11 +755,21 @@ For both the 1.5B and 7B variants, you also need to enable `--trust-remote-code` See [relevant issue on HF Transformers](https://github.com/huggingface/transformers/issues/34882). ::: +:::{note} +`jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. +::: + +:::{note} +The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. +::: + If your model is not in the above list, we will try to automatically convert the model using {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings of the whole prompt are extracted from the normalized hidden state corresponding to the last token. -#### Reward Modeling (`--task reward`) +#### Reward Modeling + +Specified using `--task reward`. :::{list-table} :widths: 25 25 50 5 5 @@ -706,7 +810,9 @@ For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`. ::: -#### Classification (`--task classify`) +#### Classification + +Specified using `--task classify`. :::{list-table} :widths: 25 25 50 5 5 @@ -732,7 +838,9 @@ e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "r If your model is not in the above list, we will try to automatically convert the model using {func}`~vllm.model_executor.models.adapters.as_classification_model`. By default, the class probabilities are extracted from the softmaxed hidden state corresponding to the last token. -#### Sentence Pair Scoring (`--task score`) +#### Sentence Pair Scoring + +Specified using `--task score`. :::{list-table} :widths: 25 25 50 5 5 @@ -819,7 +927,9 @@ vLLM currently only supports adding LoRA to the language backbone of multimodal See [this page](#generative-models) for more information on how to use generative models. -#### Text Generation (`--task generate`) +#### Text Generation + +Specified using `--task generate`. :::{list-table} :widths: 25 25 15 20 5 5 5 @@ -986,11 +1096,18 @@ See [this page](#generative-models) for more information on how to use generativ * ✅︎ * ✅︎ * ✅︎ +- * `MiniMaxVL01ForConditionalGeneration` + * MiniMax-VL + * T + IE+ + * `MiniMaxAI/MiniMax-VL-01`, etc. + * + * ✅︎ + * ✅︎ - * `Mistral3ForConditionalGeneration` * Mistral3 * T + I+ * `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. - * + * ✅︎ * ✅︎ * ✅︎ - * `MllamaForConditionalGeneration` @@ -1014,6 +1131,13 @@ See [this page](#generative-models) for more information on how to use generativ * * ✅︎ * ✅︎ +- * `Ovis` + * Ovis2, Ovis1.6 + * T + I+ + * `AIDC-AI/Ovis2-1B`, `AIDC-AI/Ovis1.6-Llama3.2-3B`, etc. + * + * + * ✅︎ - * `PaliGemmaForConditionalGeneration` * PaliGemma, PaliGemma 2 * T + IE @@ -1106,11 +1230,6 @@ See [this page](#generative-models) for more information on how to use generativ E Pre-computed embeddings can be inputted for this modality. + Multiple items can be inputted per text prompt for this modality. -:::{important} -Pan-and-scan image pre-processing is currently supported on V0 (but not V1). -You can enable it by passing `--mm-processor-kwargs '{"do_pan_and_scan": true}'`. -::: - :::{warning} Both V0 and V1 support `Gemma3ForConditionalGeneration` for text-only inputs. However, there are differences in how they handle text + image inputs: @@ -1130,7 +1249,7 @@ This limitation exists because the model's mixed attention pattern (bidirectiona ::: :::{note} -`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support backends other than FlashAttention. +`h2oai/h2ovl-mississippi-2b` will be available in V1 once we support head size 80. ::: :::{note} @@ -1193,7 +1312,9 @@ Since some model architectures support both generative and pooling tasks, you should explicitly specify the task type to ensure that the model is used in pooling mode instead of generative mode. ::: -#### Text Embedding (`--task embed`) +#### Text Embedding + +Specified using `--task embed`. Any text generation model can be converted into an embedding model by passing `--task embed`. @@ -1233,7 +1354,9 @@ The following table lists those that are tested in vLLM. * ✅︎ ::: -#### Transcription (`--task transcription`) +#### Transcription + +Specified using `--task transcription`. Speech2Text models trained specifically for Automatic Speech Recognition. diff --git a/docs/source/performance/optimization.md b/docs/source/performance/optimization.md index ccbe8a367061fa7e821ce39c17d5427d4a639706..4160f078496268c9e900ecf7d0ff54d611ed9425 100644 --- a/docs/source/performance/optimization.md +++ b/docs/source/performance/optimization.md @@ -2,65 +2,188 @@ # Optimization and Tuning +This guide covers optimization strategies and performance tuning for vLLM V1. + ## Preemption Due to the auto-regressive nature of transformer architecture, there are times when KV cache space is insufficient to handle all batched requests. -The vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes -available again. When this occurs, the following warning is printed: +In such cases, vLLM can preempt requests to free up KV cache space for other requests. Preempted requests are recomputed when sufficient KV cache space becomes +available again. When this occurs, you may see the following warning: ```text -WARNING 05-09 00:49:33 scheduler.py:1057 Sequence group 0 is preempted by PreemptionMode.SWAP mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 +WARNING 05-09 00:49:33 scheduler.py:1057 Sequence group 0 is preempted by PreemptionMode.RECOMPUTE mode because there is not enough KV cache space. This can affect the end-to-end performance. Increase gpu_memory_utilization or tensor_parallel_size to provide more KV cache memory. total_cumulative_preemption_cnt=1 ``` While this mechanism ensures system robustness, preemption and recomputation can adversely affect end-to-end latency. -If you frequently encounter preemptions from the vLLM engine, consider the following actions: +If you frequently encounter preemptions, consider the following actions: + +- Increase `gpu_memory_utilization`. vLLM pre-allocates GPU cache using this percentage of memory. By increasing utilization, you can provide more KV cache space. +- Decrease `max_num_seqs` or `max_num_batched_tokens`. This reduces the number of concurrent requests in a batch, thereby requiring less KV cache space. +- Increase `tensor_parallel_size`. This shards model weights across GPUs, allowing each GPU to have more memory available for KV cache. However, increasing this value may cause excessive synchronization overhead. +- Increase `pipeline_parallel_size`. This distributes model layers across GPUs, reducing the memory needed for model weights on each GPU, indirectly leaving more memory available for KV cache. However, increasing this value may cause latency penalties. -- Increase `gpu_memory_utilization`. The vLLM pre-allocates GPU cache by using gpu_memory_utilization% of memory. By increasing this utilization, you can provide more KV cache space. -- Decrease `max_num_seqs` or `max_num_batched_tokens`. This can reduce the number of concurrent requests in a batch, thereby requiring less KV cache space. -- Increase `tensor_parallel_size`. This approach shards model weights, so each GPU has more memory available for KV cache. -- Increase `pipeline_parallel_size`. This approach distributes model layers across GPUs, reducing the memory needed for model weights on each GPU, which indirectly leaves more memory available for KV cache. +You can monitor the number of preemption requests through Prometheus metrics exposed by vLLM. Additionally, you can log the cumulative number of preemption requests by setting `disable_log_stats=False`. -You can also monitor the number of preemption requests through Prometheus metrics exposed by the vLLM. Additionally, you can log the cumulative number of preemption requests by setting disable_log_stats=False. +In vLLM V1, the default preemption mode is `RECOMPUTE` rather than `SWAP`, as recomputation has lower overhead in the V1 architecture. (chunked-prefill)= ## Chunked Prefill -vLLM supports an experimental feature chunked prefill. Chunked prefill allows to chunk large prefills into smaller chunks and batch them together with decode requests. +Chunked prefill allows vLLM to process large prefills in smaller chunks and batch them together with decode requests. This feature helps improve both throughput and latency by better balancing compute-bound (prefill) and memory-bound (decode) operations. + +In vLLM V1, **chunked prefill is always enabled by default**. This is different from vLLM V0, where it was conditionally enabled based on model characteristics. + +With chunked prefill enabled, the scheduling policy prioritizes decode requests. It batches all pending decode requests before scheduling any prefill operations. When there are available tokens in the `max_num_batched_tokens` budget, it schedules pending prefills. If a pending prefill request cannot fit into `max_num_batched_tokens`, it automatically chunks it. + +This policy has two benefits: + +- It improves ITL and generation decode because decode requests are prioritized. +- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. -You can enable the feature by specifying `--enable-chunked-prefill` in the command line or setting `enable_chunked_prefill=True` in the LLM constructor. +### Performance Tuning with Chunked Prefill + +You can tune the performance by adjusting `max_num_batched_tokens`: + +- Smaller values (e.g., 2048) achieve better inter-token latency (ITL) because there are fewer prefills slowing down decodes. +- Higher values achieve better time to first token (TTFT) as you can process more prefill tokens in a batch. +- For optimal throughput, we recommend setting `max_num_batched_tokens > 8096` especially for smaller models on large GPUs. +- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the V0 default scheduling policy (except that it still prioritizes decodes). ```python from vllm import LLM -llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True) -# Set max_num_batched_tokens to tune performance. -# NOTE: 2048 is the default max_num_batched_tokens for chunked prefill. -# llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_chunked_prefill=True, max_num_batched_tokens=2048) +# Set max_num_batched_tokens to tune performance +llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", max_num_batched_tokens=16384) ``` -By default, vLLM scheduler prioritizes prefills and doesn't batch prefill and decode to the same batch. -This policy optimizes the TTFT (time to the first token), but incurs slower ITL (inter token latency) and inefficient GPU utilization. +See related papers for more details ( or ). -Once chunked prefill is enabled, the policy is changed to prioritize decode requests. -It batches all pending decode requests to the batch before scheduling any prefill. -When there are available token_budget (`max_num_batched_tokens`), it schedules pending prefills. -If a last pending prefill request cannot fit into `max_num_batched_tokens`, it chunks it. +## Parallelism Strategies -This policy has two benefits: +vLLM supports multiple parallelism strategies that can be combined to optimize performance across different hardware configurations. -- It improves ITL and generation decode because decode requests are prioritized. -- It helps achieve better GPU utilization by locating compute-bound (prefill) and memory-bound (decode) requests to the same batch. +### Tensor Parallelism (TP) -You can tune the performance by changing `max_num_batched_tokens`. By default, it is set to 2048. -Smaller `max_num_batched_tokens` achieves better ITL because there are fewer prefills interrupting decodes. -Higher `max_num_batched_tokens` achieves better TTFT as you can put more prefill to the batch. +Tensor parallelism shards model parameters across multiple GPUs within each model layer. This is the most common strategy for large model inference within a single node. -- If `max_num_batched_tokens` is the same as `max_model_len`, that's almost the equivalent to the default scheduling policy (except that it still prioritizes decodes). -- Note that the default value (2048) of `max_num_batched_tokens` is optimized for ITL, and it may have lower throughput than the default scheduler. +**When to use:** -We recommend you set `max_num_batched_tokens > 2048` for throughput. +- When the model is too large to fit on a single GPU +- When you need to reduce memory pressure per GPU to allow more KV cache space for higher throughput -See related papers for more details ( or ). +```python +from vllm import LLM + +# Split model across 4 GPUs +llm = LLM(model="meta-llama/Llama-3.3-70B-Instruct", tensor_parallel_size=4) +``` + +For models that are too large to fit on a single GPU (like 70B parameter models), tensor parallelism is essential. + +### Pipeline Parallelism (PP) + +Pipeline parallelism distributes model layers across multiple GPUs. Each GPU processes different parts of the model in sequence. + +**When to use:** + +- When you've already maxed out efficient tensor parallelism but need to distribute the model further, or across nodes +- For very deep and narrow models where layer distribution is more efficient than tensor sharding + +Pipeline parallelism can be combined with tensor parallelism for very large models: + +```python +from vllm import LLM + +# Combine pipeline and tensor parallelism +llm = LLM( + model="meta-llama/Llama-3.3-70B-Instruct, + tensor_parallel_size=4, + pipeline_parallel_size=2 +) +``` + +### Expert Parallelism (EP) + +Expert parallelism is a specialized form of parallelism for Mixture of Experts (MoE) models, where different expert networks are distributed across GPUs. + +**When to use:** -Please try out this feature and let us know your feedback via GitHub issues! +- Specifically for MoE models (like DeepSeekV3, Qwen3MoE, Llama-4) +- When you want to balance the expert computation load across GPUs + +Expert parallelism is enabled by setting `enable_expert_parallel=True`, which will use expert parallelism instead of tensor parallelism for MoE layers. +It will use the same degree of parallelism as what you have set for tensor parallelism. + +### Data Parallelism (DP) + +Data parallelism replicates the entire model across multiple GPU sets and processes different batches of requests in parallel. + +**When to use:** + +- When you have enough GPUs to replicate the entire model +- When you need to scale throughput rather than model size +- In multi-user environments where isolation between request batches is beneficial + +Data parallelism can be combined with the other parallelism strategies and is set by `data_parallel_size=N`. +Note that MoE layers will be sharded according to the product of the tensor parallel size and data parallel size. + +## Reducing Memory Usage + +If you encounter out-of-memory issues, consider these strategies: + +### Context Length and Batch Size + +You can reduce memory usage by limiting the context length and batch size: + +```python +from vllm import LLM + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + max_model_len=2048, # Limit context window + max_num_seqs=4 # Limit batch size +) +``` + +### Adjust CUDA Graph Compilation + +CUDA graph compilation in V1 uses more memory than in V0. You can reduce memory usage by adjusting the compilation level: + +```python +from vllm import LLM +from vllm.config import CompilationConfig, CompilationLevel + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + compilation_config=CompilationConfig( + level=CompilationLevel.PIECEWISE, + cudagraph_capture_sizes=[1, 2, 4, 8] # Capture fewer batch sizes + ) +) +``` + +Or, if you are not concerned about latency or overall performance, disable CUDA graph compilation entirely with `enforce_eager=True`: + +```python +from vllm import LLM + +llm = LLM( + model="meta-llama/Llama-3.1-8B-Instruct", + enforce_eager=True # Disable CUDA graph compilation +) +``` + +### Multimodal Models + +For multi-modal models, you can reduce memory usage by limiting the number of images/videos per request: + +```python +from vllm import LLM + +# Accept up to 2 images per prompt +llm = LLM( + model="Qwen/Qwen2.5-VL-3B-Instruct", + limit_mm_per_prompt={"image": 2} +) +``` diff --git a/docs/source/serving/engine_args.md b/docs/source/serving/engine_args.md index 97ea01cd3b2e66c35cc675a50478740935021b37..9325a2406e8cae7a62aa0a92886f9794b939b87f 100644 --- a/docs/source/serving/engine_args.md +++ b/docs/source/serving/engine_args.md @@ -7,6 +7,8 @@ Engine arguments control the behavior of the vLLM engine. - For [offline inference](#offline-inference), they are part of the arguments to `LLM` class. - For [online serving](#openai-compatible-server), they are part of the arguments to `vllm serve`. +For references to all arguments available from `vllm serve` see the [serve args](#serve-args) documentation. + Below, you can find an explanation of every engine argument: diff --git a/docs/source/serving/offline_inference.md b/docs/source/serving/offline_inference.md index 894878ed14e764c814af965d6bdcacf06b1d8155..433d2e894dd8dd9746194de25e9e9a0088652f0b 100644 --- a/docs/source/serving/offline_inference.md +++ b/docs/source/serving/offline_inference.md @@ -25,7 +25,7 @@ The available APIs depend on the type of model that is being run: Please refer to the above pages for more details about each API. :::{seealso} -[API Reference](/api/offline_inference/index) +[API Reference](#offline-inference-api) ::: (configuration-options)= @@ -33,7 +33,7 @@ Please refer to the above pages for more details about each API. ## Configuration Options This section lists the most common options for running the vLLM engine. -For a full list, refer to the [Engine Arguments](#engine-args) page. +For a full list, refer to the page. (model-resolution)= @@ -74,6 +74,8 @@ Tensor parallelism (`tensor_parallel_size` option) can be used to split the mode The following code splits the model across 2 GPUs. ```python +from vllm import LLM + llm = LLM(model="ibm-granite/granite-3.1-8b-instruct", tensor_parallel_size=2) ``` @@ -95,7 +97,7 @@ You can convert the model checkpoint to a sharded checkpoint using config file values > defaults`. -e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file. -::: - ## API Reference (completions-api)= @@ -443,6 +397,130 @@ The input format is the same as [Embeddings API](#embeddings-api), but the outpu Code example: +(classification-api)= + +### Classification API + +Our Classification API directly supports Hugging Face sequence-classification models such as [ai21labs/Jamba-tiny-reward-dev](https://huggingface.co/ai21labs/Jamba-tiny-reward-dev) and [jason9693/Qwen2.5-1.5B-apeach](https://huggingface.co/jason9693/Qwen2.5-1.5B-apeach). + +We automatically wrap any other transformer via `as_classification_model()`, which pools on the last token, attaches a `RowParallelLinear` head, and applies a softmax to produce per-class probabilities. + +Code example: + +#### Example Requests + +You can classify multiple texts by passing an array of strings: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": [ + "Loved the new café—coffee was great.", + "This update broke everything. Frustrating." + ] + }' +``` + +Response: + +```bash +{ + "id": "classify-7c87cac407b749a6935d8c7ce2a8fba2", + "object": "list", + "created": 1745383065, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + }, + { + "index": 1, + "label": "Spoiled", + "probs": [ + 0.26448777318000793, + 0.7355121970176697 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 20, + "total_tokens": 20, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +You can also pass a string directly to the `input` field: + +Request: + +```bash +curl -v "http://127.0.0.1:8000/classify" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "jason9693/Qwen2.5-1.5B-apeach", + "input": "Loved the new café—coffee was great." + }' +``` + +Response: + +```bash +{ + "id": "classify-9bf17f2847b046c7b2d5495f4b4f9682", + "object": "list", + "created": 1745383213, + "model": "jason9693/Qwen2.5-1.5B-apeach", + "data": [ + { + "index": 0, + "label": "Default", + "probs": [ + 0.565970778465271, + 0.4340292513370514 + ], + "num_classes": 2 + } + ], + "usage": { + "prompt_tokens": 10, + "total_tokens": 10, + "completion_tokens": 0, + "prompt_tokens_details": null + } +} +``` + +#### Extra parameters + +The following [pooling parameters](#pooling-params) are supported. + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-classification-pooling-params +:end-before: end-classification-pooling-params +::: + +The following extra parameters are supported: + +:::{literalinclude} ../../../vllm/entrypoints/openai/protocol.py +:language: python +:start-after: begin-classification-extra-params +:end-before: end-classification-extra-params +::: + (score-api)= ### Score API diff --git a/docs/source/serving/serve_args.md b/docs/source/serving/serve_args.md new file mode 100644 index 0000000000000000000000000000000000000000..edb49f4ba6de4c1e84e36bad091072b78137dec8 --- /dev/null +++ b/docs/source/serving/serve_args.md @@ -0,0 +1,47 @@ +(serve-args)= + +# Server Arguments + +The `vllm serve` command is used to launch the OpenAI-compatible server. + +## CLI Arguments + +The following are all arguments available from the `vllm serve` command: + + +```{eval-rst} +.. argparse:: + :module: vllm.entrypoints.openai.cli_args + :func: create_parser_for_docs + :prog: vllm serve + :nodefaultconst: + :markdownhelp: +``` + +## Configuration file + +You can load CLI arguments via a [YAML](https://yaml.org/) config file. +The argument names must be the long form of those outlined [above](#serve-args). + +For example: + +```yaml +# config.yaml + +model: meta-llama/Llama-3.1-8B-Instruct +host: "127.0.0.1" +port: 6379 +uvicorn-log-level: "info" +``` + +To use the above config file: + +```bash +vllm serve --config config.yaml +``` + +:::{note} +In case an argument is supplied simultaneously using command line and the config file, the value from the command line will take precedence. +The order of priorities is `command line > config file values > defaults`. +e.g. `vllm serve SOME_MODEL --config config.yaml`, SOME_MODEL takes precedence over `model` in config file. +::: diff --git a/examples/lmcache/README.md b/examples/lmcache/README.md index 7d0c23f529bb2173ba926706a68874499419b9b7..95a6bf995b2fd586574c8ff7e810202fd204e2ad 100644 --- a/examples/lmcache/README.md +++ b/examples/lmcache/README.md @@ -44,8 +44,8 @@ The main script generates several log files: ## 2. CPU Offload Examples -- `cpu_offload_lmcache_v0.py` - CPU offloading implementation for vLLM v0 -- `cpu_offload_lmcache_v1.py` - CPU offloading implementation for vLLM v1 +- `python cpu_offload_lmcache.py -v v0` - CPU offloading implementation for vLLM v0 +- `python cpu_offload_lmcache.py -v v1` - CPU offloading implementation for vLLM v1 ## 3. KV Cache Sharing diff --git a/examples/lmcache/cpu_offload_lmcache_v0.py b/examples/lmcache/cpu_offload_lmcache.py similarity index 54% rename from examples/lmcache/cpu_offload_lmcache_v0.py rename to examples/lmcache/cpu_offload_lmcache.py index 37aea281032fd8b826c79182f3d28d72f49f87e1..eedb47dfc12e5112b2029ccec44866ec975d4642 100644 --- a/examples/lmcache/cpu_offload_lmcache_v0.py +++ b/examples/lmcache/cpu_offload_lmcache.py @@ -1,25 +1,40 @@ # SPDX-License-Identifier: Apache-2.0 """ This file demonstrates the example usage of cpu offloading -with LMCache. +with LMCache in vLLM v1 or v0. + +Usage: + + Specify vLLM version + + -v v0 : Use LMCacheConnector + model = mistralai/Mistral-7B-Instruct-v0.2 + (Includes enable_chunked_prefill = True) + + -v v1 : Use LMCacheConnectorV1 (default) + model = meta-llama/Meta-Llama-3.1-8B-Instruct + (Without enable_chunked_prefill) Note that `lmcache` is needed to run this example. Requirements: Linux, Python: 3.10 or higher, CUDA: 12.1 Learn more about LMCache environment setup, please refer to: https://docs.lmcache.ai/getting_started/installation.html """ +import argparse import contextlib import os import time +from dataclasses import asdict from lmcache.experimental.cache_engine import LMCacheEngineBuilder from lmcache.integration.vllm.utils import ENGINE_NAME from vllm import LLM, SamplingParams from vllm.config import KVTransferConfig +from vllm.engine.arg_utils import EngineArgs -def setup_environment_variables(): +def setup_environment_variables(vllm_version: str): # LMCache-related environment variables # Use experimental features in LMCache os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" @@ -29,21 +44,37 @@ def setup_environment_variables(): os.environ["LMCACHE_LOCAL_CPU"] = "True" # Set local CPU memory limit to 5.0 GB os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" + if vllm_version == "v0": + os.environ["VLLM_USE_V1"] = "0" @contextlib.contextmanager -def build_llm_with_lmcache(): - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector", "kv_role":"kv_both"}') +def build_llm_with_lmcache(lmcache_connector: str, model: str, + vllm_version: str): + ktc = KVTransferConfig( + kv_connector=lmcache_connector, + kv_role="kv_both", + ) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. # Note: LMCache supports chunked prefill (see vLLM#14505, LMCache#392). - llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", - kv_transfer_config=ktc, - max_model_len=8000, - enable_chunked_prefill=True, - gpu_memory_utilization=0.8) - + if vllm_version == "v0": + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + enable_chunked_prefill=True, # Only in v0 + ) + else: + llm_args = EngineArgs( + model=model, + kv_transfer_config=ktc, + max_model_len=8000, + gpu_memory_utilization=0.8, + ) + + llm = LLM(**asdict(llm_args)) try: yield llm finally: @@ -57,6 +88,9 @@ def print_output( sampling_params: SamplingParams, req_str: str, ): + # Should be able to see logs like the following: + # `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0` + # This indicates that the KV cache has been stored in LMCache. start = time.time() outputs = llm.generate(prompt, sampling_params) print("-" * 50) @@ -68,10 +102,29 @@ def print_output( print("-" * 50) +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-v", + "--version", + choices=["v0", "v1"], + default="v1", + help="Specify vLLM version (default: v1)") + return parser.parse_args() + + def main(): - setup_environment_variables() + args = parse_args() + + if args.version == "v0": + lmcache_connector = "LMCacheConnector" + model = "mistralai/Mistral-7B-Instruct-v0.2" + else: + lmcache_connector = "LMCacheConnectorV1" + model = "meta-llama/Meta-Llama-3.1-8B-Instruct" + + setup_environment_variables(args.version) - with build_llm_with_lmcache() as llm: + with build_llm_with_lmcache(lmcache_connector, model, args.version) as llm: # This example script runs two requests with a shared prefix. # Define the shared prompt and specific prompts diff --git a/examples/lmcache/cpu_offload_lmcache_v1.py b/examples/lmcache/cpu_offload_lmcache_v1.py deleted file mode 100644 index f44075a36965fc12e677bad2d69a4d0797378b05..0000000000000000000000000000000000000000 --- a/examples/lmcache/cpu_offload_lmcache_v1.py +++ /dev/null @@ -1,57 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -""" -This file demonstrates the example usage of cpu offloading -with LMCache in vLLM v1. - -Note that lmcache needs to be installed to run this example. -Learn more about LMCache in https://github.com/LMCache/LMCache. -""" -import os - -from lmcache.experimental.cache_engine import LMCacheEngineBuilder -from lmcache.integration.vllm.utils import ENGINE_NAME - -from vllm import LLM, SamplingParams -from vllm.config import KVTransferConfig - -# LMCache-related environment variables -# Use experimental features in LMCache -os.environ["LMCACHE_USE_EXPERIMENTAL"] = "True" -# LMCache is set to use 256 tokens per chunk -os.environ["LMCACHE_CHUNK_SIZE"] = "256" -# Enable local CPU backend in LMCache -os.environ["LMCACHE_LOCAL_CPU"] = "True" -# Set local CPU memory limit to 5.0 GB -os.environ["LMCACHE_MAX_LOCAL_CPU_SIZE"] = "5.0" - -# This example script runs two requests with a shared prefix. -shared_prompt = "Hello, how are you?" * 1000 -first_prompt = [ - shared_prompt + "Hello, my name is", -] -second_prompt = [ - shared_prompt + "Tell me a very long story", -] - -sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - -ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') -# Set GPU memory utilization to 0.8 for an A40 GPU with 40GB -# memory. Reduce the value if your GPU has less memory. -# Note that LMCache is not compatible with chunked prefill for now. -llm = LLM(model="meta-llama/Meta-Llama-3.1-8B-Instruct", - kv_transfer_config=ktc, - max_model_len=8000, - gpu_memory_utilization=0.8) - -# Should be able to see logs like the following: -# `LMCache INFO: Storing KV cache for 6006 out of 6006 tokens for request 0` -# This indicates that the KV cache has been stored in LMCache. -outputs = llm.generate(first_prompt, sampling_params) -for output in outputs: - generated_text = output.outputs[0].text - print(f"Generated text: {generated_text!r}") - -# Clean up lmcache backend -LMCacheEngineBuilder.destroy(ENGINE_NAME) diff --git a/examples/lmcache/disagg_prefill_lmcache_v0.py b/examples/lmcache/disagg_prefill_lmcache_v0.py index 7da6fb7aaa230fbcc56983a9da4dfad818737f92..66cc941852307f3d679f8c2d7241e4c1390b1df6 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v0.py +++ b/examples/lmcache/disagg_prefill_lmcache_v0.py @@ -49,9 +49,10 @@ def run_prefill(prefill_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="LMCacheConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", @@ -78,9 +79,10 @@ def run_decode(prefill_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="LMCacheConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", diff --git a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh index 831ef0bb574bf1ffdce6803db2336ac27ccbd051..5719fa821292389fb10dde3bf558442afc7a64e5 100644 --- a/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh +++ b/examples/lmcache/disagg_prefill_lmcache_v1/disagg_vllm_launcher.sh @@ -54,6 +54,6 @@ elif [[ $1 == "decoder" ]]; then else echo "Invalid role: $1" - echo "Should be either prefill, decode" + echo "Should be either prefiller, decoder" exit 1 fi diff --git a/examples/lmcache/kv_cache_sharing_lmcache_v1.py b/examples/lmcache/kv_cache_sharing_lmcache_v1.py index af1b4351dd54c8cd91f2cda5d0724d199ae63b2e..7748f8ca6133abe2ebdef95b5dbb3663fa8a8687 100644 --- a/examples/lmcache/kv_cache_sharing_lmcache_v1.py +++ b/examples/lmcache/kv_cache_sharing_lmcache_v1.py @@ -49,8 +49,8 @@ def run_store(store_done, prompts): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", + kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", @@ -76,8 +76,8 @@ def run_retrieve(store_done, prompts, timeout=1): sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"LMCacheConnectorV1", "kv_role":"kv_both"}') + ktc = KVTransferConfig(kv_connector="LMCacheConnectorV1", + kv_role="kv_both") # Set GPU memory utilization to 0.8 for an A40 GPU with 40GB # of memory. Reduce the value if your GPU has less memory. llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.2", diff --git a/examples/offline_inference/basic/chat.py b/examples/offline_inference/basic/chat.py index 6857c6e9e31dfaac9f0fe472fc78d34f689d7be1..8e6f78ed7de21fa078981770dde2d14a2197355a 100644 --- a/examples/offline_inference/basic/chat.py +++ b/examples/offline_inference/basic/chat.py @@ -7,9 +7,8 @@ from vllm.utils import FlexibleArgumentParser def create_parser(): parser = FlexibleArgumentParser() # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int) diff --git a/examples/offline_inference/basic/generate.py b/examples/offline_inference/basic/generate.py index 54b52b22a45a977cccfe19fc8435c963977e481b..72f4a8208386d5a9d842fd993adeabb9236a66c5 100644 --- a/examples/offline_inference/basic/generate.py +++ b/examples/offline_inference/basic/generate.py @@ -7,9 +7,8 @@ from vllm.utils import FlexibleArgumentParser def create_parser(): parser = FlexibleArgumentParser() # Add engine args - engine_group = parser.add_argument_group("Engine arguments") - EngineArgs.add_cli_args(engine_group) - engine_group.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") + EngineArgs.add_cli_args(parser) + parser.set_defaults(model="meta-llama/Llama-3.2-1B-Instruct") # Add sampling params sampling_group = parser.add_argument_group("Sampling parameters") sampling_group.add_argument("--max-tokens", type=int) diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 15519bfed9cb49bbe893557ee7e3f2604a04e3d9..b532bf42adfbaf637ffc4710af821920525dd2ff 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -68,7 +68,7 @@ def get_current_weather(city: str, state: str, unit: 'str'): "partly cloudly, with highs in the 90's.") -tool_funtions = {"get_current_weather": get_current_weather} +tool_functions = {"get_current_weather": get_current_weather} tools = [{ "type": "function", @@ -122,7 +122,7 @@ messages.append({ # above defined function tool_calls = json.loads(output) tool_answers = [ - tool_funtions[call['name']](**call['arguments']) for call in tool_calls + tool_functions[call['name']](**call['arguments']) for call in tool_calls ] # append the answer as a tool message and let the LLM give you an answer diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf58f073698eab081d4734a1501cf4cd..f636a08c0b097e2ba29f01ca65b1c50f9f325460 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -65,11 +65,17 @@ def parse_args(): type=int, default=0, help="Master node port") + parser.add_argument("--enforce-eager", + action='store_true', + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action='store_true', + help="Trust remote code.") return parser.parse_args() def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, GPUs_per_dp_rank): + dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, max_tokens=[16, 20][global_dp_rank % 2]) # Create an LLM. - llm = LLM(model=model, - tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=True, - enable_expert_parallel=True) + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=True, + trust_remote_code=trust_remote_code, + ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): @@ -155,7 +164,8 @@ if __name__ == "__main__": proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size)) + tp_size, args.enforce_eager, + args.trust_remote_code)) proc.start() procs.append(proc) exit_code = 0 diff --git a/examples/offline_inference/disaggregated-prefill-v1/README.md b/examples/offline_inference/disaggregated-prefill-v1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f708eb25383801cd5af701744d80d8b9b1488aa8 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/README.md @@ -0,0 +1,9 @@ +# Disaggregated Prefill V1 + +This example contains scripts that demonstrate disaggregated prefill in the offline setting of vLLM. + +## Files + +- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. +- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. +- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. diff --git a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py index 66efbc0c9deecf083047c27aa0fd6db16f95ad7f..11918f72feec8398bc1cdddcbfdedad0589f38ff 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/decode_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/decode_example.py @@ -16,16 +16,17 @@ except FileNotFoundError: sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=10) -llm = LLM( - model="meta-llama/Llama-3.2-1B-Instruct", - enforce_eager=True, - gpu_memory_utilization=0.8, - max_num_batched_tokens=64, - max_num_seqs=16, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both",' - '"kv_connector_extra_config": {"shared_storage_path": "local_storage"}}' - )) #, max_model_len=2048, max_num_batched_tokens=2048) +llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + gpu_memory_utilization=0.8, + max_num_batched_tokens=64, + max_num_seqs=16, + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) # 1ST generation (prefill instance) outputs = llm.generate(prompts, sampling_params) diff --git a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py index f7cbf6557d54f8ea0c25394fc86bb649a0ff4b33..798128301e0f0e4be7428e7023e1635987c375fb 100644 --- a/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py +++ b/examples/offline_inference/disaggregated-prefill-v1/prefill_example.py @@ -17,11 +17,12 @@ sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, gpu_memory_utilization=0.8, - kv_transfer_config=KVTransferConfig.from_cli( - '{"kv_connector":"SharedStorageConnector","kv_role":"kv_both", ' - '"kv_connector_extra_config": ' - '{"shared_storage_path": "local_storage"}}') - ) #, max_model_len=2048, max_num_batched_tokens=2048) + kv_transfer_config=KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={ + "shared_storage_path": "local_storage" + })) #, max_model_len=2048, max_num_batched_tokens=2048) # 1ST generation (prefill instance) outputs = llm.generate( diff --git a/examples/offline_inference/disaggregated_prefill.py b/examples/offline_inference/disaggregated_prefill.py index d60985146c5c9172483b422597dc21183c5a6488..bb6fdd48f79e1f728cada9f22a37c88d7afe5a5a 100644 --- a/examples/offline_inference/disaggregated_prefill.py +++ b/examples/offline_inference/disaggregated_prefill.py @@ -32,9 +32,10 @@ def run_prefill(prefill_done): # This instance is the prefill node (kv_producer, rank 0). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_producer","kv_rank":0,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="PyNcclConnector", + kv_role="kv_producer", + kv_rank=0, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. @@ -71,9 +72,10 @@ def run_decode(prefill_done): # This instance is the decode node (kv_consumer, rank 1). # The number of parallel instances for KV cache transfer is set to 2, # as required for PyNcclConnector. - ktc = KVTransferConfig.from_cli( - '{"kv_connector":"PyNcclConnector","kv_role":"kv_consumer","kv_rank":1,"kv_parallel_size":2}' - ) + ktc = KVTransferConfig(kv_connector="PyNcclConnector", + kv_role="kv_consumer", + kv_rank=1, + kv_parallel_size=2) # Set GPU memory utilization to 0.8 for an A6000 GPU with 40GB # memory. You may need to adjust the value to fit your GPU. diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 474b745a610607da3c0626c3c931ffe1a666babe..615f67e9f8d818c38ba42f8d15d8fddd0a575978 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -36,6 +36,10 @@ def parse_args(): help="downloaded from the eagle repo " \ "https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/" ) + parser.add_argument("--method", + type=str, + default='eagle', + choices=['eagle', 'eagle3']) parser.add_argument("--max_num_seqs", type=int, default=8) parser.add_argument("--num_prompts", type=int, default=80) parser.add_argument("--num_spec_tokens", type=int, default=2) @@ -53,7 +57,13 @@ def main(): args = parse_args() model_dir = "meta-llama/Llama-3.1-8B-Instruct" - eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + + if args.method == 'eagle': + eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B" + elif args.method == 'eagle3': + eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B" + else: + raise ValueError(f"unknown method: {args.method}") max_model_len = 2048 @@ -81,7 +91,7 @@ def main(): max_num_seqs=args.max_num_seqs, gpu_memory_utilization=0.8, speculative_config={ - "method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle", + "method": args.method, "model": eagle_dir, "num_speculative_tokens": args.num_spec_tokens, "draft_tensor_parallel_size": args.draft_tp, @@ -95,6 +105,13 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + # print the generated text + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + if not hasattr(outputs, "metrics") or outputs.metrics is None: return @@ -108,8 +125,8 @@ def main(): acceptance_counts[step] += count print("-" * 50) - print(f"mean acceptance length: \ - {sum(acceptance_counts) / acceptance_counts[0]:.2f}") + print(f"mean acceptance length (including bonus tokens): \ + {1 + (sum(acceptance_counts) / acceptance_counts[0]):.2f}") print("-" * 50) # print acceptance at each token position diff --git a/examples/offline_inference/lora_with_quantization_inference.py b/examples/offline_inference/lora_with_quantization_inference.py index ab235ddd75455fa434a67fee4069b01683bab573..b6608ec6e958002c23ad316f462058deb31ee564 100644 --- a/examples/offline_inference/lora_with_quantization_inference.py +++ b/examples/offline_inference/lora_with_quantization_inference.py @@ -75,43 +75,38 @@ def initialize_engine(model: str, quantization: str, lora_repo: Optional[str]) -> LLMEngine: """Initialize the LLMEngine.""" - if quantization == "bitsandbytes": - # QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique. - # It quantizes the model when loading, with some config info from the - # LoRA adapter repo. So need to set the parameter of load_format and - # qlora_adapter_name_or_path as below. - engine_args = EngineArgs(model=model, - quantization=quantization, - qlora_adapter_name_or_path=lora_repo, - enable_lora=True, - max_lora_rank=64) - else: - engine_args = EngineArgs(model=model, - quantization=quantization, - enable_lora=True, - max_loras=4) + engine_args = EngineArgs(model=model, + quantization=quantization, + enable_lora=True, + max_lora_rank=64, + max_loras=4) return LLMEngine.from_engine_args(engine_args) def main(): """Main function that sets up and runs the prompt processing.""" - test_configs = [{ - "name": "qlora_inference_example", - 'model': "huggyllama/llama-7b", - 'quantization': "bitsandbytes", - 'lora_repo': 'timdettmers/qlora-flan-7b' - }, { - "name": "AWQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', - 'quantization': "awq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' - }, { - "name": "GPTQ_inference_with_lora_example", - 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', - 'quantization': "gptq", - 'lora_repo': 'jashing/tinyllama-colorist-lora' - }] + test_configs = [ + # QLoRA (https://arxiv.org/abs/2305.14314) + { + "name": "qlora_inference_example", + 'model': "huggyllama/llama-7b", + 'quantization': "bitsandbytes", + 'lora_repo': 'timdettmers/qlora-flan-7b' + }, + { + "name": "AWQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ', + 'quantization': "awq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + }, + { + "name": "GPTQ_inference_with_lora_example", + 'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ', + 'quantization': "gptq", + 'lora_repo': 'jashing/tinyllama-colorist-lora' + } + ] for test_config in test_configs: print( diff --git a/examples/offline_inference/neuron_eagle.py b/examples/offline_inference/neuron_eagle.py new file mode 100644 index 0000000000000000000000000000000000000000..4f63f1a2fb3c890ae2f6ffc19383752606a49768 --- /dev/null +++ b/examples/offline_inference/neuron_eagle.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with an EAGLE speculative +decoding model on neuron. To use EAGLE speculative decoding, you must use +a draft model that is specifically fine-tuned for EAGLE speculation. +Additionally, to use EAGLE with NxD Inference, the draft model must include +the LM head weights from the target model. These weights are shared between +the draft and target model. +""" + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "What is annapurna labs?", +] + +# Create a sampling params object. +sampling_params = SamplingParams(top_k=1, max_tokens=500, ignore_eos=True) + +# Create an LLM. +llm = LLM( + model="/home/ubuntu/model_hf/Meta-Llama-3.1-70B-Instruct", + speculative_config={ + "model": "/home/ubuntu/model_hf/Llama-3.1-70B-Instruct-EAGLE-Draft", + "num_speculative_tokens": 5, + "max_model_len": 2048 + }, + max_num_seqs=4, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in neuronx-distributed-inference. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + tensor_parallel_size=32, + override_neuron_config={ + "enable_eagle_speculation": True, + "enable_fused_speculation": True + }, +) + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, \n\n\n\ Generated text: {generated_text!r}") diff --git a/examples/offline_inference/neuron_speculation.py b/examples/offline_inference/neuron_speculation.py new file mode 100644 index 0000000000000000000000000000000000000000..bef434bae5bacfb591cdbd25796ef8008bdd0fe4 --- /dev/null +++ b/examples/offline_inference/neuron_speculation.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +This example shows how to run offline inference with a speculative +decoding model on neuron. +""" + +import os + +from vllm import LLM, SamplingParams + +# Sample prompts. +prompts = [ + "Hello, I am a language model and I can help", + "The president of the United States is", + "The capital of France is", +] + + +def config_buckets(): + """Configure context length and token gen buckets.""" + # creates XLA hlo graphs for all the context length buckets. + os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" + # creates XLA hlo graphs for all the token gen buckets. + os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + + +def initialize_model(): + """Create an LLM with speculative decoding.""" + return LLM( + model="openlm-research/open_llama_7b", + speculative_config={ + "model": "openlm-research/open_llama_3b", + "num_speculative_tokens": 4, + "max_model_len": 2048 + }, + max_num_seqs=4, + max_model_len=2048, + block_size=2048, + use_v2_block_manager=True, + device="neuron", + tensor_parallel_size=32, + ) + + +def process_requests(model: LLM, sampling_params: SamplingParams): + """Generate texts from prompts and print them.""" + outputs = model.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +def main(): + """Main function that sets up the model and processes prompts.""" + config_buckets() + model = initialize_model() + # Create a sampling params object. + sampling_params = SamplingParams(max_tokens=100, top_k=1) + process_requests(model, sampling_params) + + +if __name__ == '__main__': + main() diff --git a/examples/offline_inference/openai/openai_batch.md b/examples/offline_inference/openai_batch/README.md similarity index 94% rename from examples/offline_inference/openai/openai_batch.md rename to examples/offline_inference/openai_batch/README.md index d271573aa96fcb0f304b8c17ba006915c574a299..42a19f71e9de313c0f17679b554272b3e10b6926 100644 --- a/examples/offline_inference/openai/openai_batch.md +++ b/examples/offline_inference/openai_batch/README.md @@ -8,7 +8,7 @@ This is a guide to performing batch inference using the OpenAI batch file format The OpenAI batch file format consists of a series of json objects on new lines. -[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai/openai_example_batch.jsonl) +[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl) Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. @@ -30,13 +30,13 @@ We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` e To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -48,7 +48,7 @@ The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` ```console -python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ### Step 3: Check your results @@ -65,10 +65,10 @@ $ cat results.jsonl The batch runner supports remote input and output urls that are accessible via http/https. -For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl`, you can run +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run ```console -python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ## Example 3: Integrating with AWS S3 @@ -89,13 +89,13 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -103,7 +103,7 @@ $ cat offline_inference/openai/openai_example_batch.jsonl Now upload your batch file to your S3 bucket. ```console -aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` ### Step 2: Generate your presigned urls diff --git a/examples/offline_inference/openai/openai_example_batch.jsonl b/examples/offline_inference/openai_batch/openai_example_batch.jsonl similarity index 100% rename from examples/offline_inference/openai/openai_example_batch.jsonl rename to examples/offline_inference/openai_batch/openai_example_batch.jsonl diff --git a/examples/offline_inference/profiling.py b/examples/offline_inference/profiling.py index 9c818d0757345e7024adfa34c0ba2d3cfce13315..3cf0c340d670597940b667a2a16bfd0f3e00ebff 100644 --- a/examples/offline_inference/profiling.py +++ b/examples/offline_inference/profiling.py @@ -14,7 +14,7 @@ import tqdm from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs -from vllm.profiler import layerwise_profile +from vllm.profiler.layerwise_profile import layerwise_profile from vllm.utils import FlexibleArgumentParser BATCH_SIZE_DEFAULT = 1 @@ -193,7 +193,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], batch_size = context.batch_size prompt_len = context.prompt_len - scheduler_config = llm.llm_engine.scheduler_config + scheduler_config = llm.llm_engine.vllm_config.scheduler_config max_model_len = llm.llm_engine.model_config.max_model_len max_num_batched_tokens = scheduler_config.max_num_batched_tokens max_num_seqs = scheduler_config.max_num_seqs diff --git a/examples/offline_inference/qwen2_5_omni/only_thinker.py b/examples/offline_inference/qwen2_5_omni/only_thinker.py index c75a990120e0741aad71853c0a85b42a84cfd70c..52b6e977eaa2a014b2246d03273f27aaf4e88c73 100644 --- a/examples/offline_inference/qwen2_5_omni/only_thinker.py +++ b/examples/offline_inference/qwen2_5_omni/only_thinker.py @@ -47,8 +47,7 @@ def get_mixed_modalities_query() -> QueryResult: "image": ImageAsset("cherry_blossom").pil_image.convert("RGB"), "video": - VideoAsset(name="sample_demo_1.mp4", - num_frames=16).np_ndarrays, + VideoAsset(name="baby_reading", num_frames=16).np_ndarrays, }, }, limit_mm_per_prompt={ @@ -66,7 +65,7 @@ def get_use_audio_in_video_query() -> QueryResult: "<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>" f"{question}<|im_end|>\n" f"<|im_start|>assistant\n") - asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16) + asset = VideoAsset(name="baby_reading", num_frames=16) audio = asset.get_audio(sampling_rate=16000) assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. " "Please launch this example with " @@ -141,7 +140,7 @@ def main(args): print(generated_text) -if __name__ == "__main__": +def parse_args(): parser = FlexibleArgumentParser( description='Demo on using vLLM for offline inference with ' 'audio language models') @@ -156,5 +155,9 @@ if __name__ == "__main__": default=None, help="Set the seed when initializing `vllm.LLM`.") - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() main(args) diff --git a/examples/offline_inference/qwen_1m.py b/examples/offline_inference/qwen_1m.py new file mode 100644 index 0000000000000000000000000000000000000000..64a1f4c54b670e490ec93a2f736534a12e0a7cde --- /dev/null +++ b/examples/offline_inference/qwen_1m.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import os +from urllib.request import urlopen + +from vllm import LLM, SamplingParams + +os.environ["VLLM_ATTENTION_BACKEND"] = "DUAL_CHUNK_FLASH_ATTN" +os.environ["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" + + +def load_prompt() -> str: + # Test cases with various lengths can be found at: + # + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/64k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/200k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/600k.txt + # https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2.5-1M/test-data/1m.txt + + with urlopen( + "https://qianwen-res.oss-cn-beijing.aliyuncs.com" + "/Qwen2.5-1M/test-data/600k.txt", + timeout=5) as response: + prompt = response.read().decode('utf-8') + return prompt + + +# Processing the prompt. +def process_requests(llm: LLM, prompts: list[str]) -> None: + # Create a sampling params object. + sampling_params = SamplingParams( + temperature=0.7, + top_p=0.8, + top_k=20, + repetition_penalty=1.05, + detokenize=True, + max_tokens=256, + ) + # Generate texts from the prompts. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt_token_ids = output.prompt_token_ids + generated_text = output.outputs[0].text + print(f"Prompt length: {len(prompt_token_ids)}, " + f"Generated text: {generated_text!r}") + + +# Create an LLM. +def initialize_engine() -> LLM: + llm = LLM(model="Qwen/Qwen2.5-7B-Instruct-1M", + max_model_len=1048576, + tensor_parallel_size=4, + enforce_eager=True, + enable_chunked_prefill=True, + max_num_batched_tokens=131072) + return llm + + +def main(): + llm = initialize_engine() + prompt = load_prompt() + process_requests(llm, [prompt]) + + +if __name__ == '__main__': + main() diff --git a/examples/offline_inference/reproduciblity.py b/examples/offline_inference/reproducibility.py similarity index 100% rename from examples/offline_inference/reproduciblity.py rename to examples/offline_inference/reproducibility.py diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py index c6d9e6b47e21f12a12c0d9b1d19155b25215c1e9..bb61a0a29e32207600f5494660444400d164c929 100644 --- a/examples/offline_inference/torchrun_example.py +++ b/examples/offline_inference/torchrun_example.py @@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below. see `tests/distributed/test_torchrun_example.py` for the unit test. """ +import torch.distributed as dist + from vllm import LLM, SamplingParams # Create prompts, the same across all ranks @@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # all ranks have the same random seed, so that sampling can be # deterministic across ranks. llm = LLM( - model="facebook/opt-125m", + model="meta-llama/Llama-3.1-8B", tensor_parallel_size=2, + pipeline_parallel_size=2, distributed_executor_backend="external_launcher", - seed=0, + max_model_len=32768, + seed=1, ) outputs = llm.generate(prompts, sampling_params) # all ranks will have the same outputs -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +if dist.get_rank() == 0: print("-" * 50) -""" + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}\n") + print("-" * 50) + """ Further tips: 1. to communicate control messages across all ranks, use the cpu group, diff --git a/examples/offline_inference/tpu.py b/examples/offline_inference/tpu.py index dea717c36082f577ba13d65eab6ddd7725b8d7ad..71cd88f2788ad748d1c0cb36d1990979ff373b5b 100644 --- a/examples/offline_inference/tpu.py +++ b/examples/offline_inference/tpu.py @@ -22,7 +22,8 @@ def main(): # In real workloads, `enforace_eager` should be `False`. llm = LLM(model="Qwen/Qwen2-1.5B-Instruct", max_num_batched_tokens=64, - max_num_seqs=4) + max_num_seqs=4, + max_model_len=128) outputs = llm.generate(prompts, sampling_params) print("-" * 50) for output, answer in zip(outputs, answers): diff --git a/examples/offline_inference/vision_language.py b/examples/offline_inference/vision_language.py index d02ac17cfdd68ba6db4ec70df87f5d2aa05f1db0..c54f328c7a382a0458dce77bb5b20c2efbdca9cc 100644 --- a/examples/offline_inference/vision_language.py +++ b/examples/offline_inference/vision_language.py @@ -45,7 +45,7 @@ def run_aria(questions: list[str], modality: str) -> ModelRequestData: max_model_len=4096, max_num_seqs=2, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [(f"<|im_start|>user\n<|img|>{question}" @@ -71,7 +71,7 @@ def run_aya_vision(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"crop_to_patches": True}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ f"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{question}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" @@ -92,7 +92,7 @@ def run_blip2(questions: list[str], modality: str) -> ModelRequestData: prompts = [f"Question: {question} Answer:" for question in questions] engine_args = EngineArgs( model="Salesforce/blip2-opt-6.7b", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -110,7 +110,7 @@ def run_chameleon(questions: list[str], modality: str) -> ModelRequestData: model="facebook/chameleon-7b", max_model_len=4096, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -130,7 +130,7 @@ def run_deepseek_vl2(questions: list[str], modality: str) -> ModelRequestData: max_model_len=4096, max_num_seqs=2, hf_overrides={"architectures": ["DeepseekVLV2ForCausalLM"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -155,7 +155,7 @@ def run_florence2(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, trust_remote_code=True, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = ["" for _ in questions] @@ -175,7 +175,7 @@ def run_fuyu(questions: list[str], modality: str) -> ModelRequestData: model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -194,7 +194,7 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=2048, max_num_seqs=2, mm_processor_kwargs={"do_pan_and_scan": True}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [("user\n" @@ -219,7 +219,7 @@ def run_glm4v(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, enforce_eager=True, hf_overrides={"architectures": ["GLM4VForCausalLM"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -246,7 +246,7 @@ def run_h2ovl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=8192, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -287,7 +287,7 @@ def run_idefics3(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 3 * 364 }, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [( f"<|begin_of_text|>User:{question}\nAssistant:" @@ -314,7 +314,7 @@ def run_smolvlm(questions: list[str], modality: str) -> ModelRequestData: "longest_edge": 384 }, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ (f"<|im_start|>User:{question}\nAssistant:") @@ -337,7 +337,7 @@ def run_internvl(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -378,7 +378,7 @@ def run_kimi_vl(questions: list[str], modality: str) -> ModelRequestData: model="moonshotai/Kimi-VL-A3B-Instruct", trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -398,7 +398,7 @@ def run_llava(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-1.5-7b-hf", max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -415,7 +415,7 @@ def run_llava_next(questions: list[str], modality: str) -> ModelRequestData: engine_args = EngineArgs( model="llava-hf/llava-v1.6-mistral-7b-hf", max_model_len=8192, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -437,7 +437,7 @@ def run_llava_next_video(questions: list[str], model="llava-hf/LLaVA-NeXT-Video-7B-hf", max_model_len=8192, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -465,7 +465,7 @@ def run_llava_onevision(questions: list[str], engine_args = EngineArgs( model="llava-hf/llava-onevision-qwen2-7b-ov-hf", max_model_len=16384, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -488,7 +488,7 @@ def run_mantis(questions: list[str], modality: str) -> ModelRequestData: model="TIGER-Lab/Mantis-8B-siglip-llama3", max_model_len=4096, hf_overrides={"architectures": ["MantisForConditionalGeneration"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) stop_token_ids = [128009] @@ -529,7 +529,7 @@ def run_minicpmv_base(questions: list[str], modality: str, model_name): max_model_len=4096, max_num_seqs=2, trust_remote_code=True, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) # NOTE The stop_token_ids are different for various versions of MiniCPM-V # 2.0 @@ -584,7 +584,7 @@ def run_mistral3(questions: list[str], modality: str) -> ModelRequestData: max_model_len=8192, max_num_seqs=2, tensor_parallel_size=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -610,7 +610,7 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=8192, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -645,7 +645,7 @@ def run_llama4(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=4, tensor_parallel_size=8, gpu_memory_utilization=0.4, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -680,7 +680,7 @@ def run_molmo(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, dtype="bfloat16", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [ @@ -706,7 +706,38 @@ def run_nvlm_d(questions: list[str], modality: str) -> ModelRequestData: trust_remote_code=True, max_model_len=4096, tensor_parallel_size=4, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + messages = [[{ + 'role': 'user', + 'content': f"\n{question}" + }] for question in questions] + prompts = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompts=prompts, + ) + + +# Ovis +def run_ovis(questions: list[str], modality: str) -> ModelRequestData: + assert modality == "image" + + model_name = "AIDC-AI/Ovis2-1B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -733,7 +764,7 @@ def run_paligemma(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma-3b-mix-224", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -750,7 +781,7 @@ def run_paligemma2(questions: list[str], modality: str) -> ModelRequestData: prompts = ["caption en" for _ in questions] engine_args = EngineArgs( model="google/paligemma2-3b-ft-docci-448", - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -787,7 +818,7 @@ def run_phi3v(questions: list[str], modality: str) -> ModelRequestData: max_num_seqs=2, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"num_crops": 16}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -821,7 +852,7 @@ def run_phi4mm(questions: list[str], modality: str) -> ModelRequestData: max_lora_rank=320, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={"dynamic_hd": 16}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) return ModelRequestData( @@ -842,7 +873,7 @@ def run_pixtral_hf(questions: list[str], modality: str) -> ModelRequestData: model=model_name, max_model_len=6144, max_num_seqs=2, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"[INST]{question}\n[IMG][/INST]" for question in questions] @@ -863,7 +894,7 @@ def run_qwen_vl(questions: list[str], modality: str) -> ModelRequestData: max_model_len=1024, max_num_seqs=2, hf_overrides={"architectures": ["QwenVLForConditionalGeneration"]}, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) prompts = [f"{question}Picture 1: \n" for question in questions] @@ -888,7 +919,7 @@ def run_qwen2_vl(questions: list[str], modality: str) -> ModelRequestData: "min_pixels": 28 * 28, "max_pixels": 1280 * 28 * 28, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -923,7 +954,7 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData: "max_pixels": 1280 * 28 * 28, "fps": 1, }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -957,7 +988,7 @@ def run_qwen2_5_omni(questions: list[str], modality: str): "max_pixels": 1280 * 28 * 28, "fps": [1], }, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) if modality == "image": @@ -990,7 +1021,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData: model=model_name, trust_remote_code=True, max_model_len=4096, - limit_mm_per_prompt={"image": 1}, + limit_mm_per_prompt={modality: 1}, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -1041,6 +1072,7 @@ model_example_map = { "llama4": run_llama4, "molmo": run_molmo, "NVLM_D": run_nvlm_d, + "ovis": run_ovis, "paligemma": run_paligemma, "paligemma2": run_paligemma2, "phi3_v": run_phi3v, @@ -1080,7 +1112,7 @@ def get_multi_modal_input(args): if args.modality == "video": # Input video and question - video = VideoAsset(name="sample_demo_1.mp4", + video = VideoAsset(name="baby_reading", num_frames=args.num_frames).np_ndarrays vid_questions = ["Why is this video funny?"] diff --git a/examples/offline_inference/vision_language_multi_image.py b/examples/offline_inference/vision_language_multi_image.py index 7f6608559f9c4a31d27818a272d54e520605447a..20a8e635e322f87ff3ae2b71973f9a86ce489e07 100644 --- a/examples/offline_inference/vision_language_multi_image.py +++ b/examples/offline_inference/vision_language_multi_image.py @@ -436,6 +436,36 @@ def load_nvlm_d(question: str, image_urls: list[str]) -> ModelRequestData: ) +# Ovis +def load_ovis(question: str, image_urls: list[str]) -> ModelRequestData: + model_name = "AIDC-AI/Ovis2-1B" + + engine_args = EngineArgs( + model=model_name, + max_model_len=8192, + max_num_seqs=2, + trust_remote_code=True, + dtype="half", + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + image_data=[fetch_image(url) for url in image_urls], + ) + + def load_pixtral_hf(question: str, image_urls: list[str]) -> ModelRequestData: model_name = "mistral-community/pixtral-12b" @@ -685,6 +715,7 @@ model_example_map = { "mistral3": load_mistral3, "mllama": load_mllama, "NVLM_D": load_nvlm_d, + "ovis": load_ovis, "phi3_v": load_phi3v, "phi4_mm": load_phi4mm, "pixtral_hf": load_pixtral_hf, diff --git a/examples/online_serving/chart-helm/values.yaml b/examples/online_serving/chart-helm/values.yaml index 9c48e7d061bf7e8daa102fc883dee03c50241f3e..28dba9a6f6882a4173013ea2e76b29eb4ce77813 100644 --- a/examples/online_serving/chart-helm/values.yaml +++ b/examples/online_serving/chart-helm/values.yaml @@ -8,7 +8,7 @@ image: # -- Image tag tag: "latest" # -- Container launch command - command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "bfloat16", "--host", "0.0.0.0", "--port", "8000"] + command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"] # -- Container port containerPort: 8000 diff --git a/examples/online_serving/disaggregated_serving/README.md b/examples/online_serving/disaggregated_serving/README.md new file mode 100644 index 0000000000000000000000000000000000000000..090afd7515ee084bf137d32c99ee4a00a6a5256c --- /dev/null +++ b/examples/online_serving/disaggregated_serving/README.md @@ -0,0 +1,8 @@ +# Disaggregated Serving + +This example contains scripts that demonstrate the disaggregated serving features of vLLM. + +## Files + +- `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances). +- `kv_events.sh` - Demonstrates KV cache event publishing. diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py similarity index 99% rename from examples/online_serving/disagg_examples/disagg_proxy_demo.py rename to examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index a701636f357a80613f772a9c94721e60fc7cb51c..c6d26778ee49703cc8903e8c18d4194abab655a3 100644 --- a/examples/online_serving/disagg_examples/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -4,7 +4,7 @@ This file provides a disaggregated prefilling proxy demo to demonstrate an example usage of XpYd disaggregated prefilling. We can launch multiple vllm instances (2 for prefill and 2 for decode), and launch this proxy demo through: - python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + python3 examples/online_serving/disaggregated_serving/disagg_proxy_demo.py \ --model $model_name \ --prefill localhost:8100 localhost:8101 \ --decode localhost:8200 localhost:8201 \ @@ -414,7 +414,7 @@ class ProxyServer: server.run() -if __name__ == "__main__": +def parse_args(): # Todo: allow more config parser = argparse.ArgumentParser("vLLM disaggregated proxy server.") parser.add_argument("--model", @@ -445,6 +445,10 @@ if __name__ == "__main__": default=8000, help="Server port number", ) - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() proxy_server = ProxyServer(args=args) proxy_server.run_server() diff --git a/examples/online_serving/disaggregated_serving/kv_events.sh b/examples/online_serving/disaggregated_serving/kv_events.sh new file mode 100644 index 0000000000000000000000000000000000000000..a111db2179fc9db4b005ab8dc9dd0d63fb788806 --- /dev/null +++ b/examples/online_serving/disaggregated_serving/kv_events.sh @@ -0,0 +1,86 @@ +#!/bin/bash +# This file demonstrates the KV cache event publishing +# We will launch a vllm instances configured to publish KV cache +# events and launch a simple subscriber to log those events. + +set -xe + +echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧" +sleep 1 + +MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct} + +# Trap the SIGINT signal (triggered by Ctrl+C) +trap 'cleanup' INT + +# Cleanup function +cleanup() { + echo "Caught Ctrl+C, cleaning up..." + # Cleanup commands + pgrep python | xargs kill -9 + pkill -f python + echo "Cleanup complete. Exiting." + exit 0 +} + +export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + +# a function that waits vLLM server to start +wait_for_server() { + local port=$1 + timeout 1200 bash -c " + until curl -s localhost:${port}/v1/completions > /dev/null; do + sleep 1 + done" && return 0 || return 1 +} + +vllm serve $MODEL_NAME \ + --port 8100 \ + --max-model-len 100 \ + --enforce-eager \ + --gpu-memory-utilization 0.8 \ + --trust-remote-code \ + --kv-events-config \ + '{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' & + +wait_for_server 8100 + +SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" + +python3 "$SCRIPT_DIR/kv_events_subscriber.py" & +sleep 1 + +# serve two example requests +output1=$(curl -X POST -s http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "'"$MODEL_NAME"'", +"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.", +"max_tokens": 80, +"temperature": 0 +}') + +output2=$(curl -X POST -s http://localhost:8100/v1/completions \ +-H "Content-Type: application/json" \ +-d '{ +"model": "'"$MODEL_NAME"'", +"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.", +"max_tokens": 80, +"temperature": 0 +}') + +# Cleanup commands +pkill -9 -u "$USER" -f python +pkill -9 -u "$USER" -f vllm + +sleep 1 + +echo "Cleaned up" + +# Print the outputs of the curl requests +echo "" +echo "Output of first request: $output1" +echo "Output of second request: $output2" + +echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉" +echo "" diff --git a/examples/online_serving/kv_events_subscriber.py b/examples/online_serving/kv_events_subscriber.py new file mode 100644 index 0000000000000000000000000000000000000000..88bbbebd7478770e8c0d7595991e172779cab7cf --- /dev/null +++ b/examples/online_serving/kv_events_subscriber.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional, Union + +import msgspec +import zmq +from msgspec.msgpack import Decoder + + +# +# Types copied from vllm.distributed.kv_events +# +class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, + gc=False): + ts: float + events: list[Any] + + +class KVCacheEvent(msgspec.Struct, + array_like=True, + omit_defaults=True, + gc=False, + tag=True): + """Base class for all KV cache-related events""" + + +class BlockStored(KVCacheEvent): + block_hashes: list[int] + parent_block_hash: Optional[int] + token_ids: list[int] + block_size: int + lora_id: Optional[int] + + +class BlockRemoved(KVCacheEvent): + block_hashes: list[int] + + +class AllBlocksCleared(KVCacheEvent): + pass + + +class KVEventBatch(EventBatch): + events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]] + + +def process_event(event_batch): + print(f"Received event batch at {event_batch.ts}:") + for event in event_batch.events: + print(f" - {event}") + + +def main(): + decoder = Decoder(type=KVEventBatch) + last_seq = -1 + + context = zmq.Context() + + # Set up the main subscription socket + sub = context.socket(zmq.SUB) + sub.connect("tcp://localhost:5557") + topic = "kv-events" + sub.setsockopt_string(zmq.SUBSCRIBE, topic) + + # Initialize replay socket + replay = context.socket(zmq.REQ) + replay.connect("tcp://localhost:5558") + poller = zmq.Poller() + poller.register(replay, zmq.POLLIN) + + print("Listening for KV cache events on topic:", topic) + + while True: + try: + if sub.poll(50): + _, seq_bytes, payload = sub.recv_multipart() + seq = int.from_bytes(seq_bytes, "big") + + if last_seq >= 0 and seq > last_seq + 1: + missed = seq - last_seq - 1 + print(f"Missed {missed} messages" + f" (last: {last_seq}, current: {seq})") + + replay.send((last_seq + 1).to_bytes(8, "big")) + + while poller.poll(timeout=200): + seq_bytes, replay_payload = replay.recv_multipart() + if not replay_payload: + # End of replay marker is sent as an empty frame + # for the payload + break + + replay_seq = int.from_bytes(seq_bytes, "big") + + if replay_seq > last_seq: + event_batch = decoder.decode(replay_payload) + process_event(event_batch) + last_seq = replay_seq + if replay_seq >= seq - 1: + break + + event_batch = decoder.decode(payload) + process_event(event_batch) + + # ... do other periodic work or check for shutdown ... + + except KeyboardInterrupt: + print("Interrupted") + break + except Exception as e: + print("Error decoding message:", e) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index 70db4d95e64941a163440b32545d55704e495966..2707d46f46e2aa6cbf17f2a5353c1472e68b8f63 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -1,23 +1,29 @@ # SPDX-License-Identifier: Apache-2.0 -"""An example showing how to use vLLM to serve multimodal models +"""An example showing how to use vLLM to serve multimodal models and run online serving with OpenAI client. Launch the vLLM server with the following command: (single image inference with Llava) -vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja +vllm serve llava-hf/llava-1.5-7b-hf (multi-image inference with Phi-3.5-vision-instruct) vllm serve microsoft/Phi-3.5-vision-instruct --task generate \ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b \ + --max-model-len 4096 --trust-remote-code + +run the script with +python openai_chat_completion_client_for_multimodal.py --chat-type audio """ + import base64 import requests from openai import OpenAI +from utils import get_first_model from vllm.utils import FlexibleArgumentParser @@ -31,9 +37,6 @@ client = OpenAI( base_url=openai_api_base, ) -models = client.models.list() -model = models.data[0].id - def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" @@ -46,7 +49,7 @@ def encode_base64_content_from_url(content_url: str) -> str: # Text-only inference -def run_text_only() -> None: +def run_text_only(model: str) -> None: chat_completion = client.chat.completions.create( messages=[{ "role": "user", @@ -61,7 +64,7 @@ def run_text_only() -> None: # Single-image input inference -def run_single_image() -> None: +def run_single_image(model: str) -> None: ## Use image url in the payload image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" @@ -117,7 +120,7 @@ def run_single_image() -> None: # Multi-image input inference -def run_multi_image() -> None: +def run_multi_image(model: str) -> None: image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" chat_completion_from_url = client.chat.completions.create( @@ -152,7 +155,7 @@ def run_multi_image() -> None: # Video input inference -def run_video() -> None: +def run_video(model: str) -> None: video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4" video_base64 = encode_base64_content_from_url(video_url) @@ -208,7 +211,7 @@ def run_video() -> None: # Audio input inference -def run_audio() -> None: +def run_audio(model: str) -> None: from vllm.assets.audio import AudioAsset audio_url = AudioAsset("winning_call").url @@ -318,7 +321,8 @@ def parse_args(): def main(args) -> None: chat_type = args.chat_type - example_function_map[chat_type]() + model = get_first_model(client) + example_function_map[chat_type](model) if __name__ == "__main__": diff --git a/examples/online_serving/openai_chat_completion_client_with_tools.py b/examples/online_serving/openai_chat_completion_client_with_tools.py index c25203860ff398176eb02c4c6adbe025b86cf987..94f9c157058645251bacd4509c3dd4457509fd60 100644 --- a/examples/online_serving/openai_chat_completion_client_with_tools.py +++ b/examples/online_serving/openai_chat_completion_client_with_tools.py @@ -7,12 +7,12 @@ IMPORTANT: for mistral, you must use one of the provided mistral tool call templates, or your own - the model default doesn't work for tool calls with vLLM See the vLLM docs on OpenAI server & tool calling for more details. -vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ +vllm serve mistralai/Mistral-7B-Instruct-v0.3 \ --chat-template examples/tool_chat_template_mistral.jinja \ --enable-auto-tool-choice --tool-call-parser mistral OR -vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ +vllm serve NousResearch/Hermes-2-Pro-Llama-3-8B \ --chat-template examples/tool_chat_template_hermes.jinja \ --enable-auto-tool-choice --tool-call-parser hermes """ diff --git a/examples/online_serving/openai_chat_completion_structured_outputs.py b/examples/online_serving/openai_chat_completion_structured_outputs.py index f71162e36efd20415589f139903665b670fa92b4..660369e55d40e8be422ef31b1c92c2e5d159dffd 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs.py @@ -112,8 +112,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): "alan.turing@enigma.com\n") try: - # The no-fallback option forces vLLM to use xgrammar, so when it fails - # you get a 400 with the reason why + # The guided_decoding_disable_fallback option forces vLLM to use + # xgrammar, so when it fails you get a 400 with the reason why completion = client.chat.completions.create( model=model, messages=[{ @@ -123,7 +123,8 @@ def extra_backend_options_completion(client: OpenAI, model: str): extra_body={ "guided_regex": r"\w+@\w+\.com\n", "stop": ["\n"], - "guided_decoding_backend": "xgrammar:no-fallback" + "guided_decoding_backend": "xgrammar", + "guided_decoding_disable_fallback": True, }, ) return completion.choices[0].message.content @@ -137,7 +138,7 @@ def main(): api_key="-", ) - model = "Qwen/Qwen2.5-3B-Instruct" + model = client.models.list().data[0].id print("Guided Choice Completion:") print(guided_choice_completion(client, model)) diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py index b807bc5405262790f35bea7f7c52acfa9b280bd2..42aa12c451c04c34903c70d6937760156956ab77 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_structural_tag.py @@ -59,7 +59,7 @@ and San Francisco? }] response = client.chat.completions.create( - model="meta-llama/Llama-3.1-8B-Instruct", + model=client.models.list().data[0].id, messages=messages, response_format={ "type": diff --git a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py index cb7f30d93255483ef7045dddfc1b2fc44f01d254..a04f0cdf12f76ba820a8ddb7c1bb2f9c18e660fa 100644 --- a/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py @@ -4,12 +4,12 @@ An example shows how to generate structured outputs from reasoning models like DeepSeekR1. The thinking process will not be guided by the JSON schema provided by the user. Only the final output will be structured. -To run this example, you need to start the vLLM server with the reasoning +To run this example, you need to start the vLLM server with the reasoning parser: ```bash vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --enable-reasoning --reasoning-parser deepseek_r1 + --reasoning-parser deepseek_r1 ``` This example demonstrates how to generate chat completions from reasoning models diff --git a/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py index 8c6470aa3dd41333d25702337fb9acba8729a275..9417abd3989a2ddeb7442c5117b108159266b6c5 100644 --- a/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_tool_calls_with_reasoning.py @@ -9,7 +9,7 @@ the reasoning parser and tool calling enabled. ```bash vllm serve Qwen/QwQ-32B \ - --enable-reasoning --reasoning-parser deepseek_r1 \ + --reasoning-parser deepseek_r1 \ --enable-auto-tool-choice --tool-call-parser hermes ``` diff --git a/examples/online_serving/openai_chat_completion_with_reasoning.py b/examples/online_serving/openai_chat_completion_with_reasoning.py index 6f5f7b5fa20ba70e32ae399bdfd3f53f5c629faf..4bf7731cb41e357dd6d5aa0153f21ed0e5db9805 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning.py @@ -8,7 +8,7 @@ with the reasoning parser: ```bash vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --enable-reasoning --reasoning-parser deepseek_r1 + --reasoning-parser deepseek_r1 ``` This example demonstrates how to generate chat completions from reasoning models diff --git a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py index 90481cdc0fb797452cd0d5e4a34e4c2ad1b421e6..9cc0a5f2476b3cb65a35c4e9f147db0c4718293a 100644 --- a/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py +++ b/examples/online_serving/openai_chat_completion_with_reasoning_streaming.py @@ -8,7 +8,7 @@ parser: ```bash vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ - --enable-reasoning --reasoning-parser deepseek_r1 + --reasoning-parser deepseek_r1 ``` Unlike openai_chat_completion_with_reasoning.py, this example demonstrates the diff --git a/examples/online_serving/openai_classification_client.py b/examples/online_serving/openai_classification_client.py new file mode 100644 index 0000000000000000000000000000000000000000..99241346373ea8e4b31477c86f8bb222d2b41e1b --- /dev/null +++ b/examples/online_serving/openai_classification_client.py @@ -0,0 +1,49 @@ +# SPDX-License-Identifier: Apache-2.0 + +import argparse +import pprint + +import requests + + +def post_http_request(payload: dict, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=payload) + return response + + +def parse_args(): + parse = argparse.ArgumentParser() + parse.add_argument("--host", type=str, default="localhost") + parse.add_argument("--port", type=int, default=8000) + parse.add_argument("--model", + type=str, + default="jason9693/Qwen2.5-1.5B-apeach") + return parse.parse_args() + + +def main(args): + host = args.host + port = args.port + model_name = args.model + + api_url = f"http://{host}:{port}/classify" + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + payload = { + "model": model_name, + "input": prompts, + } + + classify_response = post_http_request(payload=payload, api_url=api_url) + pprint.pprint(classify_response.json()) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 5fcb7c5264162e45705582c41b4e2b2f6042d771..66e622672ef2a8ce9b619a90df1a061a031d8a27 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -46,11 +46,15 @@ async def stream_openai_response(): "model": "openai/whisper-large-v3", } url = openai_api_base + "/audio/transcriptions" + headers = {"Authorization": f"Bearer {openai_api_key}"} print("transcription result:", end=' ') async with httpx.AsyncClient() as client: with open(str(winning_call), "rb") as f: - async with client.stream('POST', url, files={'file': f}, - data=data) as response: + async with client.stream('POST', + url, + files={'file': f}, + data=data, + headers=headers) as response: async for line in response.aiter_lines(): # Each line is a JSON object prefixed with 'data: ' if line: diff --git a/examples/online_serving/opentelemetry/Otel.md b/examples/online_serving/opentelemetry/README.md similarity index 100% rename from examples/online_serving/opentelemetry/Otel.md rename to examples/online_serving/opentelemetry/README.md diff --git a/examples/online_serving/ray_serve_deepseek.py b/examples/online_serving/ray_serve_deepseek.py index f9ef3e2da1a19fe3ffcf640d715f6b82d3e47e93..e2dce107e78a3427993dea425531881861ed5bc3 100644 --- a/examples/online_serving/ray_serve_deepseek.py +++ b/examples/online_serving/ray_serve_deepseek.py @@ -1,7 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 """ Example to deploy DeepSeek R1 or V3 with Ray Serve LLM. -See Ray Serve LLM documentation at: +See more details at: +https://docs.ray.io/en/latest/serve/tutorials/serve-deepseek.html +And see Ray Serve LLM documentation at: https://docs.ray.io/en/latest/serve/llm/serving-llms.html Run `python3 ray_serve_deepseek.py` to deploy the model. diff --git a/examples/online_serving/retrieval_augmented_generation_with_langchain.py b/examples/online_serving/retrieval_augmented_generation_with_langchain.py new file mode 100644 index 0000000000000000000000000000000000000000..73063065cb36e7afaca0a86115c55647c169f4b8 --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_langchain.py @@ -0,0 +1,249 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Retrieval Augmented Generation (RAG) Implementation with Langchain +================================================================== + +This script demonstrates a RAG implementation using LangChain, Milvus +and vLLM. RAG enhances LLM responses by retrieving relevant context +from a document collection. + +Features: +- Web content loading and chunking +- Vector storage with Milvus +- Embedding generation with vLLM +- Question answering with context + +Prerequisites: +1. Install dependencies: + pip install -U vllm \ + langchain_milvus langchain_openai \ + langchain_community beautifulsoup4 \ + langchain-text-splitters + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_langchain.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" + +import argparse +from argparse import Namespace +from typing import Any + +from langchain_community.document_loaders import WebBaseLoader +from langchain_core.documents import Document +from langchain_core.output_parsers import StrOutputParser +from langchain_core.prompts import PromptTemplate +from langchain_core.runnables import RunnablePassthrough +from langchain_milvus import Milvus +from langchain_openai import ChatOpenAI, OpenAIEmbeddings +from langchain_text_splitters import RecursiveCharacterTextSplitter + + +def load_and_split_documents(config: dict[str, Any]): + """ + Load and split documents from web URL + """ + try: + loader = WebBaseLoader(web_paths=(config["url"], )) + docs = loader.load() + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + return text_splitter.split_documents(docs) + except Exception as e: + print(f"Error loading document from {config['url']}: {str(e)}") + raise + + +def init_vectorstore(config: dict[str, Any], documents: list[Document]): + """ + Initialize vector store with documents + """ + return Milvus.from_documents( + documents=documents, + embedding=OpenAIEmbeddings( + model=config["embedding_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_embedding_endpoint"], + ), + connection_args={"uri": config["uri"]}, + drop_old=True, + ) + + +def init_llm(config: dict[str, Any]): + """ + Initialize llm + """ + return ChatOpenAI( + model=config["chat_model"], + openai_api_key=config["vllm_api_key"], + openai_api_base=config["vllm_chat_endpoint"], + ) + + +def get_qa_prompt(): + """ + Get question answering prompt template + """ + template = """You are an assistant for question-answering tasks. +Use the following pieces of retrieved context to answer the question. +If you don't know the answer, just say that you don't know. +Use three sentences maximum and keep the answer concise. +Question: {question} +Context: {context} +Answer: +""" + return PromptTemplate.from_template(template) + + +def format_docs(docs: list[Document]): + """ + Format documents for prompt + """ + return "\n\n".join(doc.page_content for doc in docs) + + +def create_qa_chain(retriever: Any, llm: ChatOpenAI, prompt: PromptTemplate): + """ + Set up question answering chain + """ + return ({ + "context": retriever | format_docs, + "question": RunnablePassthrough(), + } + | prompt + | llm + | StrOutputParser()) + + +def get_parser() -> argparse.ArgumentParser: + """ + Parse command line arguments + """ + parser = argparse.ArgumentParser(description='RAG with vLLM and langchain') + + # Add command line arguments + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--vllm-embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--vllm-chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--uri', + default="./milvus.db", + help='URI for Milvus database') + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + + return parser + + +def init_config(args: Namespace): + """ + Initialize configuration settings from command line arguments + """ + + return { + "vllm_api_key": args.vllm_api_key, + "vllm_embedding_endpoint": args.vllm_embedding_endpoint, + "vllm_chat_endpoint": args.vllm_chat_endpoint, + "uri": args.uri, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "url": args.url, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load and split documents + documents = load_and_split_documents(config) + + # Initialize vector store and retriever + vectorstore = init_vectorstore(config, documents) + retriever = vectorstore.as_retriever(search_kwargs={"k": config["top_k"]}) + + # Initialize llm and prompt + llm = init_llm(config) + prompt = get_qa_prompt() + + # Set up QA chain + qa_chain = create_qa_chain(retriever, llm, prompt) + + # Interactive mode + if args.interactive: + print("\nWelcome to Interactive Q&A System!") + print("Enter 'q' or 'quit' to exit.") + + while True: + question = input("\nPlease enter your question: ") + if question.lower() in ['q', 'quit']: + print("\nThank you for using! Goodbye!") + break + + output = qa_chain.invoke(question) + print(output) + else: + # Default single question mode + question = ("How to install vLLM?") + output = qa_chain.invoke(question) + print("-" * 50) + print(output) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f76dfe4c697e728af72add47d92eb99a98c55d --- /dev/null +++ b/examples/online_serving/retrieval_augmented_generation_with_llamaindex.py @@ -0,0 +1,217 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +RAG (Retrieval Augmented Generation) Implementation with LlamaIndex +================================================================ + +This script demonstrates a RAG system using: +- LlamaIndex: For document indexing and retrieval +- Milvus: As vector store backend +- vLLM: For embedding and text generation + +Features: +1. Document Loading & Processing +2. Embedding & Storage +3. Query Processing + +Requirements: +1. Install dependencies: +pip install llama-index llama-index-readers-web \ + llama-index-llms-openai-like \ + llama-index-embeddings-openai-like \ + llama-index-vector-stores-milvus \ + +2. Start services: + # Start embedding service (port 8000) + vllm serve ssmits/Qwen2-7B-Instruct-embed-base + + # Start chat service (port 8001) + vllm serve qwen/Qwen1.5-0.5B-Chat --port 8001 + +Usage: + python retrieval_augmented_generation_with_llamaindex.py + +Notes: + - Ensure both vLLM services are running before executing + - Default ports: 8000 (embedding), 8001 (chat) + - First run may take time to download models +""" +import argparse +from argparse import Namespace +from typing import Any + +from llama_index.core import Settings, StorageContext, VectorStoreIndex +from llama_index.core.node_parser import SentenceSplitter +from llama_index.embeddings.openai_like import OpenAILikeEmbedding +from llama_index.llms.openai_like import OpenAILike +from llama_index.readers.web import SimpleWebPageReader +from llama_index.vector_stores.milvus import MilvusVectorStore + + +def init_config(args: Namespace): + """Initialize configuration with command line arguments""" + return { + "url": args.url, + "embedding_model": args.embedding_model, + "chat_model": args.chat_model, + "vllm_api_key": args.vllm_api_key, + "embedding_endpoint": args.embedding_endpoint, + "chat_endpoint": args.chat_endpoint, + "db_path": args.db_path, + "chunk_size": args.chunk_size, + "chunk_overlap": args.chunk_overlap, + "top_k": args.top_k + } + + +def load_documents(url: str) -> list: + """Load and process web documents""" + return SimpleWebPageReader(html_to_text=True).load_data([url]) + + +def setup_models(config: dict[str, Any]): + """Configure embedding and chat models""" + Settings.embed_model = OpenAILikeEmbedding( + api_base=config["embedding_endpoint"], + api_key=config["vllm_api_key"], + model_name=config["embedding_model"], + ) + + Settings.llm = OpenAILike( + model=config["chat_model"], + api_key=config["vllm_api_key"], + api_base=config["chat_endpoint"], + context_window=128000, + is_chat_model=True, + is_function_calling_model=False, + ) + + Settings.transformations = [ + SentenceSplitter( + chunk_size=config["chunk_size"], + chunk_overlap=config["chunk_overlap"], + ) + ] + + +def setup_vector_store(db_path: str) -> MilvusVectorStore: + """Initialize vector store""" + sample_emb = Settings.embed_model.get_text_embedding("test") + print(f"Embedding dimension: {len(sample_emb)}") + return MilvusVectorStore(uri=db_path, dim=len(sample_emb), overwrite=True) + + +def create_index(documents: list, vector_store: MilvusVectorStore): + """Create document index""" + storage_context = StorageContext.from_defaults(vector_store=vector_store) + return VectorStoreIndex.from_documents( + documents, + storage_context=storage_context, + ) + + +def query_document(index: VectorStoreIndex, question: str, top_k: int): + """Query document with given question""" + query_engine = index.as_query_engine(similarity_top_k=top_k) + return query_engine.query(question) + + +def get_parser() -> argparse.ArgumentParser: + """Parse command line arguments""" + parser = argparse.ArgumentParser( + description='RAG with vLLM and LlamaIndex') + + # Add command line arguments + parser.add_argument( + '--url', + default=("https://docs.vllm.ai/en/latest/getting_started/" + "quickstart.html"), + help='URL of the document to process') + parser.add_argument('--embedding-model', + default="ssmits/Qwen2-7B-Instruct-embed-base", + help='Model name for embeddings') + parser.add_argument('--chat-model', + default="qwen/Qwen1.5-0.5B-Chat", + help='Model name for chat') + parser.add_argument('--vllm-api-key', + default="EMPTY", + help='API key for vLLM compatible services') + parser.add_argument('--embedding-endpoint', + default="http://localhost:8000/v1", + help='Base URL for embedding service') + parser.add_argument('--chat-endpoint', + default="http://localhost:8001/v1", + help='Base URL for chat service') + parser.add_argument('--db-path', + default="./milvus_demo.db", + help='Path to Milvus database') + parser.add_argument('-i', + '--interactive', + action='store_true', + help='Enable interactive Q&A mode') + parser.add_argument('-c', + '--chunk-size', + type=int, + default=1000, + help='Chunk size for document splitting') + parser.add_argument('-o', + '--chunk-overlap', + type=int, + default=200, + help='Chunk overlap for document splitting') + parser.add_argument('-k', + '--top-k', + type=int, + default=3, + help='Number of top results to retrieve') + + return parser + + +def main(): + # Parse command line arguments + args = get_parser().parse_args() + + # Initialize configuration + config = init_config(args) + + # Load documents + documents = load_documents(config["url"]) + + # Setup models + setup_models(config) + + # Setup vector store + vector_store = setup_vector_store(config["db_path"]) + + # Create index + index = create_index(documents, vector_store) + + if args.interactive: + print("\nEntering interactive mode. Type 'quit' to exit.") + while True: + # Get user question + question = input("\nEnter your question: ") + + # Check for exit command + if question.lower() in ['quit', 'exit', 'q']: + print("Exiting interactive mode...") + break + + # Get and print response + print("\n" + "-" * 50) + print("Response:\n") + response = query_document(index, question, config["top_k"]) + print(response) + print("-" * 50) + else: + # Single query mode + question = "How to install vLLM?" + response = query_document(index, question, config["top_k"]) + print("-" * 50) + print("Response:\n") + print(response) + print("-" * 50) + + +if __name__ == "__main__": + main() diff --git a/examples/online_serving/streamlit_openai_chatbot_webserver.py b/examples/online_serving/streamlit_openai_chatbot_webserver.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a0f211d44d5045e4f0d6ede48a22f52c539600 --- /dev/null +++ b/examples/online_serving/streamlit_openai_chatbot_webserver.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +vLLM Chat Assistant - A Streamlit Web Interface + +A streamlined chat interface that quickly integrates +with vLLM API server. + +Features: +- Multiple chat sessions management +- Streaming response display +- Configurable API endpoint +- Real-time chat history + +Requirements: + pip install streamlit openai + +Usage: + # Start the app with default settings + streamlit run streamlit_openai_chatbot_webserver.py + + # Start with custom vLLM API endpoint + VLLM_API_BASE="http://your-server:8000/v1" \ + streamlit run streamlit_openai_chatbot_webserver.py + + # Enable debug mode + streamlit run streamlit_openai_chatbot_webserver.py \ + --logger.level=debug +""" +import os +from datetime import datetime + +import streamlit as st +from openai import OpenAI + +# Get command line arguments from environment variables +openai_api_key = os.getenv('VLLM_API_KEY', "EMPTY") +openai_api_base = os.getenv('VLLM_API_BASE', "http://localhost:8000/v1") + +# Initialize session states for managing chat sessions +if "sessions" not in st.session_state: + st.session_state.sessions = {} + +if "current_session" not in st.session_state: + st.session_state.current_session = None + +if "messages" not in st.session_state: + st.session_state.messages = [] + +if "active_session" not in st.session_state: + st.session_state.active_session = None + +# Initialize session state for API base URL +if "api_base_url" not in st.session_state: + st.session_state.api_base_url = openai_api_base + + +def create_new_chat_session(): + """Create a new chat session with timestamp as ID""" + session_id = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + st.session_state.sessions[session_id] = [] + st.session_state.current_session = session_id + st.session_state.active_session = session_id + st.session_state.messages = [] + + +def switch_to_chat_session(session_id): + """Switch to a different chat session""" + st.session_state.current_session = session_id + st.session_state.active_session = session_id + st.session_state.messages = st.session_state.sessions[session_id] + + +def get_llm_response(messages, model): + """Get streaming response from llm + + Args: + messages: List of message dictionaries + model: Name of model + + Returns: + Streaming response object or error message string + """ + try: + response = client.chat.completions.create(model=model, + messages=messages, + stream=True) + return response + except Exception as e: + st.error(f"Error details: {str(e)}") + return f"Error: {str(e)}" + + +# Sidebar - API Settings first +st.sidebar.title("API Settings") +new_api_base = st.sidebar.text_input("API Base URL:", + value=st.session_state.api_base_url) +if new_api_base != st.session_state.api_base_url: + st.session_state.api_base_url = new_api_base + st.rerun() + +st.sidebar.divider() + +# Sidebar - Session Management +st.sidebar.title("Chat Sessions") +if st.sidebar.button("New Session"): + create_new_chat_session() + +# Display all sessions in reverse chronological order +for session_id in sorted(st.session_state.sessions.keys(), reverse=True): + # Mark the active session with a pinned button + if session_id == st.session_state.active_session: + st.sidebar.button(f"📍 {session_id}", + key=session_id, + type="primary", + on_click=switch_to_chat_session, + args=(session_id, )) + else: + st.sidebar.button(f"Session {session_id}", + key=session_id, + on_click=switch_to_chat_session, + args=(session_id, )) + +# Main interface +st.title("vLLM Chat Assistant") + +# Initialize OpenAI client with API settings +client = OpenAI(api_key=openai_api_key, base_url=st.session_state.api_base_url) + +# Get and display current model id +models = client.models.list() +model = models.data[0].id +st.markdown(f"**Model**: {model}") + +# Initialize first session if none exists +if st.session_state.current_session is None: + create_new_chat_session() + st.session_state.active_session = st.session_state.current_session + +# Display chat history for current session +for message in st.session_state.messages: + with st.chat_message(message["role"]): + st.write(message["content"]) + +# Handle user input and generate llm response +if prompt := st.chat_input("Type your message here..."): + # Save user message to session + st.session_state.messages.append({"role": "user", "content": prompt}) + st.session_state.sessions[ + st.session_state.current_session] = st.session_state.messages + + # Display user message + with st.chat_message("user"): + st.write(prompt) + + # Prepare messages for llm + messages_for_llm = [{ + "role": m["role"], + "content": m["content"] + } for m in st.session_state.messages] + + # Generate and display llm response + with st.chat_message("assistant"): + message_placeholder = st.empty() + full_response = "" + + # Get streaming response from llm + response = get_llm_response(messages_for_llm, model) + if isinstance(response, str): + message_placeholder.markdown(response) + full_response = response + else: + for chunk in response: + if hasattr(chunk.choices[0].delta, "content"): + content = chunk.choices[0].delta.content + if content: + full_response += content + message_placeholder.markdown(full_response + "▌") + + message_placeholder.markdown(full_response) + + # Save llm response to session history + st.session_state.messages.append({ + "role": "assistant", + "content": full_response + }) diff --git a/examples/online_serving/utils.py b/examples/online_serving/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4826e8e20528284b0525c8345e0cee3f1f1777d2 --- /dev/null +++ b/examples/online_serving/utils.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import APIConnectionError, OpenAI +from openai.pagination import SyncPage +from openai.types.model import Model + + +def get_first_model(client: OpenAI) -> str: + """ + Get the first model from the vLLM server. + """ + try: + models: SyncPage[Model] = client.models.list() + except APIConnectionError as e: + raise RuntimeError( + "Failed to get the list of models from the vLLM server at " + f"{client.base_url} with API key {client.api_key}. Check\n" + "1. the server is running\n" + "2. the server URL is correct\n" + "3. the API key is correct") from e + + if len(models.data) == 0: + raise RuntimeError( + f"No models found on the vLLM server at {client.base_url}") + + return models.data[0].id diff --git a/examples/template_florence2.jinja b/examples/template_florence2.jinja deleted file mode 100644 index d257aed6a85b05537d262284047f759dde2c9db0..0000000000000000000000000000000000000000 --- a/examples/template_florence2.jinja +++ /dev/null @@ -1,7 +0,0 @@ -{%- for message in messages -%} - {%- if message['role'] == 'user' -%} - {{- message['content'] -}} - {%- elif message['role'] == 'assistant' -%} - {{- message['content'] -}} - {%- endif -%} -{%- endfor -%} diff --git a/examples/template_llava.jinja b/examples/template_llava.jinja deleted file mode 100644 index 6a902ee167725d08bca2ea6dacdfdcea9b17ce96..0000000000000000000000000000000000000000 --- a/examples/template_llava.jinja +++ /dev/null @@ -1,23 +0,0 @@ -{%- if messages[0]['role'] == 'system' -%} - {%- set system_message = messages[0]['content'] -%} - {%- set messages = messages[1:] -%} -{%- else -%} - {% set system_message = '' -%} -{%- endif -%} - -{{ bos_token + system_message }} -{%- for message in messages -%} - {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%} - {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} - {%- endif -%} - - {%- if message['role'] == 'user' -%} - {{ 'USER: ' + message['content'] + '\n' }} - {%- elif message['role'] == 'assistant' -%} - {{ 'ASSISTANT: ' + message['content'] + eos_token + '\n' }} - {%- endif -%} -{%- endfor -%} - -{%- if add_generation_prompt -%} - {{ 'ASSISTANT:' }} -{% endif %} diff --git a/examples/tool_chat_template_deepseekv3.jinja b/examples/tool_chat_template_deepseekv3.jinja new file mode 100644 index 0000000000000000000000000000000000000000..36f3781439ede83e5f7932bb258c9449a6df2ceb --- /dev/null +++ b/examples/tool_chat_template_deepseekv3.jinja @@ -0,0 +1,96 @@ +{% if not add_generation_prompt is defined %} + {% set add_generation_prompt = false %} +{% endif %} + +{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true, is_last_user=false) %} + +{%- for message in messages %} + {%- if message['role'] == 'system' %} + {%- if ns.is_first_sp %} + {% set ns.system_prompt = ns.system_prompt + message['content'] %} + {% set ns.is_first_sp = false %} + {%- else %} + {% set ns.system_prompt = ns.system_prompt + '\n\n' + message['content'] %} + {%- endif %} + {%- endif %} +{%- endfor %} + +{{ bos_token }} +{{ ns.system_prompt }} +{%- if tools %} + {{"\n\n# Tools\n\nYou may call one or more functions to assist with the user query." }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{"\n\n\n"}} + + {{"For function call returns, you should first print <|tool▁calls▁begin|>"}} + + {{"For each function call, you should return object like:\n" }} + {{"<|tool▁call▁begin|>function<|tool▁sep|>\n```json\n\n```<|tool▁call▁end|>"}} + + {{"At the end of function call returns, you should print <|tool▁calls▁end|><|end▁of▁sentence|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if message['role'] == 'user' %} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%} + {{'<|User|>' + message['content'] + '<|Assistant|>'}} + {%- endif %} + + {%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} + {%- endif %} + {%- set ns.is_first = false %} + {%- set ns.is_tool = false -%} + {%- set ns.is_output_first = true %} + + {%- for tool in message['tool_calls'] %} + {%- if not ns.is_first %} + {%- if message['content'] is none %} + {{'<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- else %} + {{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- set ns.is_first = true -%} + {%- else %} + {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments']|tojson + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif %} + {%- endfor %} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif %} + {%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none)%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool %} + {{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else %} + {% set content = message['content'] %} + {{content + '<|end▁of▁sentence|>'}} + {%- endif %} + {%- endif %} + + {%- if message['role'] == 'tool' %} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first %} + {{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false %} + {%- else %} + {{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif %} + {%- endif %} +{%- endfor -%} + +{% if ns.is_tool %} + {{'<|tool▁outputs▁end|>'}} +{% endif %} + +{% if add_generation_prompt and not ns.is_last_user and not ns.is_tool %} + {{'<|Assistant|>'}} +{% endif %} diff --git a/examples/tool_chat_template_mistral3.jinja b/examples/tool_chat_template_mistral3.jinja new file mode 100644 index 0000000000000000000000000000000000000000..7c4249ec44c561bf5b5140488afab2da4dec22cf --- /dev/null +++ b/examples/tool_chat_template_mistral3.jinja @@ -0,0 +1,126 @@ +{%- set today = strftime_now("%Y-%m-%d") %} +{%- set default_system_message = "You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup headquartered in Paris.\nYour knowledge base was last updated on 2023-10-01. The current date is " + today + ".\n\nWhen you're not sure about some information, you say that you don't have the information and don't make up anything.\nIf the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. \"What are some good restaurants around me?\" => \"Where are you?\" or \"When is the next flight to Tokyo\" => \"Where do you travel from?\")" %} + +{{- bos_token }} + +{%- if messages[0]['role'] == 'system' %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_messages = messages[1:] %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text'] %} + {%- set loop_messages = messages[1:] %} + {%- endif %} +{%- else %} + {%- set system_message = default_system_message %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- elif tools is not none %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{{- '[SYSTEM_PROMPT]' + system_message + '[/SYSTEM_PROMPT]' }} + +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- set filtered_messages = [] %} +{%- for message in loop_messages %} + {%- if message["role"] not in ["tool", "tool_results"] and not message.get("tool_calls") %} + {%- set filtered_messages = filtered_messages + [message] %} + {%- endif %} +{%- endfor %} + +{%- for message in filtered_messages %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if message['content'] is string %} + {{- '[INST]' + message['content'] + '[/INST]' }} + {%- else %} + {{- '[INST]' }} + {%- for block in message['content'] %} + {%- if block['type'] == 'text' %} + {{- block['text'] }} + {%- elif block['type'] == 'image' or block['type'] == 'image_url' %} + {{- '[IMG]' }} + {%- else %} + {{- raise_exception('Only text and image blocks are supported in message content!') }} + {%- endif %} + {%- endfor %} + {{- '[/INST]' }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message['role'] == 'assistant' %} + {%- if message['content'] is string %} + {{- message['content'] + eos_token }} + {%- else %} + {{- message['content'][0]['text'] + eos_token }} + {%- endif %} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/pyproject.toml b/pyproject.toml index b5f1039b44daccaf84389cb1d0ba766ac0eb9a71..0b803a26b658116d75a8ffcb04bc5e422b543431 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,10 +3,10 @@ requires = [ "cmake>=3.26", "ninja", - "packaging", - "setuptools>=61", + "packaging>=24.2", + "setuptools>=77.0.3,<80.0.0", "setuptools-scm>=8.0", - "torch == 2.6.0", + "torch == 2.7.0", "wheel", "jinja2", ] @@ -41,6 +41,9 @@ Slack="http://slack.vllm.ai/" [project.scripts] vllm = "vllm.entrypoints.cli.main:main" +[project.entry-points."vllm.general_plugins"] +lora_filesystem_resolver = "vllm.plugins.lora_resolvers.filesystem_resolver:register_filesystem_resolver" + [tool.setuptools_scm] # no extra settings needed, presence enables setuptools-scm @@ -50,6 +53,8 @@ include = ["vllm*"] [tool.yapfignore] ignore_patterns = [ + ".buildkite/**", + "benchmarks/**", "build/**", ] @@ -66,26 +71,15 @@ exclude = [ "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 -"vllm/adapter_commons/**/*.py" = ["UP006", "UP035"] +# Python 3.8 typing - skip V0 code "vllm/attention/**/*.py" = ["UP006", "UP035"] -"vllm/compilation/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"] -"vllm/device_allocator/**/*.py" = ["UP006", "UP035"] -"vllm/distributed/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/lora/**/*.py" = ["UP006", "UP035"] -"vllm/model_executor/**/*.py" = ["UP006", "UP035"] -"vllm/platforms/**/*.py" = ["UP006", "UP035"] -"vllm/plugins/**/*.py" = ["UP006", "UP035"] -"vllm/profiler/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"] -"vllm/transformers_utils/**/*.py" = ["UP006", "UP035"] -"vllm/triton_utils/**/*.py" = ["UP006", "UP035"] -"vllm/vllm_flash_attn/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] +# Python 3.8 typing - skip utils for ROCm "vllm/utils.py" = ["UP006", "UP035"] [tool.ruff.lint] @@ -102,6 +96,7 @@ select = [ "SIM", # isort # "I", + # flake8-logging-format "G", ] ignore = [ @@ -150,6 +145,10 @@ ignore-words-list = "dout, te, indicies, subtile, ElementE" skip = "tests/models/fixtures/*,tests/prompts/*,benchmarks/sonnet.txt,tests/lora/data/*,build/*,vllm/third_party/*" [tool.isort] +skip_glob = [ + ".buildkite/*", + "benchmarks/*", +] use_parentheses = true skip_gitignore = true @@ -158,7 +157,6 @@ markers = [ "skip_global_cleanup", "core_model: enable this model test in each PR instead of only nightly", "cpu_model: enable this model test in CPU tests", - "quant_model: run this model test under Quantized category", "split: run this test as part of a split", "distributed: run this test only in distributed GPU tests", "skip_v1: do not run this test with v1", @@ -171,3 +169,9 @@ plugins.md013.enabled = false # line-length plugins.md041.enabled = false # first-line-h1 plugins.md033.enabled = false # inline-html plugins.md024.allow_different_nesting = true # no-duplicate-headers + +[tool.ty] +respect-ignore-files = true + +[tool.ty.environment] +python = "./.venv" diff --git a/requirements/build.txt b/requirements/build.txt index 13d643bcaff104f5c5443fd1f92e247d8ffe333e..5edc593b9270094916b9620fbca3813fe6d9b37a 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -1,9 +1,9 @@ # Should be mirrored in pyproject.toml cmake>=3.26 ninja -packaging -setuptools>=61 +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 -torch==2.6.0 +torch==2.7.0 wheel jinja2>=3.1.6 diff --git a/requirements/common.txt b/requirements/common.txt index 33c4c3219f159002c57327a30ab0d7114e172c45..80f90e60007e0654731bfbc23d93a6cd424a3b19 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -19,31 +19,31 @@ pillow # Required for image processing prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.11, < 0.11 -llguidance >= 0.7.9, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" +llguidance >= 0.7.11, < 0.8.0; platform_machine == "x86_64" or platform_machine == "arm64" or platform_machine == "aarch64" outlines == 0.1.11 lark == 1.2.2 -xgrammar == 0.1.18; platform_machine == "x86_64" or platform_machine == "aarch64" +xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs pyzmq >= 25.0.0 msgspec gguf >= 0.13.0 -importlib_metadata +importlib_metadata; python_version < '3.10' mistral_common[opencv] >= 1.5.4 opencv-python-headless >= 4.11.0 # required for video IO pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 -setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 +setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 einops # Required for Qwen2-VL. -compressed-tensors == 0.9.3 # required for compressed-tensors +compressed-tensors == 0.9.4 # required for compressed-tensors depyf==0.18.0 # required for profiling and debugging with compilation config cloudpickle # allows pickling lambda functions in model_executor/models/registry.py watchfiles # required for http server to monitor the updates of TLS files python-json-logger # Used by logging as per examples/other/logging_configuration.md scipy # Required for phi-4-multimodal-instruct ninja # Required for xgrammar, rocm, tpu, xpu -opentelemetry-sdk>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-api>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-exporter-otlp>=1.26.0,<1.27.0 # vllm.tracing -opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0 # vllm.tracing +opentelemetry-sdk>=1.26.0 # vllm.tracing +opentelemetry-api>=1.26.0 # vllm.tracing +opentelemetry-exporter-otlp>=1.26.0 # vllm.tracing +opentelemetry-semantic-conventions-ai>=0.4.1 # vllm.tracing diff --git a/requirements/cpu.txt b/requirements/cpu.txt index 69f732c2417a1b4270296652b391e3a3c7c2f4ff..752931158a056f8efc72ab9184b427ac573ffba7 100644 --- a/requirements/cpu.txt +++ b/requirements/cpu.txt @@ -2,18 +2,19 @@ -r common.txt # Dependencies for CPUs -torch==2.6.0+cpu; platform_machine == "x86_64" -torch==2.6.0; platform_system == "Darwin" -torch==2.6.0; platform_machine == "ppc64le" or platform_machine == "aarch64" +--extra-index-url https://download.pytorch.org/whl/cpu +torch==2.7.0+cpu; platform_machine == "x86_64" +torch==2.7.0; platform_system == "Darwin" +torch==2.7.0; platform_machine == "ppc64le" or platform_machine == "aarch64" torch==2.7.0.dev20250304; platform_machine == "s390x" # required for the image processor of minicpm-o-2_6, this must be updated alongside torch torchaudio; platform_machine != "ppc64le" and platform_machine != "s390x" -torchaudio==2.6.0; platform_machine == "ppc64le" +torchaudio==2.7.0; platform_machine == "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch torchvision; platform_machine != "ppc64le" and platform_machine != "s390x" -torchvision==0.21.0; platform_machine == "ppc64le" +torchvision==0.22.0; platform_machine == "ppc64le" datasets # for benchmark scripts # cpu cannot use triton 3.3.0 diff --git a/requirements/cuda.txt b/requirements/cuda.txt index cdc6ee75afbcda2da7edb02461d8b1fa0d26d2f3..a71d9728f38ad9c7c61d8f80e6d6d14580b6e56d 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -6,8 +6,9 @@ numba == 0.61.2; python_version > '3.9' # Dependencies for NVIDIA GPUs ray[cgraph]>=2.43.0, !=2.44.* # Ray Compiled Graph, required for pipeline parallelism in V1. -torch==2.6.0 -torchaudio==2.6.0 +torch==2.7.0 +torchaudio==2.7.0 # These must be updated alongside torch -torchvision==0.21.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version -xformers==0.0.29.post2; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch 2.6.0 +torchvision==0.22.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version +# https://github.com/facebookresearch/xformers/releases/tag/v0.0.30 +xformers==0.0.30; platform_system == 'Linux' and platform_machine == 'x86_64' # Requires PyTorch >= 2.7 diff --git a/requirements/docs.txt b/requirements/docs.txt index d84fd633ce108946a263d028a3bb7df76213b2d4..9c267edaceaf1d8051b638b1be4732fdfa111611 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -1,27 +1,19 @@ -sphinx==6.2.1 -sphinx-argparse==0.4.0 -sphinx-book-theme==1.0.1 +sphinx==7.4.7 +sphinx-argparse==0.5.2 +sphinx-book-theme==1.1.4 sphinx-copybutton==0.5.2 sphinx-design==0.6.1 sphinx-togglebutton==0.3.2 -myst-parser==3.0.1 +myst-parser==3.0.1 # `myst-parser==4.0.1` breaks inline code in titles msgspec -cloudpickle +snowballstemmer<3 # https://github.com/snowballstem/snowball/issues/229 commonmark # Required by sphinx-argparse when using :markdownhelp: +# Custom autodoc2 is necessary for faster docstring processing +# see: https://github.com/sphinx-extensions2/sphinx-autodoc2/issues/33#issuecomment-2856386035 +git+https://github.com/hmellor/sphinx-autodoc2.git # sphinx-autodoc2==0.5.0 + # packages to install to build the documentation cachetools -pydantic >= 2.8 -f https://download.pytorch.org/whl/cpu -torch -py-cpuinfo -transformers -mistral_common >= 1.5.4 -aiohttp -starlette -scipy -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args -fastapi # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args -partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args -requests -zmq +torch \ No newline at end of file diff --git a/requirements/hpu.txt b/requirements/hpu.txt index 5ac58bc02892e7aa3c9ac006b7afb1b51440f52d..a88777268a342c1152ba9db9c5c2d4a358dcef17 100644 --- a/requirements/hpu.txt +++ b/requirements/hpu.txt @@ -7,6 +7,6 @@ triton==3.1.0 pandas numpy==1.26.4 tabulate -setuptools>=61 +setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 vllm-hpu-extension @ git+https://github.com/HabanaAI/vllm-hpu-extension.git@f1f6624 diff --git a/requirements/neuron.txt b/requirements/neuron.txt index 5f25bd0546e695b6f28ddee4409162bf3c703346..7df478eddde3fffccea18182b8e29882c88812e7 100644 --- a/requirements/neuron.txt +++ b/requirements/neuron.txt @@ -2,5 +2,8 @@ -r common.txt # Dependencies for Neuron devices +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 torch-neuronx >= 2.5.0 -neuronx-cc +neuronx-cc>=2.0.0a0 +torchvision # Required for Llama3.2 multimodal image preprocessing diff --git a/requirements/nightly_torch_test.txt b/requirements/nightly_torch_test.txt index 20372a9b2ef168128702ec70e6be05c52f18875f..3aebcaa623c03a14f74384d4c2a3cd2285baabac 100644 --- a/requirements/nightly_torch_test.txt +++ b/requirements/nightly_torch_test.txt @@ -8,12 +8,10 @@ pytest-rerunfailures pytest-shard pytest-timeout - librosa # required by audio tests in entrypoints/openai sentence-transformers numba == 0.61.2; python_version > '3.9' # testing utils -awscli boto3 botocore datasets @@ -24,5 +22,20 @@ runai-model-streamer-s3==0.11.0 tensorizer>=2.9.0 lm-eval==0.4.8 buildkite-test-collector==0.1.9 - lm-eval[api]==0.4.8 # required for model evaluation test + +# required for quantization test +bitsandbytes>=0.45.3 + +# required for minicpmo_26 test +vector_quantize_pytorch +vocos + +# required for Basic Models Test +blobfile # required for kimi-vl test +matplotlib # required for qwen-vl test + +# required for Multi-Modal Models Test (Standard) +num2words # required for smolvlm test +pqdm +timm # required for internvl test diff --git a/requirements/rocm-build.txt b/requirements/rocm-build.txt index 05de4ff168453d055180605327561e266e685f18..981b90632c182b5ad67e22f189c46d96896fb52f 100644 --- a/requirements/rocm-build.txt +++ b/requirements/rocm-build.txt @@ -2,14 +2,14 @@ -r common.txt --extra-index-url https://download.pytorch.org/whl/rocm6.2.4 -torch==2.6.0 -torchvision==0.21.0 -torchaudio==2.6.0 +torch==2.7.0 +torchvision==0.22.0 +torchaudio==2.7.0 triton==3.2 cmake>=3.26,<4 -packaging -setuptools>=61 +packaging>=24.2 +setuptools>=77.0.3,<80.0.0 setuptools-scm>=8 wheel jinja2>=3.1.6 diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index 52fbf787f1dff99db97bacf0f176a4ce5cd6a56c..25f950a99eceb73437c933f7c24b3f07f4f2a87c 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -1,3 +1,5 @@ +# Common dependencies +-r common.txt # entrypoints test # librosa==0.10.2.post1 # required by audio tests in entrypoints/openai @@ -20,4 +22,10 @@ decord==0.6.0 #sentence-transformers # required by entrypoints/openai/test_score.py sentence-transformers==3.4.1 +# Basic Models Test +matplotlib==3.10.3 + +# Multi-Modal Models Test (Extended) 3 +blobfile==3.0.0 + diff --git a/requirements/rocm.txt b/requirements/rocm.txt index 4df92aab3749e179ad84ee039517dfa04e990247..8a84f2ff1ed01c0bc01e8d84677a682bbf285a05 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -5,11 +5,10 @@ numba == 0.60.0; python_version == '3.9' # v0.61 doesn't support Python 3.9. Req numba == 0.61.2; python_version > '3.9' # Dependencies for AMD GPUs -awscli boto3 botocore datasets -ray >= 2.10.0 +ray>=2.10.0,<2.45.0 peft pytest-asyncio tensorizer>=2.9.0 diff --git a/requirements/test.in b/requirements/test.in index c5d2c4cd4c30f87243fb4cacebdde11529a96179..cdc7c563f087916c07b125da1be84acd224b7d88 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -8,7 +8,6 @@ pytest-shard pytest-timeout # testing utils -awscli backoff # required for phi4mm test blobfile # required for kimi-vl test einops # required for MPT, qwen-vl and Mamba @@ -23,9 +22,9 @@ sentence-transformers # required for embedding tests soundfile # required for audio tests jiwer # required for audio tests timm # required for internvl test -torch==2.6.0 -torchaudio==2.6.0 -torchvision==0.21.0 +torch==2.7.0 +torchaudio==2.7.0 +torchvision==0.22.0 transformers_stream_generator # required for qwen-vl test mamba_ssm # required for plamo2 test matplotlib # required for qwen-vl test diff --git a/requirements/test.txt b/requirements/test.txt index 9642a5bfe68d421fdcd4932ebb0d790f15e9f56a..9a15d9a0d82403a7de88500021a08bde52477635 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,5 +1,5 @@ # This file was autogenerated by uv via the following command: -# uv pip compile requirements/test.in -o requirements/test.txt +# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 absl-py==2.1.0 # via rouge-score accelerate==1.0.1 @@ -37,8 +37,6 @@ attrs==24.2.0 # referencing audioread==3.0.1 # via librosa -awscli==1.35.23 - # via -r requirements/test.in backoff==2.2.1 # via # -r requirements/test.in @@ -53,7 +51,6 @@ boto3==1.35.57 # via tensorizer botocore==1.35.57 # via - # awscli # boto3 # s3transfer bounded-pool-executor==0.0.3 @@ -81,7 +78,6 @@ click==8.1.7 # typer colorama==0.4.6 # via - # awscli # sacrebleu # schemathesis # tqdm-multiprocess @@ -115,8 +111,6 @@ dnspython==2.7.0 # via email-validator docopt==0.6.2 # via num2words -docutils==0.16 - # via awscli einops==0.8.0 # via # -r requirements/test.in @@ -274,7 +268,7 @@ mamba-ssm==2.2.4 # via -r requirements/test.in markdown-it-py==3.0.0 # via rich -markupsafe==3.0.2 +markupsafe==3.0.1 # via # jinja2 # werkzeug @@ -355,45 +349,48 @@ numpy==1.26.4 # transformers # tritonclient # vocos -nvidia-cublas-cu12==12.4.5.8 +nvidia-cublas-cu12==12.8.3.14 # via # nvidia-cudnn-cu12 # nvidia-cusolver-cu12 # torch -nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-cupti-cu12==12.8.57 + # via torch +nvidia-cuda-nvrtc-cu12==12.8.61 # via torch -nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.8.57 # via torch -nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.7.1.26 # via torch -nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.3.3.41 # via torch -nvidia-cufft-cu12==11.2.1.3 +nvidia-cufile-cu12==1.13.0.11 # via torch -nvidia-curand-cu12==10.3.5.147 +nvidia-curand-cu12==10.3.9.55 # via torch -nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusolver-cu12==11.7.2.55 # via torch -nvidia-cusparse-cu12==12.3.1.170 +nvidia-cusparse-cu12==12.5.7.53 # via # nvidia-cusolver-cu12 # torch -nvidia-cusparselt-cu12==0.6.2 +nvidia-cusparselt-cu12==0.6.3 # via torch -nvidia-nccl-cu12==2.21.5 +nvidia-nccl-cu12==2.26.2 # via torch -nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvjitlink-cu12==12.8.61 # via + # nvidia-cufft-cu12 # nvidia-cusolver-cu12 # nvidia-cusparse-cu12 # torch -nvidia-nvtx-cu12==12.4.127 +nvidia-nvtx-cu12==12.8.55 # via torch opencv-python-headless==4.11.0.86 # via # -r requirements/test.in # mistral-common -packaging==24.1 +packaging==24.2 # via # accelerate # black @@ -469,8 +466,6 @@ pyarrow==18.0.0 # via # datasets # genai-perf -pyasn1==0.6.1 - # via rsa pybind11==2.13.6 # via lm-eval pycparser==2.22 @@ -534,7 +529,6 @@ pytz==2024.2 pyyaml==6.0.2 # via # accelerate - # awscli # datamodel-code-generator # datasets # genai-perf @@ -593,16 +587,12 @@ rpds-py==0.20.1 # via # jsonschema # referencing -rsa==4.7.2 - # via awscli runai-model-streamer==0.11.0 # via -r requirements/test.in runai-model-streamer-s3==0.11.0 # via -r requirements/test.in s3transfer==0.10.3 - # via - # awscli - # boto3 + # via boto3 sacrebleu==2.4.3 # via lm-eval safetensors==0.4.5 @@ -629,11 +619,12 @@ sentence-transformers==3.2.1 # via -r requirements/test.in sentencepiece==0.2.0 # via mistral-common -setuptools==75.8.0 +setuptools==77.0.3 # via # mamba-ssm # pytablewriter # torch + # triton shellingham==1.5.4 # via typer six==1.16.0 @@ -664,7 +655,7 @@ starlette-testclient==0.4.1 # via schemathesis statsmodels==0.14.4 # via genai-perf -sympy==1.13.1 +sympy==1.13.3 # via # einx # torch @@ -696,7 +687,7 @@ tomli==2.2.1 # via schemathesis tomli-w==1.2.0 # via schemathesis -torch==2.6.0 +torch==2.7.0+cu128 # via # -r requirements/test.in # accelerate @@ -714,12 +705,12 @@ torch==2.6.0 # torchvision # vector-quantize-pytorch # vocos -torchaudio==2.6.0 +torchaudio==2.7.0+cu128 # via # -r requirements/test.in # encodec # vocos -torchvision==0.21.0 +torchvision==0.22.0+cu128 # via # -r requirements/test.in # timm @@ -748,7 +739,7 @@ transformers==4.51.3 # transformers-stream-generator transformers-stream-generator==0.0.5 # via -r requirements/test.in -triton==3.2.0 +triton==3.3.0 # via torch tritonclient==2.51.0 # via diff --git a/requirements/tpu.txt b/requirements/tpu.txt index b63993ba1ee453776333bfde58328aba53722293..11501bc5d92f3e7fa490ddcab756c2ed14678a80 100644 --- a/requirements/tpu.txt +++ b/requirements/tpu.txt @@ -3,12 +3,13 @@ # Dependencies for TPU cmake>=3.26 -packaging +packaging>=24.2 setuptools-scm>=8 wheel jinja2>=3.1.6 ray[default] ray[data] +setuptools==78.1.0 # Install torch_xla --pre @@ -17,9 +18,9 @@ ray[data] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.8.0.dev20250408 -torchvision==0.22.0.dev20250408 -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" -torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250408-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" +torch==2.8.0.dev20250430 +torchvision==0.22.0.dev20250430 +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp39-cp39-linux_x86_64.whl ; python_version == "3.9" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp310-cp310-linux_x86_64.whl ; python_version == "3.10" +torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250430-cp311-cp311-linux_x86_64.whl ; python_version == "3.11" diff --git a/requirements/xpu.txt b/requirements/xpu.txt index fa09004d0a9cb9b556d9a38cd8637652de982e7e..04c4d4ff85a0db4fdaa1aba650b9d73b4ac54394 100644 --- a/requirements/xpu.txt +++ b/requirements/xpu.txt @@ -3,14 +3,14 @@ ray>=2.9 cmake>=3.26 -packaging +packaging>=24.2 setuptools-scm>=8 -setuptools>=75.8.0 +setuptools>=77.0.3,<80.0.0 wheel jinja2>=3.1.6 datasets # for benchmark scripts -torch==2.6.0+xpu +torch==2.7.0+xpu torchaudio torchvision pytorch-triton-xpu @@ -18,6 +18,6 @@ pytorch-triton-xpu # Please refer xpu doc, we need manually install intel-extension-for-pytorch 2.6.10+xpu due to there are some conflict dependencies with torch 2.6.0+xpu # FIXME: This will be fix in ipex 2.7. just leave this here for awareness. -# intel-extension-for-pytorch==2.6.10+xpu -oneccl_bind_pt==2.6.0+xpu +intel-extension-for-pytorch==2.7.10+xpu +oneccl_bind_pt==2.7.0+xpu --extra-index-url=https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ diff --git a/setup.py b/setup.py index ec404d5acb82654f8cd9cca8a5c9629e9f6b8dfa..e67bdc2e87cbcc826eb0f895b2ade56feb4060e8 100755 --- a/setup.py +++ b/setup.py @@ -54,7 +54,7 @@ elif (sys.platform.startswith("linux") and torch.version.cuda is None # fallback to cpu VLLM_TARGET_DEVICE = "cpu" -MAIN_CUDA_VERSION = "12.4" +MAIN_CUDA_VERSION = "12.8" def is_sccache_available() -> bool: diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 48e2e31e5db885d950ecd1a7bc1eaa528a27597f..b6f44871497c8bcd9ccb243c395fde4e02d74cd8 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -41,7 +41,7 @@ class MockEngine: self.abort_request_calls = 0 self.request_id = None # Ugly, remove dependency when possible - self.parallel_config = ParallelConfig(1, 1, False) + self.parallel_config = ParallelConfig() self.model_config = MockModelConfig() async def step_async(self, virtual_engine): diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 1458f0893a93c65b3721896c4dd98fadb5617ccb..9f3b0e8ae079b630eb1e7b8eca1b9f1cc3dd0532 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -5,11 +5,13 @@ Run `pytest tests/basic_correctness/test_basic_correctness.py`. """ import os import weakref +from unittest.mock import Mock import pytest from vllm import LLM from vllm.platforms import current_platform +from vllm.v1.engine.llm_engine import LLMEngine as LLMEngineV1 from ..conftest import VllmRunner from ..models.utils import check_outputs_equal @@ -152,9 +154,44 @@ def test_models_distributed( with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +def test_failed_model_execution(vllm_runner, monkeypatch) -> None: + + from vllm.envs import VLLM_USE_V1 + + if not VLLM_USE_V1: + pytest.skip("Skipping V0 test, dump input not supported") + + # Needed to mock an error in the same process + monkeypatch.setenv('VLLM_ENABLE_V1_MULTIPROCESSING', '0') + + with vllm_runner('facebook/opt-125m', enforce_eager=True) as vllm_model: + if isinstance(vllm_model.model.llm_engine, LLMEngineV1): + v1_test_failed_model_execution(vllm_model) + + +def v1_test_failed_model_execution(vllm_model): + + engine = vllm_model.model.llm_engine + mocked_execute_model = Mock( + side_effect=RuntimeError("Mocked Critical Error")) + engine.engine_core.engine_core.model_executor.execute_model =\ + mocked_execute_model + + with pytest.raises(RuntimeError) as exc_info: + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + vllm_model.generate_greedy(prompts, 200, use_tqdm=False) + assert isinstance(exc_info.value, RuntimeError) + assert "Mocked Critical Error" in str(exc_info.value) diff --git a/tests/compile/piecewise/test_full_cudagraph.py b/tests/compile/piecewise/test_full_cudagraph.py new file mode 100644 index 0000000000000000000000000000000000000000..a71a40cda73ea07634961b31d6a1a4e83dde5a88 --- /dev/null +++ b/tests/compile/piecewise/test_full_cudagraph.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: Apache-2.0 +import contextlib +import os + +import pytest + +from vllm import LLM, SamplingParams +from vllm.config import CompilationConfig + +MODEL = "Qwen/Qwen2-1.5B-Instruct" + + +@contextlib.contextmanager +def temporary_environ(env_vars): + """ + Temporarily set environment variables and restore them afterward. + We have to do this vs monkeypatch because monkeypatch doesn't work + with "module" scoped fixtures. + """ + original_env = {k: os.environ.get(k) for k in env_vars} + try: + os.environ.update(env_vars) + yield + finally: + for k, v in original_env.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + +@pytest.fixture(scope="module") +def full_cudagraph_llm(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.2, + compilation_config=CompilationConfig(full_cuda_graph=True)) + + +@pytest.fixture(scope="module") +def piecewise_llm(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": "3" + }): + return LLM(model=MODEL, + gpu_memory_utilization=0.5, + compilation_config=CompilationConfig()) + + +def generate_text(llm: LLM, batch_size: int, max_tokens: int): + prompts = ["Hi my name is"] * batch_size + sampling_params = SamplingParams(temperature=0.0, + max_tokens=max_tokens, + top_p=0.95) + + return llm.generate(prompts, sampling_params) + + +@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10), + (16, 10), (25, 10), + (32, 10), (45, 10), + (64, 10), (8, 5), + (8, 20), (8, 200)]) +def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm, + piecewise_llm): + """ + Load full cudagraph model and piecewise model once, and at the same time to + reuse them across various test cases. + + Test various batch sizes and max_tokens to ensure that the full cudagraph + compilation works for padded cases too. + """ + piecewise_responses = generate_text(piecewise_llm, + batch_size=batch_size, + max_tokens=max_tokens) + full_cudagraph_responses = generate_text(full_cudagraph_llm, + batch_size=batch_size, + max_tokens=max_tokens) + + # Check that all responses are the same + for i in range(len(piecewise_responses)): + assert piecewise_responses[i].outputs[ + 0].text == full_cudagraph_responses[i].outputs[0].text + + +def test_full_cudagraph_with_invalid_backend(): + with temporary_environ({ + "VLLM_USE_V1": "1", + "VLLM_FLASH_ATTN_VERSION": + "2" #FA2 not supported with full_cuda_graph + }), pytest.raises(RuntimeError): + LLM(model=MODEL, + compilation_config=CompilationConfig(full_cuda_graph=True)) diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 0b76779b3a752e87403cbd2ae2662c93ea6fa31a..b6b45d1cbe880560d74bd2a4894c1c039cbc74c4 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -103,7 +103,8 @@ def test_compile_correctness( method = test_setting.method fullgraph = test_setting.fullgraph if cuda_device_count_stateless() != pp_size * tp_size: - pytest.skip("Not correct CUDA devices for the test.") + pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got " + f"{cuda_device_count_stateless()}") with monkeypatch.context() as m: m.setenv("VLLM_ATTENTION_BACKEND", attn_backend) diff --git a/tests/compile/test_full_graph.py b/tests/compile/test_full_graph.py index c094063859876ff715247de8b344021f8db94142..397517b8665bc3f44342c475c8ddb6cfd549f532 100644 --- a/tests/compile/test_full_graph.py +++ b/tests/compile/test_full_graph.py @@ -9,7 +9,7 @@ import torch from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationConfig, CompilationLevel, PassConfig from vllm.platforms import current_platform from ..utils import create_new_process_for_each_test @@ -95,9 +95,6 @@ def test_full_graph( run_model(optimization_level, model, model_kwargs) -PassConfig = CompilationConfig.PassConfig - - # TODO(luka) add other supported compilation config scenarios here @pytest.mark.parametrize( "compilation_config, model_info", diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py index 27cd10b77491ae03ce1e8e1fc0d91214b5a2780f..5d38ff91490ee10d35ad16ef757786e06b20ec5f 100644 --- a/tests/compile/test_functionalization.py +++ b/tests/compile/test_functionalization.py @@ -5,19 +5,19 @@ import torch import vllm.envs as envs from vllm import LLM, SamplingParams +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass from vllm.compilation.fix_functionalization import FixFunctionalizationPass from vllm.compilation.fusion import (FUSED_OPS, FusionPass, QuantKey, kFp8DynamicTokenSym, kFp8StaticTensorSym) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, PassConfig, VllmConfig from .backend import TestBackend OPS_IN_MODEL = [ torch.ops._C.rotary_embedding.default, torch.ops._C.fused_add_rms_norm.default, - torch.ops._C.silu_and_mul.default, ] RMS_OP = torch.ops._C.rms_norm.default @@ -29,6 +29,9 @@ RMS_QUANT_OPS = { ], } +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + +SILU_MUL_QUANT_OP = torch.ops._C.silu_and_mul_quant.default prompts = [ "Hello, my name is", "The president of the United States is", @@ -50,13 +53,14 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, torch.set_default_device("cuda") vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig(pass_config= \ - CompilationConfig.PassConfig(enable_fusion=do_fusion, - enable_noop=True)) + vllm_config.compilation_config = CompilationConfig( + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True)) noop_pass = NoOpEliminationPass(vllm_config) fusion_pass = FusionPass.instance(vllm_config) + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) - passes = [noop_pass, fusion_pass] if do_fusion else [noop_pass] + passes = [noop_pass, fusion_pass, act_quant_fusion_pass + ] if do_fusion else [noop_pass] func_pass = FixFunctionalizationPass(vllm_config) backend_func = TestBackend(*passes, func_pass) backend_no_func = TestBackend(*passes) @@ -79,6 +83,7 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, model_runner.model = torch.compile(orig_model, fullgraph=True, backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) for output_func, output_no_func in zip(gen_func, gen_no_func): @@ -88,7 +93,12 @@ def test_fix_functionalization(model: str, quant_key: QuantKey, # and replaced by fused quantized ops in RMS_QUANT_OPS. rms_ops = [FUSED_OPS[(quant_key, True)], FUSED_OPS[(quant_key, False)] ] if do_fusion else [RMS_OP] - ops = OPS_IN_MODEL + rms_ops + silu_mul_ops = [SILU_MUL_QUANT_OP] if do_fusion and \ + quant_key == kFp8StaticTensorSym else [ + SILU_MUL_OP + ] + + ops = OPS_IN_MODEL + rms_ops + silu_mul_ops for op in ops: find_auto_fn(backend_no_func.graph_post_pass.nodes, op) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 6a696fe0226b1a4a8dd61797aa5842154226ecda..4d56b34bdecfb9caf622029b4ba8d6f14f36bf01 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -9,7 +9,8 @@ from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey, FusionPass, QuantKey) from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe from vllm.compilation.noop_elimination import NoOpEliminationPass -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig +from vllm.config import (CompilationConfig, CompilationLevel, PassConfig, + VllmConfig) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity) @@ -78,8 +79,7 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static, vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"])) vllm_config.compilation_config.pass_config = \ - CompilationConfig.PassConfig(enable_fusion=True, - enable_noop=True) + PassConfig(enable_fusion=True, enable_noop=True) with vllm.config.set_current_vllm_config(vllm_config): # Reshape pass is needed for the fusion pass to work noop_pass = NoOpEliminationPass(vllm_config) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py index 673ebe8b6fdc0c800f4b2d832bcd07db94f97061..b630d0e85d31ac7cc9c0fe0edb6d92e54429e79c 100644 --- a/tests/compile/test_pass_manager.py +++ b/tests/compile/test_pass_manager.py @@ -22,7 +22,7 @@ def test_bad_callable(): pass_manager.configure(config) with pytest.raises(AssertionError): - pass_manager.add(simple_callable) # noqa, type wrong on purpose + pass_manager.add(simple_callable) # Pass that inherits from InductorPass diff --git a/tests/compile/test_sequence_parallelism.py b/tests/compile/test_sequence_parallelism.py index 79f5486dadcdd6d06c683b6dd2913541a7dad227..6152f171705b16a2df1e38bf57633a51025420c8 100644 --- a/tests/compile/test_sequence_parallelism.py +++ b/tests/compile/test_sequence_parallelism.py @@ -10,7 +10,7 @@ from vllm.compilation.fx_utils import (find_auto_fn, find_auto_fn_maybe, find_specified_fn_maybe, is_func) from vllm.compilation.sequence_parallelism import SequenceParallelismPass from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig, - VllmConfig) + PassConfig, VllmConfig) from vllm.distributed import tensor_model_parallel_all_reduce from vllm.distributed.parallel_state import (init_distributed_environment, initialize_model_parallel) @@ -126,9 +126,8 @@ def sequence_parallelism_pass_on_test_model(local_rank: int, world_size: int, # configure vllm config for SequenceParallelismPass vllm_config = VllmConfig() - vllm_config.compilation_config = CompilationConfig( - pass_config=CompilationConfig.PassConfig( - enable_sequence_parallelism=True, ), ) + vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig( + enable_sequence_parallelism=True)) vllm_config.device_config = DeviceConfig(device=torch.device("cuda")) # this is a fake model name to construct the model config diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..9eae48d60f368c14a063f60eeb5ec0335cd3f20b --- /dev/null +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +import vllm.envs as envs +from vllm._custom_ops import scaled_fp8_quant +from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass +from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe +from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.model_executor.layers.activation import SiluAndMul + +from .backend import TestBackend + + +class TestModel(torch.nn.Module): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.silu_and_mul = SiluAndMul() + self.scale = torch.rand(1, dtype=torch.float32) + + def forward(self, x): + y = self.silu_and_mul(x) + x2 = scaled_fp8_quant(y, self.scale) + return x2 + + +@pytest.mark.parametrize("num_tokens", [256]) +@pytest.mark.parametrize("hidden_size", [64]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], + reason="Only test on CUDA and ROCm") +def test_fusion_silu_and_mul_quant(num_tokens, hidden_size): + torch.set_default_device("cuda") + torch.set_default_dtype(torch.float16) + + # Reshape pass is needed for the fusion pass to work + config = VllmConfig() + config.compilation_config = CompilationConfig( + pass_config=PassConfig(enable_fusion=True, enable_noop=True)) + fusion_pass = ActivationQuantFusionPass(config) + + backend = TestBackend(fusion_pass) + model = TestModel() + + # First dimension dynamic + x = torch.rand(num_tokens, hidden_size) + torch._dynamo.mark_dynamic(x, 0) + + result = model(x) + + model2 = torch.compile(model, backend=backend) + result2 = model2(x) + + # Check that it gives the same answer + torch.testing.assert_close(result[0].to(dtype=torch.float16), + result2[0].to(dtype=torch.float16), + atol=1e-3, + rtol=1e-3) + + # Check substitution worked + pre_nodes = backend.graph_pre_pass.nodes + post_nodes = backend.graph_post_pass.nodes + + silu_and_mul_quant = torch.ops._C.silu_and_mul_quant.default + fp8_quant = torch.ops._C.static_scaled_fp8_quant.default + + # In pre-nodes, fp8 quant should be present and fused kernels should not + assert find_auto_fn_maybe(pre_nodes, silu_and_mul_quant) is None + find_auto_fn(pre_nodes, fp8_quant) + + # In post-nodes, fused kernels should be present and fp8 quant should not + find_auto_fn(post_nodes, silu_and_mul_quant) + assert find_auto_fn_maybe(post_nodes, fp8_quant) is None diff --git a/tests/conftest.py b/tests/conftest.py index e62b56cb58252b89bb783f210c6dd33e6b66088b..c5700179c228411b93b56cd5b8f0fd722428bb14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 - import json import os import tempfile -from collections import UserList from enum import Enum from typing import Any, Callable, Optional, TypedDict, TypeVar, Union @@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]: return prompts -class _ImageAssetPrompts(TypedDict): +class ImageAssetPrompts(TypedDict): stop_sign: str cherry_blossom: str -class _ImageAssetsBase(UserList[ImageAsset]): - pass - - -class _ImageAssets(_ImageAssetsBase): +class ImageTestAssets(list[ImageAsset]): def __init__(self) -> None: super().__init__([ @@ -75,7 +69,7 @@ class _ImageAssets(_ImageAssetsBase): ImageAsset("cherry_blossom"), ]) - def prompts(self, prompts: _ImageAssetPrompts) -> list[str]: + def prompts(self, prompts: ImageAssetPrompts) -> list[str]: """ Convenience method to define the prompt for each test image. @@ -85,30 +79,27 @@ class _ImageAssets(_ImageAssetsBase): return [prompts["stop_sign"], prompts["cherry_blossom"]] -class _VideoAssetPrompts(TypedDict): - sample_demo_1: str +class VideoAssetPrompts(TypedDict): + baby_reading: str -class _VideoAssetsBase(UserList[VideoAsset]): - pass - - -class _VideoAssets(_VideoAssetsBase): +class VideoTestAssets(list[VideoAsset]): def __init__(self) -> None: super().__init__([ - VideoAsset("sample_demo_1.mp4"), + VideoAsset("baby_reading"), ]) - def prompts(self, prompts: _VideoAssetPrompts) -> list[str]: - return [prompts["sample_demo_1"]] + def prompts(self, prompts: VideoAssetPrompts) -> list[str]: + return [prompts["baby_reading"]] -class _AudioAssetsBase(UserList[AudioAsset]): - pass +class AudioAssetPrompts(TypedDict): + mary_had_lamb: str + winning_call: str -class _AudioAssets(_AudioAssetsBase): +class AudioTestAssets(list[AudioAsset]): def __init__(self) -> None: super().__init__([ @@ -116,13 +107,16 @@ class _AudioAssets(_AudioAssetsBase): AudioAsset("winning_call"), ]) + def prompts(self, prompts: AudioAssetPrompts) -> list[str]: + return [prompts["mary_had_lamb"], prompts["winning_call"]] + -IMAGE_ASSETS = _ImageAssets() -"""Singleton instance of :class:`_ImageAssets`.""" -VIDEO_ASSETS = _VideoAssets() -"""Singleton instance of :class:`_VideoAssets`.""" -AUDIO_ASSETS = _AudioAssets() -"""Singleton instance of :class:`_AudioAssets`.""" +IMAGE_ASSETS = ImageTestAssets() +"""Singleton instance of {class}`ImageTestAssets`.""" +VIDEO_ASSETS = VideoTestAssets() +"""Singleton instance of {class}`VideoTestAssets`.""" +AUDIO_ASSETS = AudioTestAssets() +"""Singleton instance of {class}`AudioTestAssets`.""" @pytest.fixture(scope="function", autouse=True) @@ -270,17 +264,17 @@ def example_long_prompts() -> list[str]: @pytest.fixture(scope="session") -def image_assets() -> _ImageAssets: +def image_assets() -> ImageTestAssets: return IMAGE_ASSETS @pytest.fixture(scope="session") -def video_assets() -> _VideoAssets: +def video_assets() -> VideoTestAssets: return VIDEO_ASSETS @pytest.fixture(scope="session") -def audio_assets() -> _AudioAssets: +def audio_assets() -> AudioTestAssets: return AUDIO_ASSETS @@ -293,7 +287,8 @@ class HfRunner: def get_default_device(self): from vllm.platforms import current_platform - return ("cpu" if current_platform.is_cpu() else "cuda") + return ("cpu" + if current_platform.is_cpu() else current_platform.device_type) def wrap_device(self, x: _T, device: Optional[str] = None) -> _T: if x is None or isinstance(x, (bool, )): @@ -360,10 +355,16 @@ class HfRunner: **model_kwargs, ) + # in case some unquantized custom models are not in same dtype + if (getattr(model, "quantization_method", None) is None + and any(p.dtype != self.dtype + for p in model.parameters())): + model = model.to(dtype=self.dtype) + if (getattr(model, "quantization_method", None) != "bitsandbytes" and len({p.device for p in model.parameters()}) < 2): - model = model.to(self.device) + model = model.to(device=self.device) self.model = model @@ -729,7 +730,7 @@ def hf_runner(): class VllmRunner: """ The default value of some arguments have been modified from - :class:`~vllm.LLM` as follows: + {class}`~vllm.LLM` as follows: - `trust_remote_code`: Set to `True` instead of `False` for convenience. - `seed`: Set to `0` instead of `None` for test reproducibility. @@ -737,7 +738,7 @@ class VllmRunner: - `block_size`: Set to `16` instead of `None` to reduce memory usage. - `enable_chunked_prefill`: Set to `False` instead of `None` for test reproducibility. - - `enforce_eager`: Set to `False` instead of `None` to test CUDA graph. + - `enforce_eager`: Set to `False` to test CUDA graph. """ def __init__( @@ -778,7 +779,7 @@ class VllmRunner: def get_inputs( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, @@ -800,16 +801,18 @@ class VllmRunner: if audios is not None and (audio := audios[i]) is not None: multi_modal_data["audio"] = audio - inputs.append( - TextPrompt(prompt=prompt, - multi_modal_data=multi_modal_data - if multi_modal_data else None)) + text_prompt_kwargs = { + ("prompt" if isinstance(prompt, str) else "prompt_embeds"): + prompt, + "multi_modal_data": multi_modal_data or None + } + inputs.append(TextPrompt(**text_prompt_kwargs)) return inputs def generate( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, @@ -835,7 +838,7 @@ class VllmRunner: output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append(prompt_str + output_str) + req_sample_output_strs.append((prompt_str or "") + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @@ -902,7 +905,7 @@ class VllmRunner: def generate_greedy( self, - prompts: list[str], + prompts: Union[list[str], list[torch.Tensor]], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 8bd64923fe2291de78720f6dff62444536eca437..a5ba16898d89118c9119ac572a25afb762738620 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -2,16 +2,18 @@ import time from collections import deque +from typing import Optional from unittest.mock import MagicMock import pytest # noqa +import torch from torch import Use # noqa from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus from vllm.core.scheduler import Scheduler, SchedulingBudget from vllm.lora.request import LoRARequest -from vllm.sequence import SequenceGroup +from vllm.sequence import SequenceGroup, SequenceStatus from .utils import (append_new_token, append_new_token_seq, append_new_token_seq_group, create_dummy_prompt, @@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( ), "A partial prefix of C (4 tokens) should be prefilled, with the " "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " "then be rounded down to 2 tokens on block size, thus 6 tokens in total." + + +def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds(): + """ + Test that the scheduler does not schedule batches with prompt tokens and + prompt embeddings co-mingled. + """ + block_size = 2 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + ) + + # the odd indexed inputs should be passed in via embeddings, + # evens via token_ids + seq_length = 7 + embedding_size = 5 + num_seqs = 11 + seq_tokens: list[list[int]] = [] + seq_embeds: list[Optional[torch.Tensor]] = [] + for i in range(num_seqs): + if i % 2: + seq_tokens.append(list(range(seq_length))) + seq_embeds.append(None) + else: + seq_tokens.append([0] * seq_length) + seq_embeds.append(torch.rand(embedding_size)) + + seq_and_seq_groups = [ + create_dummy_prompt(f"{i}", + prompt_tokens=seq_tokens[i], + prompt_embeds=seq_embeds[i], + block_size=block_size) + for i in range(len(seq_tokens)) + ] + + for _, seq_group in seq_and_seq_groups: + scheduler.add_seq_group(seq_group) + + while not all(seq.is_finished() for seq, _ in seq_and_seq_groups): + unfinished_seq_groups = [ + seq_group for _, seq_group in seq_and_seq_groups + if not seq_group.is_finished() + ] + _, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) > 0 + batch_is_prompt_embeds = out.scheduled_seq_groups[ + 0].seq_group.uses_prompt_embeds() + expected_scheduled_seq_groups = [ + seq_group for seq_group in unfinished_seq_groups + if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds + ] + + # We should have as many scheduled groups as possible, without mixing + assert len(out.scheduled_seq_groups) == min( + max_seq_group, len(expected_scheduled_seq_groups)) + assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() == + batch_is_prompt_embeds + for scheduled_seq_group in out.scheduled_seq_groups) + + # Finish the scheduled groups + for scheduled_seq_group in out.scheduled_seq_groups: + for seq in scheduled_seq_group.seq_group.seqs: + seq.status = SequenceStatus.FINISHED_STOPPED + scheduler.free_finished_seq_groups() diff --git a/tests/core/utils.py b/tests/core/utils.py index ea18b879a31727a1429ada196f836418f886dc30..84b0426b470bc9cd012da02bb0db11126773154c 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -5,9 +5,11 @@ from collections import defaultdict from collections.abc import Sequence as GenericSequence from typing import Any, Optional +import torch + from vllm import SamplingParams from vllm.core.scheduler import Scheduler, SchedulerOutputs -from vllm.inputs import EncoderDecoderInputs, token_inputs +from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs from vllm.lora.request import LoRARequest from vllm.sequence import (Logprob, Sequence, SequenceGroup, SequenceGroupMetadata) @@ -19,6 +21,7 @@ def create_dummy_prompt( block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, prompt_tokens: Optional[list[int]] = None, + prompt_embeds: Optional[torch.Tensor] = None, min_tokens: int = 0, max_tokens: int = 16, ) -> tuple[Sequence, SequenceGroup]: @@ -31,9 +34,13 @@ def create_dummy_prompt( prompt_tokens = list(range(prompt_length)) prompt_str = " ".join([str(t) for t in prompt_tokens]) + inputs = token_inputs( + prompt_token_ids=prompt_tokens, + prompt=prompt_str) if prompt_embeds is None else embeds_inputs( + prompt_embeds=prompt_embeds) prompt = Sequence( int(request_id), - inputs=token_inputs(prompt_tokens, prompt=prompt_str), + inputs=inputs, block_size=block_size, ) seq_group = SequenceGroup( diff --git a/tests/distributed/conftest.py b/tests/distributed/conftest.py new file mode 100644 index 0000000000000000000000000000000000000000..ee8f2097933d1dbdbb5b7b0942a04f3faf74b5d6 --- /dev/null +++ b/tests/distributed/conftest.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +import random +from typing import Optional, Union + +import msgspec +import msgspec.msgpack +import pytest +import zmq + +from vllm.config import KVEventsConfig +from vllm.distributed.kv_events import EventPublisherFactory + +from .test_events import SampleBatch + + +@pytest.fixture +def random_port(): + """Generate a random port number for testing""" + return random.randint(10000, 60000) + + +@pytest.fixture +def publisher_config(random_port, request): + """Create a publisher config with inproc transport""" + how = request.param if hasattr(request, "param") else "inproc" + + if how == "inproc": + endpoint = f"inproc://test-{random_port}" + replay_endpoint = endpoint + "-replay" + else: + endpoint = f"tcp://*:{random_port}" + replay_endpoint = f"tcp://*:{random_port + 1}" + + return KVEventsConfig(enable_kv_cache_events=True, + publisher="zmq", + endpoint=endpoint, + replay_endpoint=replay_endpoint, + buffer_steps=100, + hwm=1000, + topic="test") + + +@pytest.fixture +def publisher(publisher_config): + """Create and return a publisher instance""" + pub = EventPublisherFactory.create(publisher_config) + yield pub + pub.shutdown() + + +@pytest.fixture +def subscriber(publisher_config): + """Create and return a subscriber for testing""" + endpoint = publisher_config.endpoint + replay_endpoint = publisher_config.replay_endpoint + + if endpoint.startswith("tcp://*"): + endpoint = endpoint.replace("*", "127.0.0.1") + if replay_endpoint and replay_endpoint.startswith("tcp://*"): + replay_endpoint = replay_endpoint.replace("*", "127.0.0.1") + + sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic) + yield sub + sub.close() + + +class MockSubscriber: + """Helper class to receive and verify published events""" + + def __init__(self, + pub_endpoint: str, + replay_endpoint: Optional[str] = None, + topic: str = "", + decode_type=SampleBatch): + self.ctx = zmq.Context.instance() + + # Set up subscriber socket + self.sub = self.ctx.socket(zmq.SUB) + self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8')) + self.sub.connect(pub_endpoint) + + # Set up replay socket if provided + self.replay = None + if replay_endpoint: + self.replay = self.ctx.socket(zmq.REQ) + self.replay.connect(replay_endpoint) + + self.topic = topic + self.topic_bytes = topic.encode('utf-8') + self.received_msgs: list[tuple[int, SampleBatch]] = [] + self.last_seq = -1 + self.decoder = msgspec.msgpack.Decoder(type=decode_type) + + def receive_one(self, + timeout=1000) -> Union[tuple[int, SampleBatch], None]: + """Receive a single message with timeout""" + if not self.sub.poll(timeout): + return None + + topic_bytes, seq_bytes, payload = self.sub.recv_multipart() + assert topic_bytes == self.topic_bytes + + seq = int.from_bytes(seq_bytes, "big") + data = self.decoder.decode(payload) + self.last_seq = seq + self.received_msgs.append((seq, data)) + return seq, data + + def request_replay(self, start_seq: int) -> None: + """Request replay of messages starting from start_seq""" + if not self.replay: + raise ValueError("Replay socket not initialized") + + self.replay.send(start_seq.to_bytes(8, "big")) + + def receive_replay(self) -> list[tuple[int, SampleBatch]]: + """Receive replayed messages""" + if not self.replay: + raise ValueError("Replay socket not initialized") + + replayed: list[tuple[int, SampleBatch]] = [] + while True: + try: + if not self.replay.poll(1000): + break + + frames = self.replay.recv_multipart() + if not frames or not frames[-1]: + # End of replay marker + break + + seq_bytes, payload = frames + seq = int.from_bytes(seq_bytes, "big") + data = self.decoder.decode(payload) + replayed.append((seq, data)) + except zmq.ZMQError as _: + break + + return replayed + + def close(self): + """Clean up resources""" + self.sub.close() + if self.replay: + self.replay.close() diff --git a/tests/distributed/test_events.py b/tests/distributed/test_events.py new file mode 100644 index 0000000000000000000000000000000000000000..15bcfdb8555f3880f70acdc79d9fbf9ecd5ed63a --- /dev/null +++ b/tests/distributed/test_events.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: Apache-2.0 +import threading +import time + +import msgspec +import pytest + +from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory, + NullEventPublisher) + + +class EventSample( + msgspec.Struct, + tag=True, # type: ignore + array_like=True # type: ignore +): + """Test event for publisher testing""" + id: int + value: str + + +class SampleBatch(EventBatch): + """Test event batch for publisher testing""" + events: list[EventSample] + + +def create_test_events(count: int) -> SampleBatch: + """Create a batch of test events""" + events = [EventSample(id=i, value=f"test-{i}") for i in range(count)] + return SampleBatch(ts=time.time(), events=events) + + +def test_basic_publishing(publisher, subscriber): + """Test basic event publishing works""" + + test_batch = create_test_events(5) + publisher.publish(test_batch) + + result = subscriber.receive_one(timeout=1000) + assert result is not None, "No message received" + + seq, received = result + assert seq == 0, "Sequence number mismatch" + assert received.ts == pytest.approx(test_batch.ts, + abs=0.1), ("Timestamp mismatch") + assert len(received.events) == len( + test_batch.events), ("Number of events mismatch") + + for i, event in enumerate(received.events): + assert event.id == i, "Event id mismatch" + assert event.value == f"test-{i}", "Event value mismatch" + + +def test_multiple_events(publisher, subscriber): + """Test publishing and receiving multiple event batches""" + for _ in range(10): + batch = create_test_events(2) + publisher.publish(batch) + + received = [] + for _ in range(10): + data = subscriber.receive_one(timeout=100) + if data: + received.append(data) + + assert len(received) == 10, "Number of messages mismatch" + seqs = [seq for seq, _ in received] + assert seqs == list(range(10)), "Sequence numbers mismatch" + + +def test_replay_mechanism(publisher, subscriber): + """Test the replay mechanism works correctly""" + for _ in range(19): + batch = create_test_events(1) + publisher.publish(batch) + + time.sleep(0.5) # Need publisher to process above requests + subscriber.request_replay(10) + + batch = create_test_events(1) + publisher.publish(batch) # 20th message + + replayed = subscriber.receive_replay() + + assert len(replayed) > 0, "No replayed messages received" + seqs = [seq for seq, _ in replayed] + assert all(seq >= 10 for seq in seqs), "Replayed messages not in order" + assert seqs == list(range(min(seqs), + max(seqs) + + 1)), ("Replayed messages not consecutive") + + +def test_buffer_limit(publisher, subscriber, publisher_config): + """Test buffer limit behavior""" + buffer_size = publisher_config.buffer_steps + + # Publish more events than the buffer can hold + for i in range(buffer_size + 10): + batch = create_test_events(1) + publisher.publish(batch) + + time.sleep(0.5) # Need publisher to process above requests + subscriber.request_replay(0) + + batch = create_test_events(1) + publisher.publish(batch) + + replayed = subscriber.receive_replay() + + assert len(replayed) <= buffer_size, "Can't replay more than buffer size" + + oldest_seq = min(seq for seq, _ in replayed) + assert oldest_seq >= 10, "The oldest sequence should be at least 10" + + +def test_topic_filtering(publisher_config): + """ + Test that a subscriber only receives messages matching its topic filter + """ + publisher_config.replay_endpoint = None + + cfg = publisher_config.model_copy() + cfg.topic = "foo" + pub = EventPublisherFactory.create(cfg) + + from .conftest import MockSubscriber + sub_foo = MockSubscriber(cfg.endpoint, None, "foo") + sub_bar = MockSubscriber(cfg.endpoint, None, "bar") + + try: + time.sleep(0.1) + + for _ in range(3): + pub.publish(create_test_events(1)) + + foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)] + assert all(msg is not None for msg in foo_received), ( + "Subscriber with matching topic should receive messages") + + bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)] + assert all(msg is None for msg in bar_received), ( + "Subscriber with non-matching topic should receive no messages") + finally: + pub.shutdown() + sub_foo.close() + sub_bar.close() + + +def test_high_volume(publisher, subscriber): + """Test publishing and receiving a high volume of events""" + num_batches = 10_000 + events_per_batch = 100 + + # Publish events in a separate thread to not block + def publish_events(): + for i in range(num_batches): + batch = create_test_events(events_per_batch) + publisher.publish(batch) + # Small delay to avoid overwhelming + if i % 100 == 0: + time.sleep(0.01) + + received: list[tuple[int, SampleBatch]] = [] + + publisher_thread = threading.Thread(target=publish_events) + publisher_thread.start() + + start_time = time.time() + while len(received) < num_batches: + if time.time() - start_time > 10: # Timeout after 10 seconds + break + + result = subscriber.receive_one(timeout=100) + if result: + received.append(result) + + publisher_thread.join() + + assert len(received) >= num_batches * 0.9, ( + "We should have received most messages") + + seqs = [seq for seq, _ in received] + assert sorted(seqs) == seqs, "Sequence numbers should be in order" + + +def test_null_publisher(): + """Test that NullEventPublisher can be used without errors""" + publisher = NullEventPublisher() + + # This should not raise any errors + batch = create_test_events(5) + publisher.publish(batch) + publisher.shutdown() diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 03de8d9b92bff3709095516bebb8a835a6c41915..5346d67b10d16d7cc947239accb07250573be2db 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -100,9 +100,8 @@ class PPTestSettings: eager_mode=True, chunked_prefill=False), ], - # only ray is supported for V1 - distributed_backends=["mp", "ray", "ray"], - vllm_major_versions=["0", "0", "1"], + distributed_backends=["mp", "mp", "ray", "ray"], + vllm_major_versions=["0", "1", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -186,7 +185,7 @@ TEXT_GENERATION_MODELS = { "mosaicml/mpt-7b": PPTestSettings.fast(), "nvidia/Minitron-8B-Base": PPTestSettings.fast(), "allenai/OLMo-1B-hf": PPTestSettings.fast(), - "shanearora/OLMo-7B-1124-hf": PPTestSettings.fast(), + "allenai/OLMo-2-0425-1B": PPTestSettings.fast(), "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(), @@ -350,6 +349,11 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") + elif distributed_backend == "mp": + # Both V0/V1 of multiprocessing executor support PP + pp_env = { + "VLLM_USE_V1": vllm_major_version, + } else: pp_env = None diff --git a/tests/distributed/test_sequence_parallel.py b/tests/distributed/test_sequence_parallel.py index 19497ad9c14090aec82d499fafd4fbbd317fa17e..c9eba2b43788ecbc4c91f9b66a51dc5af2798277 100644 --- a/tests/distributed/test_sequence_parallel.py +++ b/tests/distributed/test_sequence_parallel.py @@ -26,6 +26,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" class ParallelSetup(NamedTuple): tp_size: int + pp_size: int sp_enabled: bool eager_mode: bool chunked_prefill: bool @@ -60,6 +61,7 @@ class SPTestSettings: def detailed( *, tp_base: int = 2, + pp_base: int = 1, multi_node_only: bool = False, task: TaskOption = "auto", load_format: Optional[str] = None, @@ -67,18 +69,42 @@ class SPTestSettings: return SPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=False), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=True), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, sp_enabled=True, eager_mode=True, chunked_prefill=False), ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + sp_enabled=True, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, sp_enabled=True, eager_mode=True, chunked_prefill=True) @@ -94,6 +120,7 @@ class SPTestSettings: def fast( *, tp_base: int = 2, + pp_base: int = 1, task: TaskOption = "auto", multi_node_only: bool = False, load_format: Optional[str] = None, @@ -101,6 +128,12 @@ class SPTestSettings: return SPTestSettings( parallel_setups=[ ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + sp_enabled=True, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, sp_enabled=True, eager_mode=False, chunked_prefill=False), @@ -136,6 +169,7 @@ def _compare_sp( ): ( tp_size, + pp_size, sp_enabled, eager_mode, chunked_prefill, @@ -167,7 +201,6 @@ def _compare_sp( else: model_info.check_available_online(on_fail="skip") - pp_size = 1 if num_gpus_available < tp_size * pp_size: pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs") if VLLM_MULTI_NODE and distributed_backend == "mp": @@ -206,7 +239,7 @@ def _compare_sp( 'compile_sizes': [4, 8], 'splitting_ops': [], 'pass_config': { - 'enable_sequence_parallism': sp_enabled, + 'enable_sequence_parallelism': sp_enabled, 'enable_noop': True, 'enable_fusion': True, }, @@ -223,7 +256,7 @@ def _compare_sp( "--distributed-executor-backend", distributed_backend, "--compilation_config", - str(compilation_config), + json.dumps(compilation_config), ] tp_env = { @@ -256,7 +289,7 @@ def _compare_sp( SP_TEXT_GENERATION_MODELS = { # [Decoder-only] - "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.detailed(), + "meta-llama/Llama-3.2-1B-Instruct": SPTestSettings.fast(), } SP_TEST_MODELS = [ diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 0420a6454d461121ed17d6aad367839c545b0fc6..bb38e908b7345bcfa791a105d2e246ae7b7294f4 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # unit test for `examples/offline_inference/torchrun_example.py` - +import os import random import torch.distributed as dist @@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # to test if all ranks agree on the same kv cache configuration. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), swap_space=random.randint(1, 4), diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 052d5793c1b3ab238a696c6c93a9161051148ce0..05d9cfc7ab747191b3413892439d9991c905b724 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -8,20 +8,18 @@ from typing import Literal, Optional import pytest -from vllm.config import PoolerConfig, config +from vllm.config import CompilationConfig, config from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, get_type, is_not_builtin, is_type, - nullable_kvs, optional_type) + literal_to_kwargs, nullable_kvs, + optional_type, parse_type) from vllm.utils import FlexibleArgumentParser @pytest.mark.parametrize(("type", "value", "expected"), [ (int, "42", 42), - (int, "None", None), (float, "3.14", 3.14), - (float, "None", None), (str, "Hello World!", "Hello World!"), - (str, "None", None), (json.loads, '{"foo":1,"bar":2}', { "foo": 1, "bar": 2 @@ -30,15 +28,20 @@ from vllm.utils import FlexibleArgumentParser "foo": 1, "bar": 2 }), - (json.loads, "None", None), ]) -def test_optional_type(type, value, expected): - optional_type_func = optional_type(type) +def test_parse_type(type, value, expected): + parse_type_func = parse_type(type) context = nullcontext() if value == "foo=1,bar=2": context = pytest.warns(DeprecationWarning) with context: - assert optional_type_func(value) == expected + assert parse_type_func(value) == expected + + +def test_optional_type(): + optional_type_func = optional_type(int) + assert optional_type_func("None") is None + assert optional_type_func("42") == 42 @pytest.mark.parametrize(("type_hint", "type", "expected"), [ @@ -71,9 +74,57 @@ def test_get_type(type_hints, type, expected): assert get_type(type_hints, type) == expected +@pytest.mark.parametrize(("type_hints", "expected"), [ + ({Literal[1, 2]}, { + "type": int, + "choices": [1, 2] + }), + ({Literal[1, "a"]}, Exception), +]) +def test_literal_to_kwargs(type_hints, expected): + context = nullcontext() + if expected is Exception: + context = pytest.raises(expected) + with context: + assert literal_to_kwargs(type_hints) == expected + + +@config +@dataclass +class NestedConfig: + field: int = 1 + """field""" + + +@config +@dataclass +class FromCliConfig1: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 1 + return inst + + +@config +@dataclass +class FromCliConfig2: + field: int = 1 + """field""" + + @classmethod + def from_cli(cls, cli_value: str): + inst = cls(**json.loads(cli_value)) + inst.field += 2 + return inst + + @config @dataclass -class DummyConfigClass: +class DummyConfig: regular_bool: bool = True """Regular bool with default True""" optional_bool: Optional[bool] = None @@ -81,23 +132,35 @@ class DummyConfigClass: optional_literal: Optional[Literal["x", "y"]] = None """Optional literal with default None""" tuple_n: tuple[int, ...] = field(default_factory=lambda: (1, 2, 3)) - """Tuple with default (1, 2, 3)""" + """Tuple with variable length""" tuple_2: tuple[int, int] = field(default_factory=lambda: (1, 2)) - """Tuple with default (1, 2)""" + """Tuple with fixed length""" list_n: list[int] = field(default_factory=lambda: [1, 2, 3]) - """List with default [1, 2, 3]""" + """List with variable length""" + list_literal: list[Literal[1, 2]] = field(default_factory=list) + """List with literal choices""" + literal_literal: Literal[Literal[1], Literal[2]] = 1 + """Literal of literals with default 1""" + json_tip: dict = field(default_factory=dict) + """Dict which will be JSON in CLI""" + nested_config: NestedConfig = field(default_factory=NestedConfig) + """Nested config""" + from_cli_config1: FromCliConfig1 = field(default_factory=FromCliConfig1) + """Config with from_cli method""" + from_cli_config2: FromCliConfig2 = field(default_factory=FromCliConfig2) + """Different config with from_cli method""" @pytest.mark.parametrize(("type_hint", "expected"), [ (int, False), - (DummyConfigClass, True), + (DummyConfig, True), ]) def test_is_not_builtin(type_hint, expected): assert is_not_builtin(type_hint) == expected def test_get_kwargs(): - kwargs = get_kwargs(DummyConfigClass) + kwargs = get_kwargs(DummyConfig) print(kwargs) # bools should not have their type set @@ -111,6 +174,20 @@ def test_get_kwargs(): # lists should work assert kwargs["list_n"]["type"] is int assert kwargs["list_n"]["nargs"] == "+" + # lists with literals should have the correct choices + assert kwargs["list_literal"]["type"] is int + assert kwargs["list_literal"]["nargs"] == "+" + assert kwargs["list_literal"]["choices"] == [1, 2] + # literals of literals should have merged choices + assert kwargs["literal_literal"]["choices"] == [1, 2] + # dict should have json tip in help + json_tip = "Should either be a valid JSON string or JSON keys" + assert json_tip in kwargs["json_tip"]["help"] + # nested config should should construct the nested config + assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) + # from_cli configs should be constructed with the correct method + assert kwargs["from_cli_config1"]["type"]('{"field": 2}').field == 3 + assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4 @pytest.mark.parametrize(("arg", "expected"), [ @@ -146,7 +223,7 @@ def test_compilation_config(): # default value args = parser.parse_args([]) - assert args.compilation_config is None + assert args.compilation_config == CompilationConfig() # set to O3 args = parser.parse_args(["-O3"]) @@ -163,7 +240,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config", - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) @@ -171,7 +248,7 @@ def test_compilation_config(): # set to string form of a dict args = parser.parse_args([ "--compilation-config=" - "{'level': 3, 'cudagraph_capture_sizes': [1, 2, 4, 8]}", + '{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}', ]) assert (args.compilation_config.level == 3 and args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]) @@ -196,17 +273,6 @@ def test_prefix_cache_default(): assert not engine_args.enable_prefix_caching -def test_valid_pooling_config(): - parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) - args = parser.parse_args([ - '--override-pooler-config', - '{"pooling_type": "MEAN"}', - ]) - engine_args = EngineArgs.from_cli_args(args=args) - assert engine_args.override_pooler_config == PoolerConfig( - pooling_type="MEAN", ) - - @pytest.mark.parametrize( ("arg"), [ diff --git a/tests/engine/test_options.py b/tests/engine/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf4f69d56a8703430a39d7fe944e3e5692eafad --- /dev/null +++ b/tests/engine/test_options.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext + +import pytest + +from vllm.entrypoints.llm import LLM +from vllm.sampling_params import SamplingParams + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +def test_skip_tokenizer_initialization(model: str): + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + skip_tokenizer_init=True, + enforce_eager=True, + ) + sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) + + with pytest.raises(ValueError, match="cannot pass text prompts when"): + llm.generate("abc", sampling_params) + + outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, + sampling_params=sampling_params) + assert len(outputs) > 0 + completions = outputs[0].outputs + assert len(completions) > 0 + assert completions[0].text == "" + assert completions[0].token_ids + + +@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) +@pytest.mark.parametrize("enable_prompt_embeds", [True, False]) +def test_enable_prompt_embeds(hf_runner, model: str, + enable_prompt_embeds: bool): + prompt = "abc" + + with hf_runner(model) as hf_model: + token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids + token_ids = token_ids.to(hf_model.model.device) + + embed_layer = hf_model.model.get_input_embeddings() + prompt_embeds = embed_layer(token_ids).squeeze(0) + + ctx = (nullcontext() if enable_prompt_embeds else pytest.raises( + ValueError, match="set `--enable-prompt-embeds`")) + + # This test checks if the flag skip_tokenizer_init skips the initialization + # of tokenizer and detokenizer. The generated output is expected to contain + # token ids. + llm = LLM( + model=model, + enable_prompt_embeds=enable_prompt_embeds, + enforce_eager=True, + ) + + with ctx: + llm.generate({"prompt_embeds": prompt_embeds}) diff --git a/tests/engine/test_skip_tokenizer_init.py b/tests/engine/test_skip_tokenizer_init.py deleted file mode 100644 index 5e197f5ffe5926c3b0ba060ccc994ddcc6bb4fd6..0000000000000000000000000000000000000000 --- a/tests/engine/test_skip_tokenizer_init.py +++ /dev/null @@ -1,29 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from vllm.entrypoints.llm import LLM -from vllm.sampling_params import SamplingParams - - -@pytest.mark.parametrize("model", ["distilbert/distilgpt2"]) -def test_skip_tokenizer_initialization(model: str): - # This test checks if the flag skip_tokenizer_init skips the initialization - # of tokenizer and detokenizer. The generated output is expected to contain - # token ids. - llm = LLM( - model=model, - skip_tokenizer_init=True, - ) - sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True) - - with pytest.raises(ValueError, match="cannot pass text prompts when"): - llm.generate("abc", sampling_params) - - outputs = llm.generate({"prompt_token_ids": [1, 2, 3]}, - sampling_params=sampling_params) - assert len(outputs) > 0 - completions = outputs[0].outputs - assert len(completions) > 0 - assert completions[0].text == "" - assert completions[0].token_ids diff --git a/tests/entrypoints/llm/test_chat.py b/tests/entrypoints/llm/test_chat.py index 6a4862123b517226c859bc0525856edb723eb28b..742a66683445736ba416d0bf4ef40f1b2092b1b6 100644 --- a/tests/entrypoints/llm/test_chat.py +++ b/tests/entrypoints/llm/test_chat.py @@ -1,15 +1,31 @@ # SPDX-License-Identifier: Apache-2.0 +import weakref import pytest from vllm import LLM +from vllm.distributed import cleanup_dist_env_and_memory from ..openai.test_vision import TEST_IMAGE_URLS -def test_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") +@pytest.fixture(scope="function") +def text_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", + enforce_eager=True, + seed=0) + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +def test_chat(text_llm): prompt1 = "Explain the concept of entropy." messages = [ { @@ -21,13 +37,11 @@ def test_chat(): "content": prompt1 }, ] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 1 -def test_multi_chat(): - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct") - +def test_multi_chat(text_llm): prompt1 = "Explain the concept of entropy." prompt2 = "Explain what among us is." @@ -55,13 +69,14 @@ def test_multi_chat(): messages = [conversation1, conversation2] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 2 -@pytest.mark.parametrize("image_urls", - [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) -def test_chat_multi_image(image_urls: list[str]): +@pytest.fixture(scope="function") +def vision_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection llm = LLM( model="microsoft/Phi-3.5-vision-instruct", max_model_len=4096, @@ -69,8 +84,20 @@ def test_chat_multi_image(image_urls: list[str]): enforce_eager=True, trust_remote_code=True, limit_mm_per_prompt={"image": 2}, + seed=0, ) + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(vision_llm, image_urls: list[str]): messages = [{ "role": "user", @@ -87,16 +114,15 @@ def test_chat_multi_image(image_urls: list[str]): }, ], }] - outputs = llm.chat(messages) + outputs = vision_llm.chat(messages) assert len(outputs) >= 0 -def test_llm_chat_tokenization_no_double_bos(): +def test_llm_chat_tokenization_no_double_bos(text_llm): """ LLM.chat() should not add special tokens when using chat templates. Check we get a single BOS token for llama chat. """ - llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True) messages = [ { "role": "system", @@ -107,13 +133,64 @@ def test_llm_chat_tokenization_no_double_bos(): "content": "Hello!" }, ] - outputs = llm.chat(messages) + outputs = text_llm.chat(messages) assert len(outputs) == 1 - prompt_token_ids = getattr(outputs[0], "prompt_token_ids", None) + + prompt_token_ids = outputs[0].prompt_token_ids assert prompt_token_ids is not None - bos_token = llm.get_tokenizer().bos_token_id + bos_token = text_llm.get_tokenizer().bos_token_id # Ensure we have a single BOS assert prompt_token_ids[0] == bos_token assert prompt_token_ids[1] != bos_token, "Double BOS" + + +@pytest.fixture(scope="function") +def thinking_llm(): + # pytest caches the fixture so we use weakref.proxy to + # enable garbage collection + llm = LLM( + model="Qwen/Qwen3-0.6B", + max_model_len=4096, + enforce_eager=True, + seed=0, + ) + + with llm.deprecate_legacy_api(): + yield weakref.proxy(llm) + + del llm + + cleanup_dist_env_and_memory() + + +@pytest.mark.parametrize("enable_thinking", [True, False]) +def test_chat_extra_kwargs(thinking_llm, enable_thinking): + messages = [ + { + "role": "system", + "content": "You are a helpful assistant" + }, + { + "role": "user", + "content": "What is 1+1?" + }, + ] + + outputs = thinking_llm.chat( + messages, + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + assert len(outputs) == 1 + + prompt_token_ids = outputs[0].prompt_token_ids + assert prompt_token_ids is not None + + think_id = thinking_llm.get_tokenizer().get_vocab()[""] + + if enable_thinking: + assert think_id not in prompt_token_ids + else: + # The chat template includes dummy thinking process + assert think_id in prompt_token_ids diff --git a/tests/entrypoints/llm/test_collective_rpc.py b/tests/entrypoints/llm/test_collective_rpc.py index d51b7c26344f888f21ccdb35424e40951196c78d..6470249dddbcf7958acb7ea24acaacc8a7c7b050 100644 --- a/tests/entrypoints/llm/test_collective_rpc.py +++ b/tests/entrypoints/llm/test_collective_rpc.py @@ -10,7 +10,7 @@ from ...utils import create_new_process_for_each_test @pytest.mark.parametrize("tp_size", [1, 2]) @pytest.mark.parametrize("backend", ["mp", "ray"]) @create_new_process_for_each_test() -def test_collective_rpc(tp_size, backend): +def test_collective_rpc(tp_size, backend, monkeypatch): if tp_size == 1 and backend == "ray": pytest.skip("Skip duplicate test case") if tp_size == 1: @@ -21,6 +21,7 @@ def test_collective_rpc(tp_size, backend): def echo_rank(self): return self.rank + monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1") llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct", enforce_eager=True, load_format="dummy", diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index ad726fa8ce51835af9ea563d7d768e4ab15cfde5..fdbdccd4654c16ba8b8d06f1ea083af70142cb78 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -16,10 +16,11 @@ from vllm.sampling_params import GuidedDecodingParams, SamplingParams MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" GUIDED_DECODING_BACKENDS = [ - "outlines", - "lm-format-enforcer", - "xgrammar:disable-any-whitespace", - "guidance:disable-any-whitespace", + # (backend, disable_any_whitespace), + ("outlines", False), + ("lm-format-enforcer", False), + ("xgrammar", True), + ("guidance", True), ] @@ -36,13 +37,17 @@ def llm(): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - regex=sample_regex, - backend=guided_decoding_backend)) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_regex(sample_regex, llm, guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + regex=sample_regex, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example IPv4 address with this regex: {sample_regex}" ] * 2, @@ -62,14 +67,18 @@ def test_guided_regex(sample_regex, llm, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_json_completion(sample_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an employee profile " f"that fits this schema: {sample_json_schema}" @@ -92,14 +101,18 @@ def test_guided_json_completion(sample_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_complex_json_completion(sample_complex_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_complex_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_complex_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for an assignment grade " f"that fits this schema: {sample_complex_json_schema}" @@ -123,14 +136,18 @@ def test_guided_complex_json_completion(sample_complex_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_definition_json_completion(sample_definition_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_definition_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_definition_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ f"Give an example JSON for solving 8x + 7 = -23 " f"that fits this schema: {sample_definition_json_schema}" @@ -154,14 +171,18 @@ def test_guided_definition_json_completion(sample_definition_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_enum_json_completion(sample_enum_json_schema, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=sample_enum_json_schema, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=sample_enum_json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate(prompts=[ "Create a bug report JSON that fits this schema: " f"{sample_enum_json_schema}. Make it for a high priority critical bug." @@ -195,14 +216,18 @@ def test_guided_enum_json_completion(sample_enum_json_schema, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_choice_completion(sample_guided_choice, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - guided_decoding=GuidedDecodingParams( - choice=sample_guided_choice, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams( + choice=sample_guided_choice, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, @@ -221,15 +246,19 @@ def test_guided_choice_completion(sample_guided_choice, llm, @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) def test_guided_grammar(sample_sql_statements, llm, - guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - grammar=sample_sql_statements, - backend=guided_decoding_backend)) + guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + grammar=sample_sql_statements, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), @@ -300,7 +329,8 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): top_p=0.95, guided_decoding=GuidedDecodingParams( json=unsupported_json, - backend="xgrammar:no-fallback")) + backend="xgrammar", + disable_fallback=True)) with pytest.raises( ValueError, @@ -312,14 +342,18 @@ def test_disable_guided_decoding_fallback(sample_regex, llm): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_json_object(llm, guided_decoding_backend: str): - sampling_params = SamplingParams(temperature=1.0, - max_tokens=100, - n=2, - guided_decoding=GuidedDecodingParams( - json_object=True, - backend=guided_decoding_backend)) +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_json_object(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + n=2, + guided_decoding=GuidedDecodingParams( + json_object=True, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts=("Generate a JSON object with curly braces for a person with " @@ -337,7 +371,7 @@ def test_guided_json_object(llm, guided_decoding_backend: str): print(generated_text) assert generated_text is not None - if 'disable-any-whitespace' in guided_decoding_backend: + if disable_any_whitespace: assert "\n" not in generated_text # Parse to verify it is valid JSON @@ -359,14 +393,18 @@ class CarDescription(BaseModel): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): json_schema = CarDescription.model_json_schema() - sampling_params = SamplingParams(temperature=1.0, - max_tokens=1000, - guided_decoding=GuidedDecodingParams( - json=json_schema, - backend=guided_decoding_backend)) + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=1000, + guided_decoding=GuidedDecodingParams( + json=json_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace)) outputs = llm.generate( prompts="Generate a JSON with the brand, model and car_type of" "the most iconic car from the 90's", @@ -387,9 +425,10 @@ def test_guided_json_completion_with_enum(llm, guided_decoding_backend: str): @pytest.mark.skip_global_cleanup -@pytest.mark.parametrize("guided_decoding_backend", GUIDED_DECODING_BACKENDS) -def test_guided_number_range_json_completion(llm, - guided_decoding_backend: str): +@pytest.mark.parametrize("guided_decoding_backend,disable_any_whitespace", + GUIDED_DECODING_BACKENDS) +def test_guided_number_range_json_completion(llm, guided_decoding_backend: str, + disable_any_whitespace: bool): sample_output_schema = { "type": "object", "properties": { @@ -413,8 +452,10 @@ def test_guided_number_range_json_completion(llm, sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - guided_decoding=GuidedDecodingParams(json=sample_output_schema, - backend=guided_decoding_backend), + guided_decoding=GuidedDecodingParams( + json=sample_output_schema, + backend=guided_decoding_backend, + disable_any_whitespace=disable_any_whitespace), ) outputs = llm.generate( prompts=[ @@ -466,8 +507,12 @@ def test_guidance_no_additional_properties(llm): "large JSON object with key-value pairs a1=b1, a2=b2, ..., a20=b20" "<|im_end|>\n<|im_start|>assistant\n") - def generate_with_backend(backend): - guided_params = GuidedDecodingParams(json=schema, backend=backend) + def generate_with_backend(backend, disable_additional_properties): + guided_params = GuidedDecodingParams( + json=schema, + backend=backend, + disable_any_whitespace=True, + disable_additional_properties=disable_additional_properties) sampling_params = SamplingParams(temperature=0, max_tokens=256, guided_decoding=guided_params) @@ -481,7 +526,7 @@ def test_guidance_no_additional_properties(llm): jsonschema.validate(instance=parsed_json, schema=schema) return parsed_json - base_generated = generate_with_backend('guidance:disable-any-whitespace') + base_generated = generate_with_backend("guidance", False) assert "a1" in base_generated assert "a2" in base_generated assert "a3" in base_generated @@ -490,8 +535,7 @@ def test_guidance_no_additional_properties(llm): assert "a5" in base_generated assert "a6" in base_generated - generated = generate_with_backend( - 'guidance:no-additional-properties,disable-any-whitespace') + generated = generate_with_backend("guidance", True) assert "a1" in generated assert "a2" in generated assert "a3" in generated diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 72e616656775e6f26db15ae628dfca080befa934..7f959f31201911957d7280817276e2765138bc07 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -272,7 +272,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, chat_completion = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, ) output = chat_completion.choices[0].message.content @@ -282,7 +282,7 @@ async def test_chat_streaming_audio(client: openai.AsyncOpenAI, stream = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, stream=True, ) @@ -332,7 +332,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, chat_completion = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, ) output = chat_completion.choices[0].message.content @@ -342,7 +342,7 @@ async def test_chat_streaming_input_audio(client: openai.AsyncOpenAI, stream = await client.chat.completions.create( model=model_name, messages=messages, - max_completion_tokens=10, + max_completion_tokens=8, temperature=0.0, stream=True, ) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 78e40eeecde13493808b68dbc1414c42e6cbb6a3..f18fbb0a9c71153de56d04872d06fc77d2515b41 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -2,11 +2,13 @@ import pytest +from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (apply_hf_chat_template, load_chat_template) from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer +from ...models.registry import HF_EXAMPLE_MODELS from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" @@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike(): MODEL_TEMPLATE_GENERATON_OUTPUT) def test_get_gen_prompt(model, template, add_generation_prompt, continue_final_message, expected_output): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + # Initialize the tokenizer - tokenizer = get_tokenizer(tokenizer_name=model) + tokenizer = get_tokenizer( + tokenizer_name=model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) template_content = load_chat_template(chat_template=template) # Create a mock request object using keyword arguments @@ -106,10 +122,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Call the function and get the result result = apply_hf_chat_template( - tokenizer, - trust_remote_code=True, + tokenizer=tokenizer, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, + model_config=model_config, tools=None, add_generation_prompt=mock_request.add_generation_prompt, continue_final_message=mock_request.continue_final_message, diff --git a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py index 53df1d9241b7877fafa58687be450dc18629de57..e00f001ef730d925c51dc54b698db6beff807d9e 100644 --- a/tests/entrypoints/openai/test_chat_with_tool_reasoning.py +++ b/tests/entrypoints/openai/test_chat_with_tool_reasoning.py @@ -13,9 +13,9 @@ MODEL_NAME = "Qwen/QwQ-32B" @pytest.fixture(scope="module") def server(): # noqa: F811 args = [ - "--max-model-len", "8192", "--enforce-eager", "--enable-reasoning", - "--reasoning-parser", "deepseek_r1", "--enable-auto-tool-choice", - "--tool-call-parser", "hermes" + "--max-model-len", "8192", "--enforce-eager", "--reasoning-parser", + "deepseek_r1", "--enable-auto-tool-choice", "--tool-call-parser", + "hermes" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: diff --git a/tests/entrypoints/openai/test_classification.py b/tests/entrypoints/openai/test_classification.py new file mode 100644 index 0000000000000000000000000000000000000000..97124c85e0d33ff6bf3fb98f8f3ecc6328a5553f --- /dev/null +++ b/tests/entrypoints/openai/test_classification.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ClassificationResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach" +DTYPE = "float32" # Use float32 to avoid NaN issue + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + "--max-model-len", + "512", + "--dtype", + DTYPE, + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_single_input_classification(server: RemoteOpenAIServer, + model_name: str): + input_text = "This product was excellent and exceeded my expectations" + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_text + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert output.model == MODEL_NAME + assert len(output.data) == 1 + assert hasattr(output.data[0], "label") + assert hasattr(output.data[0], "probs") + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_multiple_inputs_classification(server: RemoteOpenAIServer, + model_name: str): + input_texts = [ + "The product arrived on time and works perfectly", + "I'm very satisfied with my purchase, would buy again", + "The customer service was helpful and resolved my issue quickly", + "This product broke after one week, terrible quality", + "I'm very disappointed with this purchase, complete waste of money", + "The customer service was rude and unhelpful", + ] + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": input_texts + }, + ) + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == len(input_texts) + for i, item in enumerate(output.data): + assert item.index == i + assert hasattr(item, "label") + assert hasattr(item, "probs") + assert len(item.probs) == item.num_classes + assert item.label in ["Default", "Spoiled"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_truncate_prompt_tokens(server: RemoteOpenAIServer, model_name: str): + long_text = "hello " * 600 + + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": long_text, + "truncate_prompt_tokens": 5 + }, + ) + + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert len(output.data) == 1 + assert output.data[0].index == 0 + assert hasattr(output.data[0], "probs") + assert output.usage.prompt_tokens == 5 + assert output.usage.total_tokens == 5 + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_invalid_truncate_prompt_tokens_error(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "test", + "truncate_prompt_tokens": 513 + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + assert "truncate_prompt_tokens" in error["message"] + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_empty_input_error(server: RemoteOpenAIServer, model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": "" + }, + ) + + error = classification_response.json() + assert classification_response.status_code == 400 + assert error["object"] == "error" + + +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +def test_batch_classification_empty_list(server: RemoteOpenAIServer, + model_name: str): + classification_response = requests.post( + server.url_for("classify"), + json={ + "model": model_name, + "input": [] + }, + ) + classification_response.raise_for_status() + output = ClassificationResponse.model_validate( + classification_response.json()) + + assert output.object == "list" + assert isinstance(output.data, list) + assert len(output.data) == 0 diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index e0285b5e556646a6e5378b99ba413a0f95275d94..8d1abe28a027ad7bfa2282036fe4345ae660d540 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -122,31 +122,23 @@ def test_enable_auto_choice_fails_with_enable_reasoning(serve_parser): """Ensure validation fails if reasoning is enabled with auto tool choice""" args = serve_parser.parse_args(args=[ "--enable-auto-tool-choice", - "--enable-reasoning", + "--reasoning-parser", + "deepseek_r1", ]) with pytest.raises(TypeError): validate_parsed_serve_args(args) -def test_enable_reasoning_passes_with_reasoning_parser(serve_parser): +def test_passes_with_reasoning_parser(serve_parser): """Ensure validation passes if reasoning is enabled with a reasoning parser""" args = serve_parser.parse_args(args=[ - "--enable-reasoning", "--reasoning-parser", "deepseek_r1", ]) validate_parsed_serve_args(args) -def test_enable_reasoning_fails_without_reasoning_parser(serve_parser): - """Ensure validation fails if reasoning is enabled - without a reasoning parser""" - args = serve_parser.parse_args(args=["--enable-reasoning"]) - with pytest.raises(TypeError): - validate_parsed_serve_args(args) - - def test_chat_template_validation_for_happy_paths(serve_parser): """Ensure validation passes if the chat template exists""" args = serve_parser.parse_args( diff --git a/tests/entrypoints/openai/test_completion_with_function_calling.py b/tests/entrypoints/openai/test_completion_with_function_calling.py new file mode 100644 index 0000000000000000000000000000000000000000..dad76b54c5e99ab9894803e546d465a6f324576e --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_function_calling.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +# downloading lora to test lora requests +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" + + +@pytest.fixture(scope="module") +def server(): # noqa: F811 + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "half", + "--enable-auto-tool-choice", + "--guided-decoding-backend", + "xgrammar", + "--tool-call-parser", + "hermes" + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_required_tool_use(client: openai.AsyncOpenAI, model_name: str): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to find the weather for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "unit"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_forecast", + "description": "Get the weather forecast for a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": + "The city to get the forecast for, e.g. 'Vienna'", + "default": "Vienna", + }, + "country": { + "type": + "string", + "description": + "The country that the city is in, e.g. 'Austria'", + }, + "days": { + "type": + "integer", + "description": + "Number of days to get the forecast for (1-7)", + }, + "unit": { + "type": "string", + "description": + "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["country", "days", "unit"], + }, + }, + }, + ] + + messages = [ + { + "role": "user", + "content": "Hi! How are you doing today?" + }, + { + "role": "assistant", + "content": "I'm doing well! How can I help you?" + }, + { + "role": + "user", + "content": + "Can you tell me what the current weather is in Berlin and the "\ + "forecast for the next 5 days, in fahrenheit?", + }, + ] + + # Non-streaming test + chat_completion = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + ) + + assert chat_completion.choices[0].message.tool_calls is not None + assert len(chat_completion.choices[0].message.tool_calls) > 0 + + # Streaming test + stream = await client.chat.completions.create( + messages=messages, + model=model_name, + tools=tools, + tool_choice="required", + stream=True, + ) + + output = [] + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.tool_calls: + output.extend(chunk.choices[0].delta.tool_calls) + + assert len(output) > 0 diff --git a/tests/entrypoints/openai/test_completion_with_prompt_embeds.py b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py new file mode 100644 index 0000000000000000000000000000000000000000..b7ee3e33c2d250fd3ac3f5f2cfb6f1f9cf19a096 --- /dev/null +++ b/tests/entrypoints/openai/test_completion_with_prompt_embeds.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 + +import base64 +import io +import shutil +from tempfile import TemporaryDirectory + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio +import torch +# downloading lora to test lora requests +from huggingface_hub import snapshot_download +from openai import BadRequestError +from transformers import AutoConfig, AutoTokenizer + +from ...utils import RemoteOpenAIServer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +LORA_NAME = "typeof/zephyr-7b-beta-lora" + +CONFIG = AutoConfig.from_pretrained(MODEL_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_files(): + return snapshot_download(repo_id=LORA_NAME) + + +@pytest.fixture(scope="module") +def zephyr_lora_added_tokens_files(zephyr_lora_files): + tmp_dir = TemporaryDirectory() + tmp_model_dir = f"{tmp_dir.name}/zephyr" + shutil.copytree(zephyr_lora_files, tmp_model_dir) + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + # Copy tokenizer to adapter and add some unique tokens + # 32000, 32001, 32002 + added = tokenizer.add_tokens(["vllm1", "vllm2", "vllm3"], + special_tokens=True) + assert added == 3 + tokenizer.save_pretrained(tmp_model_dir) + yield tmp_model_dir + tmp_dir.cleanup() + + +@pytest.fixture(scope="module") +def default_server_args( + zephyr_lora_files, + zephyr_lora_added_tokens_files, +) -> list[str]: + return [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--max-num-seqs", + "128", + "--enforce-eager", + # Prompt Embeds server args + "--enable-prompt-embeds", + "--no-enable-chunked-prefill", + ] + + +@pytest.fixture(scope="module", + params=["", "--disable-frontend-multiprocessing"]) +def server_with_prompt_embeds(default_server_args, request): + if request.param: + default_server_args.append(request.param) + + with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client_with_prompt_embeds(server_with_prompt_embeds): + async with server_with_prompt_embeds.get_async_client() as async_client: + yield async_client + + +def create_dummy_embeds(num_tokens: int = 5) -> str: + """Create dummy embeddings and return them as base64 encoded string.""" + dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size) + buffer = io.BytesIO() + torch.save(dummy_embeds, buffer) + return base64.b64encode(buffer.getvalue()).decode('utf-8') + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test case: Single prompt embeds input + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices[0].text) >= 1 + assert completion.choices[0].prompt_logprobs is None + + # Test case: batch completion with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + assert len(completion.choices) == 2 + assert len(completion.choices[0].text) >= 1 + assert len(completion.choices[1].text) >= 1 + + # Test case: streaming with prompt_embeds + encoded_embeds = create_dummy_embeds() + single_completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + single_output = single_completion.choices[0].text + + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": encoded_embeds}) + chunks = [] + finish_reason_count = 0 + async for chunk in stream: + chunks.append(chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert "".join(chunks) == single_output + + # Test case: batch streaming with prompt_embeds + encoded_embeds2 = create_dummy_embeds() + stream = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + stream=True, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + chunks_stream_embeds: list[list[str]] = [[], []] + finish_reason_count = 0 + async for chunk in stream: + chunks_stream_embeds[chunk.choices[0].index].append( + chunk.choices[0].text) + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert finish_reason_count == 2 + assert chunk.choices[0].finish_reason == "length" + assert chunk.choices[0].text + assert len(chunks_stream_embeds[0]) > 0 + assert len(chunks_stream_embeds[1]) > 0 + + # Test case: mixed text and prompt_embeds + encoded_embeds = create_dummy_embeds() + completion_mixed = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + assert len(completion.choices) == 2 + completion_text_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="This is a prompt", + max_tokens=5, + temperature=0.0, + ) + completion_embeds_only = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": encoded_embeds}) + # Embeddings responses should be handled first + assert completion_mixed.choices[0].text == completion_embeds_only.choices[ + 0].text + assert completion_mixed.choices[1].text == completion_text_only.choices[ + 0].text + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_errors_with_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str): + # Test error case: invalid prompt_embeds + with pytest.raises(BadRequestError): + await client_with_prompt_embeds.completions.create( + prompt="", + model=model_name, + max_tokens=5, + temperature=0.0, + extra_body={"prompt_embeds": "invalid_base64"}) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("logprobs_arg", [1, 0]) +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_completions_with_logprobs_and_prompt_embeds( + client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int, + model_name: str): + # Test case: Logprobs using prompt_embeds + encoded_embeds = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": encoded_embeds}) + + logprobs = completion.choices[0].logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 + + # Test case: Log probs with batch completion and prompt_embeds + encoded_embeds2 = create_dummy_embeds() + completion = await client_with_prompt_embeds.completions.create( + model=model_name, + prompt="", # Add empty prompt as required parameter + max_tokens=5, + temperature=0.0, + echo=False, + logprobs=logprobs_arg, + extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]}) + + assert len(completion.choices) == 2 + for choice in completion.choices: + logprobs = choice.logprobs + assert logprobs is not None + assert len(logprobs.text_offset) == 5 + assert len(logprobs.token_logprobs) == 5 + assert len(logprobs.top_logprobs) == 5 + for top_logprobs in logprobs.top_logprobs[1:]: + assert max(logprobs_arg, + 1) <= len(top_logprobs) <= logprobs_arg + 1 + assert len(logprobs.tokens) == 5 diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 50b20e78c4c420ec0306c690f4b3bd38360c0e05..1019bfd589362eb6205eb92e0e34cef4452b9715 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -11,7 +11,7 @@ import requests from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.embedding.utils import correctness_test +from ...models.utils import run_embedding_correctness_test from ...utils import RemoteOpenAIServer MODEL_NAME = "intfloat/multilingual-e5-small" @@ -76,7 +76,7 @@ async def test_single_embedding(hf_model, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 11 vllm_outputs = [d.embedding for d in embeddings.data] - correctness_test(hf_model, input_texts, vllm_outputs) + run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) # test using token IDs input_tokens = [1, 1, 1, 1, 1] @@ -121,7 +121,7 @@ async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI, assert embeddings.usage.total_tokens == 33 vllm_outputs = [d.embedding for d in embeddings.data] - correctness_test(hf_model, input_texts, vllm_outputs) + run_embedding_correctness_test(hf_model, input_texts, vllm_outputs) # test list[list[int]] input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24], @@ -208,7 +208,7 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, model=model_name, encoding_format="float") float_data = [d.embedding for d in responses_float.data] - correctness_test(hf_model, input_texts, float_data) + run_embedding_correctness_test(hf_model, input_texts, float_data) responses_base64 = await client.embeddings.create(input=input_texts, model=model_name, @@ -219,13 +219,13 @@ async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI, np.frombuffer(base64.b64decode(data.embedding), dtype="float32").tolist()) - correctness_test(hf_model, input_texts, base64_data) + run_embedding_correctness_test(hf_model, input_texts, base64_data) # Default response is float32 decoded from base64 by OpenAI Client responses_default = await client.embeddings.create(input=input_texts, model=model_name) default_data = [d.embedding for d in responses_default.data] - correctness_test(hf_model, input_texts, default_data) + run_embedding_correctness_test(hf_model, input_texts, default_data) @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py index 9f5a8c6839bc550eae827149342def6bb3ea74e1..332fa332a4a41f90435c11c95384783f7ba712ad 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -11,7 +11,7 @@ import pytest from vllm.entrypoints.openai.protocol import EmbeddingResponse from ...conftest import HfRunner -from ...models.embedding.utils import EmbedModelInfo, correctness_test +from ...models.utils import EmbedModelInfo, run_embedding_correctness_test from ...utils import RemoteOpenAIServer MODELS = [ @@ -95,7 +95,8 @@ async def test_matryoshka(model_info: EmbedModelInfo, assert len(embeddings.data[0].embedding) == dimensions vllm_outputs = [d.embedding for d in embeddings.data] - correctness_test(hf_model, prompts, vllm_outputs, dimensions) + run_embedding_correctness_test(hf_model, prompts, vllm_outputs, + dimensions) if model_info.is_matryoshka: valid_dimensions: list[Optional[int]] = [None] diff --git a/tests/entrypoints/openai/test_openai_schema.py b/tests/entrypoints/openai/test_openai_schema.py index 1ccb803a328d608bf1fbf0cd7731c4dd935ef4d7..5c585d54c429b22e78f5a5bcce5353aafbc6a6ab 100644 --- a/tests/entrypoints/openai/test_openai_schema.py +++ b/tests/entrypoints/openai/test_openai_schema.py @@ -44,6 +44,6 @@ schema = schemathesis.from_pytest_fixture("get_schema") @schema.parametrize() @schema.override(headers={"Content-Type": "application/json"}) -async def test_openapi_stateless(case): +def test_openapi_stateless(case: schemathesis.Case): #No need to verify SSL certificate for localhost - await case.call_and_validate(verify=False) + case.call_and_validate(verify=False) diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 19d16713b209f2eeffcb41f613e7ec91af04bef6..5e11af8cf89294ae49f3138a65873bf1bb59b7ff 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -272,3 +272,43 @@ def test_serving_chat_could_load_correct_generation_config(): assert mock_engine.generate.call_args.args[1].temperature == 0.0 assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05 + + +def test_serving_chat_did_set_correct_cache_salt(): + mock_model_config = MockModelConfig() + + mock_engine = MagicMock(spec=MQLLMEngineClient) + mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME) + mock_engine.errored = False + + # Initialize the serving chat + models = OpenAIServingModels(engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + model_config=mock_model_config) + serving_chat = OpenAIServingChat(mock_engine, + mock_model_config, + models, + response_role="assistant", + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None) + + # Test cache_salt + req = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?" + }], + ) + + # By default cache_salt in the engine prompt is not set + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert "cache_salt" not in mock_engine.generate.call_args.args[0] + + # Test with certain cache_salt + req.cache_salt = "test_salt" + with suppress(Exception): + asyncio.run(serving_chat.create_chat_completion(req)) + assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt" diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 663b722426c5860235ac7eda26594b82b36fde8c..9773f3e45b99c452d3432b9c44e9a44221cada4c 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -145,6 +145,83 @@ async def test_tokenize_chat( } +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_tokenize_chat_with_tools( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": + "user", + "content": + "What's the weather like in Paris today?", + }] + + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + }, + }, + }] + + for continue_final in [False, True]: + if add_generation and continue_final: + continue + if continue_final: + conversation.append({ + "role": "assistant", + "content": "Sure," + }) + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + continue_final_message=continue_final, + conversation=conversation, + tools=tools, + tokenize=False, + ) + tokens = tokenizer.encode(prompt, + add_special_tokens=add_special) + + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + "tools": tools, + }, + ) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192, + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", diff --git a/tests/entrypoints/openai/test_truncation.py b/tests/entrypoints/openai/test_truncation.py new file mode 100644 index 0000000000000000000000000000000000000000..137ed9db85891ba93613078ac3e08b6ab308e20e --- /dev/null +++ b/tests/entrypoints/openai/test_truncation.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import openai +import pytest +import pytest_asyncio + +from tests.utils import RemoteOpenAIServer + +MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" +max_model_len = 128 + +input = """Immerse yourself in the enchanting chronicle of calculus, a + mathematical domain that has radically transformed our comprehension of + change and motion. Despite its roots in ancient civilizations, the + formal birth of calculus predominantly occurred in the 17th century, + primarily under the influential guidance of Sir Isaac Newton and Gottfried + Wilhelm Leibniz. The earliest traces of calculus concepts are found in + ancient Greek mathematics,most notably in the works of Eudoxus and + Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a + technique for computing areas and volumes through the use of finite sums. + This methodology laid crucial foundational work for integral calculus. + In the 17th century, both Newton and Leibniz independently pioneered + calculus, each contributing unique perspectives that would shape this new + field.""" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--task", + "embed", + "--dtype", + "bfloat16", + "--enforce-eager", + "--max-model-len", + str(max_model_len), + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +@pytest.mark.asyncio +async def test_smaller_truncation_size(client: openai.AsyncOpenAI): + truncation_size = 10 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == truncation_size + + +@pytest.mark.asyncio +async def test_bigger_truncation_size(client: openai.AsyncOpenAI): + truncation_size = max_model_len + 1 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + with pytest.raises(openai.BadRequestError) as err: + err = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert str(err) == f"""openai.BadRequestError: + Error code: 400 - {{'object': 'error', + 'message': 'truncate_prompt_tokens value + ({truncation_size}) + is greater than max_model_len ({max_model_len}). + Please, select a smaller truncation size.', + 'type': 'BadRequestError', + 'param': None, 'code': 400}}""" + + +@pytest.mark.asyncio +async def test_max_truncation_size(client: openai.AsyncOpenAI): + truncation_size = -1 + kwargs: dict[str, Any] = { + "model": MODEL_NAME, + "input": input, + "truncate_prompt_tokens": truncation_size + } + + response = await client.post(path="embeddings", + cast_to=object, + body={**kwargs}) + + assert response["usage"]["prompt_tokens"] == max_model_len diff --git a/tests/entrypoints/openai/tool_parsers/utils.py b/tests/entrypoints/openai/tool_parsers/utils.py index 6ad5aa26ffa14a991e3652a97e985fdaaa95a76f..ab8f4bd678fdfd6d0cb4a902fa433806811b3751 100644 --- a/tests/entrypoints/openai/tool_parsers/utils.py +++ b/tests/entrypoints/openai/tool_parsers/utils.py @@ -32,7 +32,7 @@ class StreamingToolReconstructor: assert len(delta.tool_calls) < 2, ( "Streaming should include only one tool call per update.") for call_delta in delta.tool_calls: - assert call_delta.type == "function", ( + assert call_delta.type is None or call_delta.type == "function", ( "Streaming tool calls should only emit function calls. Got " f"{call_delta.type}") current_tool_call = self.tool_calls[ diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 92c1e0fec6b743154e0bc6e322614136a7c410c3..9f1f2321d9e648dabdb291997ae97d850fab1be3 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -4,8 +4,6 @@ import warnings from typing import Optional import pytest -from packaging.version import Version -from transformers import __version__ as TRANSFORMERS_VERSION from vllm.assets.image import ImageAsset from vllm.config import ModelConfig @@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH EXAMPLES_DIR = VLLM_PATH / "examples" @@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -793,10 +793,10 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ) vllm_result = apply_hf_chat_template( - tokenizer, - trust_remote_code=model_config.trust_remote_code, + tokenizer=tokenizer, conversation=conversation, chat_template=None, + model_config=model_config, tools=None, add_generation_prompt=True, ) @@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -837,7 +848,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): tokenizer, chat_template=None, tools=tools, - trust_remote_code=True, + model_config=model_config, ) assert isinstance(chat_template, str) @@ -857,15 +868,23 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) # yapf: enable def test_resolve_content_format_hf_defined(model, expected_format): - if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( - "4.49.0"): - pytest.skip("Qwen2.5-VL requires transformers>=4.49.0") + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) tokenizer_group = TokenizerGroup( model, enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -874,7 +893,7 @@ def test_resolve_content_format_hf_defined(model, expected_format): tokenizer, chat_template=None, tools=None, - trust_remote_code=True, + model_config=model_config, ) assert isinstance(chat_template, str) @@ -888,7 +907,66 @@ def test_resolve_content_format_hf_defined(model, expected_format): None, "auto", tokenizer, - trust_remote_code=True, + model_config=model_config, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("microsoft/Florence-2-base", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string")], +) +# yapf: enable +def test_resolve_content_format_fallbacks(model, expected_format): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + + tokenizer_group = TokenizerGroup( + model_config.tokenizer, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + trust_remote_code=model_config.trust_remote_code, + ) + tokenizer = tokenizer_group.tokenizer + + # Test detecting the tokenizer's chat_template + chat_template = resolve_hf_chat_template( + tokenizer, + chat_template=None, + tools=None, + model_config=model_config, + ) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + None, # Test detecting the tokenizer's chat_template + None, + "auto", + tokenizer, + model_config=model_config, ) assert resolved_format == expected_format @@ -899,17 +977,13 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("template_path", "expected_format"), [("template_alpaca.jinja", "string"), ("template_baichuan.jinja", "string"), - ("template_blip2.jinja", "string"), ("template_chatglm.jinja", "string"), ("template_chatglm2.jinja", "string"), ("template_chatml.jinja", "string"), - ("template_deepseek_vl2.jinja", "string"), ("template_dse_qwen2_vl.jinja", "openai"), ("template_falcon_180b.jinja", "string"), ("template_falcon.jinja", "string"), - ("template_florence2.jinja", "string"), ("template_inkbot.jinja", "string"), - ("template_llava.jinja", "string"), ("template_teleflm.jinja", "string"), ("template_vlm2vec.jinja", "openai"), ("tool_chat_template_granite_20b_fc.jinja", "string"), @@ -922,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): ) # yapf: enable def test_resolve_content_format_examples(template_path, expected_format): + model_config = ModelConfig( + PHI3V_MODEL_ID, # Dummy + tokenizer=PHI3V_MODEL_ID, # Dummy + trust_remote_code=True, + ) + tokenizer_group = TokenizerGroup( - PHI3V_MODEL_ID, + PHI3V_MODEL_ID, # Dummy enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None @@ -944,7 +1025,7 @@ def test_resolve_content_format_examples(template_path, expected_format): None, "auto", dummy_tokenizer, - trust_remote_code=True, + model_config=model_config, ) assert resolved_format == expected_format diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index b0414244c2151c8139b8d6faf0416b1507997327..58da01f0ebbf3f08bacc2a0f39299c114d645d72 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -102,7 +102,10 @@ def test_env( block_size, False, use_mla=use_mla) - assert backend.get_name() == name + if use_v1 and name != "TRITON_MLA": + assert backend.get_name() == f"{name}_VLLM_V1" + else: + assert backend.get_name() == name else: with pytest.raises(ValueError) as exc_info: get_attn_backend(16, @@ -185,8 +188,9 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, STR_FLASH_ATTN_VAL) # Unsupported CUDA arch - monkeypatch.setattr(torch.cuda, "get_device_capability", lambda: - (7, 5)) + monkeypatch.setattr(torch.cuda, + "get_device_capability", + lambda _=None: (7, 5)) backend = get_attn_backend(16, torch.float16, None, 16, False) assert backend.get_name() != STR_FLASH_ATTN_VAL diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index 3985c6834f60e57d1a97f4e7ce85cb45c4c70d41..0d51a8e7fee19a0271f68448dd83db5913304366 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -5,11 +5,11 @@ import random import pytest import torch -import triton from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, get_mla_metadata, is_flashmla_supported) +from vllm.triton_utils import triton def cal_diff(x: torch.Tensor, y: torch.Tensor, name: str) -> None: diff --git a/tests/kernels/attention/test_rocm_attention_selector.py b/tests/kernels/attention/test_rocm_attention_selector.py index 4cf7bcb01d4d7641d9fe12a3ad53c76a1674689a..6ffe27abf709e992de426eb7142e793bc086c740 100644 --- a/tests/kernels/attention/test_rocm_attention_selector.py +++ b/tests/kernels/attention/test_rocm_attention_selector.py @@ -48,7 +48,8 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") # If attention backend is None # If use_mla is true @@ -58,4 +59,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch): m.setenv("VLLM_ROCM_USE_AITER", "1") backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, False, True) - assert backend.get_name() == "ROCM_AITER_MLA" + assert (backend.get_name() == "ROCM_AITER_MLA" + or backend.get_name() == "ROCM_AITER_MLA_VLLM_V1") diff --git a/tests/kernels/attention/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4e15d00255a4ff3f7b8b8fefc99c05cca46e7439 --- /dev/null +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import pytest +import torch + +from vllm.attention.ops.triton_unified_attention import unified_attention +from vllm.platforms import current_platform + +NUM_HEADS = [(4, 4), (8, 2), (16, 2)] +HEAD_SIZES = [128, 256] +BLOCK_SIZES = [16, 32] + +DTYPES = [torch.float16, torch.bfloat16] +QDTYPES = [None, torch.float8_e4m3fn] +# one value large enough to test overflow in index calculation. +# one value small enough to test the schema op check +NUM_BLOCKS = [32768, 2048] + + +def ref_paged_attn( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + query_lens: list[int], + kv_lens: list[int], + block_tables: torch.Tensor, + scale: float, + sliding_window: Optional[int] = None, + soft_cap: Optional[float] = None, +) -> torch.Tensor: + num_seqs = len(query_lens) + block_tables = block_tables.cpu().numpy() + _, block_size, num_kv_heads, head_size = key_cache.shape + + outputs: list[torch.Tensor] = [] + start_idx = 0 + for i in range(num_seqs): + query_len = query_lens[i] + kv_len = kv_lens[i] + q = query[start_idx:start_idx + query_len] + q *= scale + + num_kv_blocks = (kv_len + block_size - 1) // block_size + block_indices = block_tables[i, :num_kv_blocks] + + k = key_cache[block_indices].view(-1, num_kv_heads, head_size) + k = k[:kv_len] + v = value_cache[block_indices].view(-1, num_kv_heads, head_size) + v = v[:kv_len] + + if q.shape[1] != k.shape[1]: + k = torch.repeat_interleave(k, q.shape[1] // k.shape[1], dim=1) + v = torch.repeat_interleave(v, q.shape[1] // v.shape[1], dim=1) + attn = torch.einsum("qhd,khd->hqk", q, k).float() + empty_mask = torch.ones(query_len, kv_len) + mask = torch.triu(empty_mask, diagonal=kv_len - query_len + 1).bool() + if sliding_window is not None: + sliding_window_mask = torch.triu(empty_mask, + diagonal=kv_len - + (query_len + sliding_window) + + 1).bool().logical_not() + mask |= sliding_window_mask + if soft_cap is not None and soft_cap > 0: + attn = soft_cap * torch.tanh(attn / soft_cap) + attn.masked_fill_(mask, float("-inf")) + attn = torch.softmax(attn, dim=-1).to(v.dtype) + out = torch.einsum("hqk,khd->qhd", attn, v) + + outputs.append(out) + start_idx += query_len + + return torch.cat(outputs, dim=0) + + +@pytest.mark.parametrize("seq_lens", + [[(1, 1328), (5, 18), + (129, 463)], [(1, 523), (1, 37), (1, 2011)]]) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("sliding_window", [None, 256]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0]) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS) +@pytest.mark.parametrize("q_dtype", QDTYPES) +@torch.inference_mode() +def test_triton_unified_attn( + seq_lens: list[tuple[int, int]], + num_heads: tuple[int, int], + head_size: int, + sliding_window: Optional[int], + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], + num_blocks: int, + q_dtype: Optional[torch.dtype], +) -> None: + torch.set_default_device("cuda") + + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + + current_platform.seed_everything(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_query_len = max(query_lens) + max_kv_len = max(kv_lens) + window_size = ((sliding_window - 1, 0) if sliding_window is not None else + (-1, -1)) + scale = head_size**-0.5 + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + key_cache = torch.randn(num_blocks, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + value_cache = torch.randn_like(key_cache) + cu_query_lens = torch.tensor([0] + query_lens, + dtype=torch.int32).cumsum(dim=0, + dtype=torch.int32) + kv_lens = torch.tensor(kv_lens, dtype=torch.int32) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + num_blocks, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + output = torch.empty_like(query) + + maybe_quantized_query = query + maybe_quantized_key_cache = key_cache + maybe_quantized_value_cache = value_cache + q_descale = None + k_descale = None + v_descale = None + if q_dtype is not None: + # QKV are drawn from N(0, 1): no need for a fp8 scaling factor + maybe_quantized_query = query.to(q_dtype) + maybe_quantized_key_cache = key_cache.to(q_dtype) + maybe_quantized_value_cache = value_cache.to(q_dtype) + + scale_shape = (num_seqs, num_kv_heads) + q_descale = None # Not yet supported + k_descale = torch.rand(scale_shape, dtype=torch.float32) + v_descale = torch.rand(scale_shape, dtype=torch.float32) + + unified_attention( + q=maybe_quantized_query, + k=maybe_quantized_key_cache, + v=maybe_quantized_value_cache, + out=output, + cu_seqlens_q=cu_query_lens, + seqused_k=kv_lens, + max_seqlen_q=max_query_len, + max_seqlen_k=max_kv_len, + softmax_scale=scale, + causal=True, + window_size=window_size, + block_table=block_tables, + softcap=soft_cap if soft_cap is not None else 0, + q_descale=q_descale, + k_descale=k_descale, + v_descale=v_descale, + ) + + ref_output = ref_paged_attn( + query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + sliding_window=sliding_window, + soft_cap=soft_cap, + ) + atol, rtol = 1.5e-2, 1e-2 + if q_dtype is not None: + atol, rtol = 1.5e-1, 1.5e-1 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index 2b7bf755ec22d98e6970a9664201b50179bc4a6f..f327deb0e549ec74d6d8f0eda45b57e18b152529 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -21,6 +21,7 @@ SEEDS = [0] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +USE_KEY = [True, False] def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, @@ -28,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, return (batch_size, seq_len, num_heads * head_size) +# For testing sliced tensors +def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size + 64) + + def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, head_size: int) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) -TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] +TENSORS_SHAPES_FN = [ + _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape +] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -46,6 +55,7 @@ TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_rotary_embedding( is_neox_style: bool, @@ -58,6 +68,7 @@ def test_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -74,7 +85,11 @@ def test_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None + + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -85,10 +100,14 @@ def test_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -101,6 +120,7 @@ def test_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding( is_neox_style: bool, @@ -113,6 +133,7 @@ def test_batched_rotary_embedding( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -129,7 +150,11 @@ def test_batched_rotary_embedding( positions = torch.randint(0, max_position, (batch_size, seq_len)) query_shape = tensor_shape_fn(batch_size, seq_len, num_heads, head_size) query = torch.randn(query_shape, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None + + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. @@ -145,10 +170,14 @@ def test_batched_rotary_embedding( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -160,6 +189,7 @@ def test_batched_rotary_embedding( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_key", USE_KEY) @torch.inference_mode() def test_batched_rotary_embedding_multi_lora( is_neox_style: bool, @@ -171,6 +201,7 @@ def test_batched_rotary_embedding_multi_lora( dtype: torch.dtype, seed: int, device: str, + use_key: bool, max_position: int = 8192, base: int = 10000, ) -> None: @@ -190,7 +221,7 @@ def test_batched_rotary_embedding_multi_lora( seq_len, num_heads * head_size, dtype=dtype) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None offset_map = torch.tensor( list( @@ -214,10 +245,14 @@ def test_batched_rotary_embedding_multi_lora( ref_query, atol=get_default_atol(out_query), rtol=get_default_rtol(out_query)) - torch.testing.assert_close(out_key, - ref_key, - atol=get_default_atol(out_key), - rtol=get_default_rtol(out_key)) + if use_key: + torch.testing.assert_close(out_key, + ref_key, + atol=get_default_atol(out_key), + rtol=get_default_rtol(out_key)) + else: + assert ref_key is None and out_key is None, \ + "expected returned key to be None" @torch.inference_mode() diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index c497dd90edda863ea16d59aab26718e10856763f..8383f943b9fa4705212982ed3c16b5d3916ac470 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -15,7 +15,7 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding def rotary_embedding_opcheck(rot, positions: torch.Tensor, query: torch.Tensor, - key: torch.Tensor, + key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None): cos_sin_cache = rot.cos_sin_cache.to(query.device, dtype=query.dtype) @@ -37,9 +37,11 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("rotary_dim", [32]) @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) +@pytest.mark.parametrize("use_key", [True, False]) +@pytest.mark.parametrize("head_stride_is_contingous", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len): + seq_len, use_key, head_stride_is_contingous): batch_size = 1 base = 10000 num_heads = 7 @@ -49,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + head_stride = head_size + (64 if head_stride_is_contingous else 0) + query = torch.randn(batch_size, seq_len, - num_heads * head_size, + num_heads, + head_stride, dtype=torch.float32, device=device) - key = torch.randn_like(query) + key = torch.randn_like(query) if use_key else None + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, device=device, dtype=torch.long) rotary_embedding_opcheck(rot, positions, query, key, offsets) + + # if we have a contiguous head stride, test the alternate + # [..., num_heads * head_dim] shape/layout + if head_stride_is_contingous: + rotary_embedding_opcheck( + rot, positions, query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None) diff --git a/tests/kernels/mamba/test_mamba_ssm_ssd.py b/tests/kernels/mamba/test_mamba_ssm_ssd.py index ee908105f557fa961a18864e3e7b1e6989758c0d..f5e751bea4149fdc7d4685808ace79b16ad21b42 100644 --- a/tests/kernels/mamba/test_mamba_ssm_ssd.py +++ b/tests/kernels/mamba/test_mamba_ssm_ssd.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from vllm.model_executor.layers.mamba.mamba2_metadata import ( - _seq_idx_to_chunk_indices_offsets) + _query_start_loc_to_chunk_indices_offsets) from vllm.model_executor.layers.mamba.ops.ssd_combined import ( mamba_chunk_scan_combined) from vllm.platforms import current_platform @@ -274,8 +274,9 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, last_taken, exhausted, n_heads, d_head, itype): - chunk_indices, chunk_offsets = _seq_idx_to_chunk_indices_offsets( - seq_idx, chunk_size) + chunk_indices, chunk_offsets = \ + _query_start_loc_to_chunk_indices_offsets( + cu_seqlens, chunk_size, cu_seqlens[-1]) Y, new_states = mamba_chunk_scan_combined( X, diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..7d369edfc86a4393ecc1cbc243ee83439b8db084 --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import pytest +import torch +import triton.language as tl + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 10 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", + [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("K", [128, 256, 1024]) +@pytest.mark.parametrize("N", [128, 256, 512, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[test_output.dtype] + + torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a171f28d01fc36f7fa2bba16d3707e58..7db4fe0f46e3fab44a9637cd75453eaa2b7482ff 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ MNK_FACTORS = [ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 425f36984a33be28bea56fe3597aece64039fb69..43ddc79fcb8182937b3cc82a8ade5fc896b0ca64 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -11,24 +11,32 @@ from transformers import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, - torch_moe_single) +from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( fused_moe as iterative_moe) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + rand_marlin_weight_fp4_like) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( awq_marlin_quantize, marlin_quantize) from vllm.model_executor.layers.quantization.utils.quant_utils import ( quantize_weights) from vllm.model_executor.models.mixtral import MixtralMoE from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types +from vllm.scalar_type import ScalarType, scalar_types NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -67,31 +75,33 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) - - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -112,7 +122,6 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -191,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -221,9 +232,16 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, """Make sure our Mixtral MoE implementation agrees with the one from huggingface.""" + # clear the cache before every test + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + is_rocm_aiter_moe_enabled) + is_rocm_aiter_moe_enabled.cache_clear() if use_rocm_aiter: monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + if dtype == torch.float32: + pytest.skip("AITER ROCm test skip for float32") + # Instantiate our and huggingface's MoE blocks config = MixtralConfig() hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") @@ -285,18 +303,64 @@ def test_mixtral_moe(dtype: torch.dtype, padding: bool, use_rocm_aiter: bool, atol=mixtral_moe_tol[dtype]) -@pytest.mark.parametrize("m", [1, 33, 123]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("ep_size", [1, 4]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("has_zp", [True, False]) -@pytest.mark.parametrize("is_k_full", [True, False]) +def marlin_moe_generate_valid_test_cases(): + import itertools + m_list = [1, 123, 666] + n_list = [128, 1024] + k_list = [256, 2048] + e_list = [4, 12] + topk_list = [2, 3] + ep_size_list = [1, 4] + dtype_list = [torch.half, torch.bfloat16] + group_size_list = [-1, 16, 32, 128] + act_order_list = [True, False] + quant_type_list = [ + scalar_types.float4_e2m1f, + scalar_types.float8_e4m3fn, + scalar_types.uint4, + scalar_types.uint4b8, + scalar_types.uint8b128, + ] + is_k_full_list = [True, False] + + all_combinations = itertools.product(m_list, n_list, k_list, e_list, + topk_list, ep_size_list, dtype_list, + group_size_list, act_order_list, + quant_type_list, is_k_full_list) + + def is_invalid(m, n, k, e, topk, ep_size, dtype, group_size, act_order, + quant_type, is_k_full): + + if quant_type == scalar_types.float8_e4m3fn and \ + group_size not in [-1, 128]: + return False + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return False + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return False + + # Filter act_order + if act_order: + if group_size in (-1, k, n): + return False + if quant_type not in [scalar_types.uint4b8]: + return False + elif not is_k_full: + return False + + return True + + cases = [] + for case in all_combinations: + if is_invalid(*case): + cases.append(case) + return cases + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.parametrize(("m, n, k, e, topk, ep_size, dtype, group_size," + "act_order, quant_type, is_k_full"), + marlin_moe_generate_valid_test_cases()) @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") def test_fused_marlin_moe( m: int, @@ -308,14 +372,22 @@ def test_fused_marlin_moe( dtype: torch.dtype, group_size: int, act_order: bool, - num_bits: int, - has_zp: bool, + quant_type: ScalarType, is_k_full: bool, ): - current_platform.seed_everything(7) + torch.cuda.manual_seed(0) + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] + + if quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128]: + return + if act_order: + return # Filter act_order if act_order: + if quant_type == scalar_types.float8_e4m3fn: + return if group_size == -1: return if group_size in (k, n): @@ -326,17 +398,14 @@ def test_fused_marlin_moe( if not is_k_full: return - if has_zp: - # we don't build kernel for int8 with zero - if num_bits == 8: - return - quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 - else: - quant_type = scalar_types.uint4b8 \ - if num_bits == 4 else scalar_types.uint8b128 + if quant_type == scalar_types.float4_e2m1f and group_size != 16: + return + if quant_type != scalar_types.float4_e2m1f and group_size == 16: + return + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 20 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 20 if ep_size > 1: local_e = e // ep_size @@ -351,12 +420,27 @@ def test_fused_marlin_moe( w_ref1_l = [] qweight1_l = [] scales1_l = [] + global_scale1_l = [] zeros1_l = [] g_idx1_l = [] sort_indices1_l = [] for i in range(w1.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref1, qweight1, scales1, global_scale1 = \ + rand_marlin_weight_fp4_like(w1[i], group_size) + + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + global_scale1_l.append(global_scale1) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref1, qweight1, scales1 = marlin_quant_fp8_torch( + w1[i], group_size) + w_ref1_l.append(w_ref1.T) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + elif has_zp: w_ref1, qweight1, scales1, zeros1 = awq_marlin_quantize( w1[i].transpose(1, 0), quant_type, group_size) @@ -366,9 +450,9 @@ def test_fused_marlin_moe( zeros1_l.append(zeros1) else: test_perm = torch.randperm(k) - quant_res = marlin_quantize(w1[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) - w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = quant_res + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = \ + marlin_quantize(w1[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) w_ref1_l.append(w_ref1.T) qweight1_l.append(qweight1) @@ -379,6 +463,7 @@ def test_fused_marlin_moe( w_ref1 = stack_and_dev(w_ref1_l) qweight1 = stack_and_dev(qweight1_l).contiguous() scales1 = stack_and_dev(scales1_l) + global_scale1 = stack_and_dev(global_scale1_l) if global_scale1_l else None g_idx1 = stack_and_dev(g_idx1_l) if g_idx1_l else None zeros1 = stack_and_dev(zeros1_l) if zeros1_l else None sort_indices1 = stack_and_dev(sort_indices1_l) if sort_indices1_l else None @@ -386,12 +471,27 @@ def test_fused_marlin_moe( w_ref2_l = [] qweight2_l = [] scales2_l = [] + global_scale2_l = [] zeros2_l = [] g_idx2_l = [] sort_indices2_l = [] for i in range(w2.shape[0]): - if has_zp: + if quant_type == scalar_types.float4_e2m1f: + w_ref2, qweight2, scales2, global_scale2 = \ + rand_marlin_weight_fp4_like(w2[i], group_size) + + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + global_scale2_l.append(global_scale2) + elif quant_type == scalar_types.float8_e4m3fn: + w_ref2, qweight2, scales2 = marlin_quant_fp8_torch( + w2[i], group_size) + w_ref2_l.append(w_ref2.T) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + elif has_zp: w_ref2, qweight2, scales2, zeros2 = awq_marlin_quantize( w2[i].transpose(1, 0), quant_type, group_size) @@ -401,9 +501,9 @@ def test_fused_marlin_moe( zeros2_l.append(zeros2) else: test_perm = torch.randperm(n) - quant_res = marlin_quantize(w2[i].transpose(1, 0), quant_type, - group_size, act_order, test_perm) - w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = quant_res + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = \ + marlin_quantize(w2[i].transpose(1, 0), quant_type, + group_size, act_order, test_perm) w_ref2_l.append(w_ref2.T) qweight2_l.append(qweight2) @@ -414,15 +514,17 @@ def test_fused_marlin_moe( w_ref2 = stack_and_dev(w_ref2_l) qweight2 = stack_and_dev(qweight2_l).contiguous() scales2 = stack_and_dev(scales2_l) + global_scale2 = stack_and_dev(global_scale2_l) if global_scale2_l else None g_idx2 = stack_and_dev(g_idx2_l) if g_idx2_l else None zeros2 = stack_and_dev(zeros2_l) if zeros2_l else None sort_indices2 = stack_and_dev(sort_indices2_l) if sort_indices2_l else None score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, @@ -435,108 +537,18 @@ def test_fused_marlin_moe( topk_ids, global_num_experts=e, expert_map=e_map, + global_scale1=global_scale1, + global_scale2=global_scale2, g_idx1=g_idx1, g_idx2=g_idx2, sort_indices1=sort_indices1, sort_indices2=sort_indices2, w1_zeros=zeros1, w2_zeros=zeros2, - num_bits=num_bits, + quant_type_id=quant_type.id, is_k_full=is_k_full) - torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) - - -@pytest.mark.skip("This test is here for the sake of debugging, " - "don't run it in automated tests.") -@pytest.mark.parametrize("m", [1, 33, 123]) -@pytest.mark.parametrize("n", [128, 1024]) -@pytest.mark.parametrize("k", [256, 2048]) -@pytest.mark.parametrize("e", [4, 12]) -@pytest.mark.parametrize("topk", [2, 3]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("group_size", [-1, 32, 128]) -@pytest.mark.parametrize("act_order", [True, False]) -@pytest.mark.parametrize("num_bits", [4, 8]) -@pytest.mark.parametrize("has_zp", [True, False]) -@pytest.mark.parametrize("is_k_full", [True, False]) -def test_single_marlin_moe_multiply(m: int, n: int, k: int, e: int, topk: int, - dtype: torch.dtype, group_size: int, - act_order: bool, num_bits: int, - has_zp: bool, is_k_full: bool): - # Filter act_order - if act_order: - if group_size == -1: - return - if group_size in (k, n): - return - if has_zp: - return - else: - if not is_k_full: - return - - if has_zp: - quant_type = scalar_types.uint4 if num_bits == 4 else scalar_types.uint8 - else: - quant_type = scalar_types.uint4b8 \ - if num_bits == 4 else scalar_types.uint8b128 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 - - w_ref_l = [] - qweight_l = [] - scales_l = [] - zeros_l = [] - g_idx_l = [] - sort_indices_l = [] - - for i in range(w.shape[0]): - if has_zp: - w_ref, qweight, scales, zeros = awq_marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size) - - w_ref_l.append(w_ref.T) - qweight_l.append(qweight) - scales_l.append(scales) - zeros_l.append(zeros) - else: - test_perm = torch.randperm(k) - w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size, act_order, - test_perm) - - w_ref_l.append(w_ref.T) - qweight_l.append(qweight) - scales_l.append(scales) - g_idx_l.append(g_idx) - sort_indices_l.append(sort_indices) - - w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweight_l).contiguous() - scales = stack_and_dev(scales_l) - g_idx = stack_and_dev(g_idx_l) if g_idx_l else None - zeros = stack_and_dev(zeros_l) if zeros_l else None - sort_indices = stack_and_dev(sort_indices_l) if sort_indices_l else None - - score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = torch.ops.vllm.single_marlin_moe( - a, - qweight, - scales, - score, - topk, - renormalize=False, - g_idx=g_idx, - sort_indices=sort_indices, - w_zeros=zeros, - num_bits=num_bits, - is_k_full=is_k_full, - ) - - torch_output = torch_moe_single(a, w_ref, score, topk) - - torch.testing.assert_close(marlin_output, torch_output, atol=2e-2, rtol=0) + torch.testing.assert_close(marlin_output, torch_output, atol=5e-2, rtol=0) def test_moe_align_block_size_opcheck(): diff --git a/tests/kernels/moe/test_moe_permute_unpermute.py b/tests/kernels/moe/test_moe_permute_unpermute.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcd61f7758702b24399dc0e3855bde9c1fe35ce --- /dev/null +++ b/tests/kernels/moe/test_moe_permute_unpermute.py @@ -0,0 +1,223 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE permute/unpermute kernel + +Run `pytest tests/kernels/test_moe_permute_unpermute.py`. +""" + +from typing import Optional + +import numpy as np +import pytest +import torch + +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.layer import determine_expert_map +from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( + moe_permute, moe_unpermute) +from vllm.platforms import current_platform + +NUM_EXPERTS = [16, 64] +TOP_KS = [2, 4, 6, 8] +EP_SIZE = [1, 4, 16] +current_platform.seed_everything(0) + + +def torch_permute(hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + topk: int, + n_expert: int, + n_local_expert: int, + start_expert: int, + expert_map: Optional[torch.Tensor] = None, + align_block_size: Optional[int] = None, + fill_invalid_expert: int = -1) -> list[torch.Tensor]: + n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1] + if expert_map is not None: + is_local_expert = (expert_map[topk_ids] != -1) + not_local_expert = (expert_map[topk_ids] == -1) + topk_ids = is_local_expert * ( + topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert) + + sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(), + stable=True) + dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices] + + expert_first_token_offset = torch.zeros(n_local_expert + 1, + dtype=torch.int64, + device="cuda") + idx = 0 + for i in range(0, n_local_expert): + cnt = 0 + while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i: + cnt += 1 + idx += 1 + expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt + + _, src2dst_idx = torch.sort(dst_row_id2src_row_id_map) + valid_row_idx = [] + if align_block_size is None: + + permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map % + n_token, ...] + permuted_row_size = permuted_hidden_states.shape[0] + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + m_indices[first_token_offset:last_token_offset] = i - 1 + src_row_id2dst_row_id_map = torch.arange( + 0, n_token * topk, device="cuda", + dtype=torch.int32)[src2dst_idx].reshape((n_token, topk)) + valid_row_idx += [i for i in range(expert_first_token_offset[-1])] + return [ + permuted_hidden_states, expert_first_token_offset, + src_row_id2dst_row_id_map, m_indices, valid_row_idx + ] + else: + permuted_row_size = (topk * n_token + n_expert * + (align_block_size - 1) + align_block_size - + 1) // align_block_size * align_block_size + permuted_hidden_states = torch.empty((permuted_row_size, n_hidden), + device="cuda", + dtype=hidden_states.dtype) + align_src_row_id2dst_row_id = torch.empty(n_token * topk, + device="cuda", + dtype=torch.int32) + align_expert_first_token_offset = torch.zeros_like( + expert_first_token_offset) + m_indices = torch.empty(permuted_row_size, + device="cuda", + dtype=torch.int32).fill_(fill_invalid_expert) + # get align_permuted_hidden_states, + # valid row_idx and align_expert_first_token_offset + for i in range(1, n_local_expert + 1): + first_token_offset = expert_first_token_offset[i - 1] + last_token_offset = expert_first_token_offset[i] + n_token_in_expert = last_token_offset - first_token_offset + align_expert_first_token_offset[ + i] = align_expert_first_token_offset[ + i - 1] + (n_token_in_expert + align_block_size - + 1) // align_block_size * align_block_size + align_first_token_offset = align_expert_first_token_offset[i - 1] + align_last_token_offset = align_expert_first_token_offset[i] + dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[ + first_token_offset:first_token_offset + + n_token_in_expert] % n_token + # store token in current expert with align_first_token_offset + permuted_hidden_states[align_first_token_offset:\ + align_first_token_offset+n_token_in_expert,\ + ...] = hidden_states[\ + dst_row_id2src_row_id_in_expert, ...] + # set current expert m_indices + m_indices[align_first_token_offset:align_last_token_offset] = i - 1 + valid_row_idx += [ + i for i in range(align_first_token_offset, + align_first_token_offset + n_token_in_expert) + ] + # get align_src_row_id2dst_row_id + for i in range(n_token * topk): + eid = sorted_topk_ids[i] + if (eid >= n_local_expert): + # check token not in local expert + align_src_row_id2dst_row_id[ + i] = align_expert_first_token_offset[-1] + continue + first_token_offset = expert_first_token_offset[eid] + align_first_token_offset = align_expert_first_token_offset[eid] + token_offset = i - first_token_offset + align_src_row_id2dst_row_id[ + i] = align_first_token_offset + token_offset + align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\ + src2dst_idx].reshape((n_token, topk)) + return [ + permuted_hidden_states, align_expert_first_token_offset, + align_src_row_id2dst_row_id, m_indices, valid_row_idx + ] + + +def torch_unpermute(permuted_hidden_states: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + token_expert_indices: torch.Tensor, + src_row_id2dst_row_id_map: torch.Tensor, + valid_row_idx: torch.Tensor, topk: int, + n_expert: int) -> torch.Tensor: + # ignore invalid row + mask = torch.zeros(permuted_hidden_states.shape[0], + dtype=bool, + device="cuda") + mask[valid_row_idx] = True + permuted_hidden_states[~mask] = 0 + idx = src_row_id2dst_row_id_map.flatten()[ + token_expert_indices.flatten()].reshape(token_expert_indices.shape) + output = permuted_hidden_states[idx, ...] * topk_weights[..., None] + output = output.sum(dim=1).to(permuted_hidden_states.dtype) + return output + + +@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000]) +@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168]) +@pytest.mark.parametrize("n_expert", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("ep_size", EP_SIZE) +@pytest.mark.parametrize("align_block_size", [None, 128]) +def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int, + n_expert: int, ep_size: int, dtype: torch.dtype, + align_block_size: Optional[int]): + fill_invalid_expert = 0 + ep_rank = np.random.randint(0, ep_size) + expert_map = None + n_local_expert = n_expert + if (ep_size != 1): + n_local_expert, expert_map = determine_expert_map( + ep_size, ep_rank, n_expert) + expert_map = expert_map.cuda() + start_expert = n_local_expert * ep_rank + current_platform.seed_everything(0) + hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype) + gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype) + topk_weights, topk_ids, token_expert_indices = fused_topk( + hidden_states, gating_output, topk, False) + gold0, gold1, gold2, gold3, valid_row_idx = torch_permute( + hidden_states, + topk_ids, + token_expert_indices, + topk, + n_expert, + n_local_expert, + start_expert, + expert_map=expert_map, + align_block_size=align_block_size, + fill_invalid_expert=fill_invalid_expert) + + result0, result1, result2, result3 = moe_permute( + hidden_states, topk_weights, topk_ids, token_expert_indices, topk, + n_expert, n_local_expert, expert_map, align_block_size, + fill_invalid_expert) + + # check expert_first_token_offset + torch.testing.assert_close(gold1, result1, atol=0, rtol=0) + # check src_row_id2dst_row_id_map + torch.testing.assert_close(gold2, result2, atol=0, rtol=0) + # check mindice + torch.testing.assert_close(gold3, result3, atol=0, rtol=0) + # check permuted_hidden_states, only valid token + torch.testing.assert_close(gold0[valid_row_idx], + result0[valid_row_idx], + atol=0, + rtol=0) + + # add a random tensor to simulate group gemm + result0 = 0.5 * result0 + torch.randn_like(result0) + + result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1, + topk, n_expert, n_local_expert) + gold4 = torch_unpermute(result0, topk_weights, topk_ids, + token_expert_indices, result2, valid_row_idx, topk, + n_local_expert) + + # check unpermuted hidden + torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0) diff --git a/tests/kernels/moe/test_nvfp4_moe.py b/tests/kernels/moe/test_nvfp4_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..ae63b379f39d10fc72f49f2222d3354f101ba9e8 --- /dev/null +++ b/tests/kernels/moe/test_nvfp4_moe.py @@ -0,0 +1,144 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +from tests.kernels.quantization.nvfp4_utils import (FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) +from tests.kernels.utils import torch_moe +from vllm import _custom_ops as ops +from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.platforms import current_platform + +if not current_platform.has_device_capability(100): + pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True) + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1536), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +@pytest.mark.parametrize("m,n,k", MNK_FACTORS) +@pytest.mark.parametrize("e", [40, 64, 256]) +@pytest.mark.parametrize("topk", [1, 6, 8]) +@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) +@torch.inference_mode() +def test_cutlass_fp4_moe_no_graph(m: int, n: int, k: int, e: int, topk: int, + dtype: torch.dtype): + current_platform.seed_everything(7) + with set_current_vllm_config( + VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1))): + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + quant_blocksize = 16 + round_up = lambda x, y: (x + y - 1) // y * y + sf_w1_2n = round_up(2 * n, 128) + sf_w1_k = round_up(k // quant_blocksize, 4) + w1_blockscale = torch.empty((e, sf_w1_2n, sf_w1_k), + device="cuda", + dtype=torch.float8_e4m3fn) + + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + sf_w2_k = round_up(k, 128) + sf_w2_n = round_up(n // quant_blocksize, 4) + w2_blockscale = torch.empty((e, sf_w2_k, sf_w2_n), + device="cuda", + dtype=torch.float8_e4m3fn) + + w1_q = torch.empty((e, 2 * n, k // 2), + device="cuda", + dtype=torch.uint8) + w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8) + w1_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + w2_gs = torch.empty((e, ), device="cuda", dtype=torch.float32) + + for expert in range(e): + w1_amax = torch.abs(w1).max().to(torch.float32) + w2_amax = torch.abs(w2).max().to(torch.float32) + w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + + w1_q[expert], w1_blockscale[expert] = ops.scaled_fp4_quant( + w1[expert], w1_gs[expert]) + + w2_q[expert], w2_blockscale[expert] = ops.scaled_fp4_quant( + w2[expert], w2_gs[expert]) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + + a1_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + a2_gs = torch.ones((e, ), device="cuda", dtype=torch.float32) + + cutlass_output = cutlass_moe_fp4( + a=a, + a1_gscale=a1_gs, + w1_fp4=w1_q, + w1_blockscale=w1_blockscale, + w1_alphas=(1 / w1_gs), + a2_gscale=a2_gs, + w2_fp4=w2_q, + w2_blockscale=w2_blockscale, + w2_alphas=(1 / w2_gs), + topk_weights=topk_weights, + topk_ids=topk_ids, + m=m, + n=n, + k=k, + e=e, + device=a.device, + ) + + # Reference check: + a_global_scale = ((FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / + torch.amax(a.flatten(), dim=-1)).to(torch.float32) + a_fp4, a_scale_interleaved = ops.scaled_fp4_quant(a, a_global_scale) + _, m_k = a_fp4.shape + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_scale_interleaved, + a_global_scale, + dtype=a.dtype, + device=a.device, + block_size=quant_blocksize) + + w1_d = torch.empty((e, 2 * n, k), device="cuda", dtype=dtype) + w2_d = torch.empty((e, k, n), device="cuda", dtype=dtype) + + for idx in range(0, e): + w1_d[idx] = dequantize_nvfp4_to_dtype(w1_q[idx], + w1_blockscale[idx], + w1_gs[idx], + dtype=w1.dtype, + device=w1.device, + block_size=quant_blocksize) + w2_d[idx] = dequantize_nvfp4_to_dtype(w2_q[idx], + w2_blockscale[idx], + w2_gs[idx], + dtype=w2.dtype, + device=w2.device, + block_size=quant_blocksize) + + torch_output = torch_moe(a_in_dtype, w1_d, w2_d, score, topk, None) + + torch.testing.assert_close(torch_output, + cutlass_output, + atol=1e-1, + rtol=1e-1) + + +if __name__ == "__main__": + test_cutlass_fp4_moe_no_graph((2, 1024, 1024), 40, 1, torch.half) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4a2c3fa440ff47ae8fe0797e93c12c95fd6a37 --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import traceback +from typing import Callable, Optional + +import pytest +import torch + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = True +except ImportError: + has_pplx = False + +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import override_config +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, + get_default_config) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.platforms import current_platform + +PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), + (222, 2048, 1024)] + +PPLX_MOE_COMBOS = [ + (1, 128, 128), + (2, 128, 512), + (3, 1024, 2048), + (32, 128, 1024), + (45, 512, 2048), + (64, 1024, 1024), + (222, 1024, 2048), +] + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [1, 2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_prepare( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens, hidden_dim = a.shape + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + + if max_num_tokens is None: + max_num_tokens = int(tokens_per_expert.max().item()) + + b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), + dtype=a.dtype, + device=a.device) + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx + 1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + num_tokens = topk_ids.shape[0] + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_finalize(out, topk_weight, topk_ids) + + +def batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) + + return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) + + +# Note: same as torch_moe but with fused_topk factored out. +def torch_moe2( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) + + torch.testing.assert_close(baseline_output, + torch_output, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_output, + batched_output, + atol=2e-2, + rtol=0) + + +def rank_chunk(num: int, r: int, w: int) -> int: + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + + +def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, + topk_weight: torch.Tensor, topk_ids: torch.Tensor, + num_experts: int) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + assert torch.cuda.current_device() == pgi.local_rank + + topk = topk_ids.shape[1] + num_tokens, hidden_dim = a.shape + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, + None, + False, + ) + + b_a = b_a * 1.5 + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + prepare_finalize.finalize( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + + torch.cuda.synchronize() + + ata.destroy() + + num_tokens = a_chunk.shape[0] + + return out[:num_tokens] + + +def _pplx_prepare_finalize( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + score: torch.Tensor, + topk: torch.Tensor, + num_experts: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + k = a.shape[1] + + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( + a.dtype) + + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, + num_experts) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO (bnell): this test point does not work for odd M due to how the test is +# written, not due to limitations of the pplx kernels. The pplx_moe +# test below is able to deal with odd M. +@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_prepare_finalize( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, + topk, e) + + +def pplx_moe( + rank: int, + world_size: int, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + use_compile: bool = True, + use_cudagraphs: bool = True, +) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + device = torch.device("cuda", rank) + hidden_dim = a.shape[1] + num_experts = w1.shape[0] + block_size = 128 + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + ) + + experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + world_size=world_size, + dp_size=dp_size) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + # Chunking weights like this only works for batched format + w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) + w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + + if use_compile: + _fused_experts = torch.compile(fused_experts, + backend='inductor', + fullgraph=True) + else: + _fused_experts = fused_experts + + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + if use_cudagraphs: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + torch.cuda.synchronize() + graph.replay() + + torch.cuda.synchronize() + + ata.destroy() + + return out + + +def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + assert torch.cuda.current_device() == pgi.local_rank + + num_experts = w1.shape[0] + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + prepare_finalize = BatchedPrepareAndFinalize( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + + experts = BatchedExperts(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + return out + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + + with set_current_vllm_config(vllm_config), override_config(moe_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, + topk_weight, topk_ids) + # TODO (bnell): fix + re-enable + #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, + # topk_ids) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_moe( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/tests/kernels/moe/test_rocm_aiter_topk.py b/tests/kernels/moe/test_rocm_aiter_topk.py new file mode 100644 index 0000000000000000000000000000000000000000..b0d34ddfd4234bc0c41b7a8bc54d88d3cbf770d6 --- /dev/null +++ b/tests/kernels/moe/test_rocm_aiter_topk.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# This is a test for the AITER ops. +# It tests if the AITER ops are +# 1. correctly registered as custom ops +# 2. correctly defined the relationship between +# implementation and fake function +# 3. can be used with torch.compile +# This file will be skipped if AITER is not installed +# and the platform is not ROCm. + +import importlib.util + +import pytest +import torch + +# this import statement is needed to ensure the ops are registered +import vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe # noqa: F401 +from vllm.platforms import current_platform + +# need to import once to ensure the ops are registered +# Check if aiter package is installed +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = pytest.mark.skipif( + not (current_platform.is_rocm() and aiter_available), + reason="AITER ops are only available on ROCm with aiter package installed") + + +def test_rocm_aiter_biased_grouped_topk_custom_op_registration(): + """Test that the custom op is correctly registered.""" + # Check if the op exists in torch.ops.vllm + assert hasattr(torch.ops.vllm, 'rocm_aiter_biased_grouped_topk') + + # Check if the op is callable + assert callable(torch.ops.vllm.rocm_aiter_biased_grouped_topk) + + +def test_rocm_aiter_biased_grouped_topk_torch_compile_compatibility(): + """Test that the op can be used with torch.compile.""" + # Create test tensors + token = 64 + expert = 256 + num_expert_group = 8 + topk = 8 + topk_group = 4 + renormalize = True + scale_factor = 1.0 + + gating_output = torch.randn((token, expert), + dtype=torch.bfloat16, + device="cuda") + e_score_correction_bias = torch.randn((expert, ), + dtype=torch.bfloat16, + device="cuda") + + device = gating_output.device + topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device) + topk_weights = torch.empty((token, topk), + dtype=torch.float32, + device=device) + + # Define a function that uses the op + def biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights, topk_ids): + return torch.ops.vllm.rocm_aiter_biased_grouped_topk( + gating_output, e_score_correction_bias, topk_weights, topk_ids, + num_expert_group, topk_group, renormalize, scale_factor) + + # Verify the op's fake implementation + torch.library.opcheck( + torch.ops.vllm.rocm_aiter_biased_grouped_topk, + (gating_output, e_score_correction_bias, topk_weights, topk_ids), + kwargs={ + "num_expert_group": num_expert_group, + "topk_group": topk_group, + "need_renorm": renormalize, + "routed_scaling_factor": scale_factor + }, + test_utils=("test_faketensor")) + + # Compile the function with appropriate settings + compiled_fn = torch.compile(biased_grouped_topk_fn, + fullgraph=True, + backend="inductor", + mode="reduce-overhead", + dynamic=False) + + topk_weights_original = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_original = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + topk_weights_compiled = torch.empty((token, topk), + dtype=torch.float32, + device=device) + topk_ids_compiled = torch.empty((token, topk), + dtype=torch.int32, + device=device) + + # Run both compiled (V1 graph mode) and uncompiled versions (V1 eager mode) + biased_grouped_topk_fn(gating_output, e_score_correction_bias, + topk_weights_original, topk_ids_original) + compiled_fn(gating_output, e_score_correction_bias, topk_weights_compiled, + topk_ids_compiled) + + # Sort the results for comparison since the order might not be deterministic + topk_ids_original, indices_original = torch.sort(topk_ids_original) + topk_weights_original = torch.gather(topk_weights_original, 1, + indices_original) + + topk_ids_compiled, indices_compiled = torch.sort(topk_ids_compiled) + topk_weights_compiled = torch.gather(topk_weights_compiled, 1, + indices_compiled) + + # Verify results match + assert torch.allclose(topk_weights_original, + topk_weights_compiled, + rtol=1e-2, + atol=1e-2) + assert torch.allclose(topk_ids_original, topk_ids_compiled) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340aa1e3738fbbe85f92933a739f830bb..3b5838a99fa156c398b8063b185c32cfa7ff1046 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import pytest import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/nvfp4_utils.py b/tests/kernels/quantization/nvfp4_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..58eaeee1c0b88204073cc244ad916ede944ce149 --- /dev/null +++ b/tests/kernels/quantization/nvfp4_utils.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +import torch + +from vllm.scalar_type import scalar_types + +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() +FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +kE2M1ToFloat = torch.tensor([0., 0.5, 1., 1.5, 2., 3., 4., 6.], + dtype=torch.float32) + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype(tensor_fp4, + tensor_sf, + global_scale, + dtype, + device, + block_size=16): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) diff --git a/tests/kernels/quantization/test_awq_marlin.py b/tests/kernels/quantization/test_awq_marlin.py deleted file mode 100644 index 939b0e7157be7b75514333749241339e0b6ee374..0000000000000000000000000000000000000000 --- a/tests/kernels/quantization/test_awq_marlin.py +++ /dev/null @@ -1,163 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Test AWQ with fused MoE Marlin kernels. - -Run `pytest tests/kernels/test_awq_marlin.py`. -""" -import pytest -import torch - -import vllm.model_executor.layers.fused_moe # noqa -from tests.kernels.utils import (compute_max_diff, stack_and_dev, torch_moe, - torch_moe_single) -from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk -from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( - awq_marlin_quantize) -from vllm.scalar_type import scalar_types - -NUM_EXPERTS = [8, 64] -TOP_KS = [2, 6] -GROUP_SIZES = [-1, 32, 128] - - -@pytest.mark.parametrize("m", [1, 33, 64, 222]) -@pytest.mark.parametrize("n", [128, 2048]) -@pytest.mark.parametrize("k", [128, 1024]) -@pytest.mark.parametrize("e", NUM_EXPERTS) -@pytest.mark.parametrize("topk", TOP_KS) -@pytest.mark.parametrize("group_size", GROUP_SIZES) -@pytest.mark.skipif(not (ops.supports_moe_ops - and hasattr(torch.ops._moe_C, "marlin_gemm_moe")), - reason="Marlin is not supported on this GPU type.") -def test_fused_marlin_moe_awq( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, -): - torch.manual_seed(7) - - num_bits = 4 - quant_type = scalar_types.uint4 - dtype = torch.float16 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - - w_ref1_l = [] - qweights1_l = [] - scales1_l = [] - zp1_l = [] - - for i in range(w1.shape[0]): - w_ref1, qweight1, scales1, zp1 = awq_marlin_quantize( - w1[i].transpose(1, 0), quant_type, group_size) - w_ref1_l.append(w_ref1) - qweights1_l.append(qweight1) - scales1_l.append(scales1) - zp1_l.append(zp1) - - w_ref1 = stack_and_dev(w_ref1_l) - qweight1 = stack_and_dev(qweights1_l).contiguous() - scales1 = stack_and_dev(scales1_l) - zp1 = stack_and_dev(zp1_l) - - w_ref2_l = [] - qweights2_l = [] - scales2_l = [] - zp2_l = [] - - for i in range(w2.shape[0]): - w_ref2, qweight2, scales2, zp2 = awq_marlin_quantize( - w2[i].transpose(1, 0), quant_type, group_size) - w_ref2_l.append(w_ref2) - qweights2_l.append(qweight2) - scales2_l.append(scales2) - zp2_l.append(zp2) - - w_ref2 = stack_and_dev(w_ref2_l) - qweight2 = stack_and_dev(qweights2_l).contiguous() - scales2 = stack_and_dev(scales2_l) - zp2 = stack_and_dev(zp2_l) - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - topk_weights, topk_ids = fused_topk(a, score, topk, False) - marlin_output = torch.ops.vllm.fused_marlin_moe( - a, - qweight1, - qweight2, - scales1, - scales2, - score, - topk_weights, - topk_ids, - w1_zeros=zp1, - w2_zeros=zp2, - num_bits=num_bits, - ) - - torch_output = torch_moe(a, w_ref1.transpose(1, 2), w_ref2.transpose(1, 2), - score, topk, None) - - assert compute_max_diff(marlin_output, torch_output) < 4e-2 - - -@pytest.mark.skip("This test is here for the sake of debugging, " - "don't run it in automated tests.") -@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) -@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) -@pytest.mark.parametrize("k", [128, 1024, 512]) -@pytest.mark.parametrize("e", [8, 64]) -@pytest.mark.parametrize("topk", [2, 6]) -@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) -def test_single_marlin_moe_multiply_awq( - m: int, - n: int, - k: int, - e: int, - topk: int, - group_size: int, -): - torch.manual_seed(7) - - num_bits = 4 - quant_type = scalar_types.uint4 - dtype = torch.float16 - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 - - w_ref_l = [] - qweights_l = [] - scales_l = [] - zp_l = [] - - for i in range(w.shape[0]): - w_ref, qweight, scales, zp = awq_marlin_quantize( - w[i].transpose(1, 0), quant_type, group_size) - w_ref_l.append(w_ref) - qweights_l.append(qweight) - scales_l.append(scales) - zp_l.append(zp) - - w_ref = stack_and_dev(w_ref_l) - qweight = stack_and_dev(qweights_l).contiguous() - scales = stack_and_dev(scales_l).contiguous() - zp = stack_and_dev(zp_l).contiguous() - - score = torch.randn((m, e), device="cuda", dtype=dtype) - - marlin_output = torch.ops.vllm.single_marlin_moe(a, - qweight, - scales, - score, - topk, - renormalize=False, - w_zeros=zp, - num_bits=num_bits) - - torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) - - assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index c57e39f4250646db4247ed0368ad3727dd40e0e5..ef1d7e47ef8107da43357223987093a44c8ce0b1 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -30,6 +30,10 @@ if current_platform.get_device_capability() < (9, 0): pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -258,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -338,7 +342,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk, M, K = a.shape N = w2.shape[-1] - topk_weight, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weight, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) block_m = deep_gemm.get_m_alignment_for_contiguous_layout() @@ -380,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - if N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - vllm_config = VllmConfig() + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) @@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_size) - topk_weights, topk_ids = fused_topk(a, score.float(), topk, False) + topk_weights, topk_ids, token_expert_indices = fused_topk( + a, score.float(), topk, False) out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids) diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd2f63fd101facb494b9587f2846510..a4e9f83f0eaf1ddebadec2f4b73a9c9e81638948 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ if current_platform.get_device_capability() < (7, 0): pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 8084d9bf2c2da23930095e2873e6fb7af27fc0d2..633addd421f4389f89e3f085109db741f1f97078 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int, out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias) - torch.testing.assert_close(out, baseline, rtol=1e-2, atol=5e-2) + torch.testing.assert_close(out, baseline, rtol=1e-2, atol=1.5e-1) opcheck(torch.ops._C.cutlass_scaled_mm, (out, a, b, scale_a, scale_b, bias)) @@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int, return if m % a_scale_group_shape[0] != 0 or k % a_scale_group_shape[1] != 0: return + if m % 4 != 0 and current_platform.has_device_capability(100): + return cutlass_fp8_gemm_helper(m, n, k, a_scale_group_shape, b_scale_group_shape, use_bias) diff --git a/tests/kernels/quantization/test_ggml.py b/tests/kernels/quantization/test_ggml.py index cc157da518cbfd4bf17b11e59853b8704eec0956..73697a6d1242dd4c1c7add0d305b96ad21e633e0 100644 --- a/tests/kernels/quantization/test_ggml.py +++ b/tests/kernels/quantization/test_ggml.py @@ -36,3 +36,9 @@ def test_ggml_opcheck(quant_type): opcheck(torch.ops._C.ggml_moe_a8, (x, qweight, sorted_token_ids, expert_ids, num_tokens_post_padded, quant_type, qweight.shape[0], 1, x.shape[0])) + + topk_ids = torch.zeros((1, 1), device='cuda', dtype=torch.int32) + + opcheck( + torch.ops._C.ggml_moe_a8_vec, + (x, qweight, topk_ids, 1, quant_type, qweight.shape[0], x.shape[0])) diff --git a/tests/kernels/quantization/test_gguf.py b/tests/kernels/quantization/test_gguf.py index 4c0fae9d9fd752739fd69b7eb7daa1212cb600df..6cf88604ec65ea5cce750ff8e9f926b8389c9aa4 100644 --- a/tests/kernels/quantization/test_gguf.py +++ b/tests/kernels/quantization/test_gguf.py @@ -151,20 +151,7 @@ def test_mmq(num_tokens: int, hidden_size: int, dtype: torch.dtype, @pytest.mark.parametrize("hidden_size", [512]) @pytest.mark.parametrize("top_k", [4, 8]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize( - "quant_type", - [ - # k-quants - GGMLQuantizationType.Q2_K, - GGMLQuantizationType.Q3_K, - GGMLQuantizationType.Q4_K, - GGMLQuantizationType.Q5_K, - GGMLQuantizationType.Q6_K, - # standard quants - GGMLQuantizationType.Q4_0, - GGMLQuantizationType.Q5_0, - GGMLQuantizationType.Q8_0, - ]) +@pytest.mark.parametrize("quant_type", QUANT_TYPES) @torch.inference_mode() def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, quant_type: GGMLQuantizationType, top_k: int): @@ -174,7 +161,10 @@ def test_moe(num_tokens: int, hidden_size: int, dtype: torch.dtype, x = torch.rand((num_tokens, H), dtype=dtype, device="cuda") topk_weights = torch.rand(num_tokens, top_k, device="cuda", dtype=dtype) - topk_ids = torch.randint(0, E, (num_tokens, top_k), device="cuda") + topk_ids = torch.randint(0, + E, (num_tokens, top_k), + device="cuda", + dtype=torch.int32) tensors = get_gguf_MoE_tensors(hidden_size, quant_type) diff --git a/tests/kernels/quantization/test_marlin_gemm.py b/tests/kernels/quantization/test_marlin_gemm.py index 3165201aa35321da8b1bfddae869ba6409c7fce9..52507b375c2717c645062902a7be72592beb470c 100644 --- a/tests/kernels/quantization/test_marlin_gemm.py +++ b/tests/kernels/quantization/test_marlin_gemm.py @@ -18,9 +18,12 @@ from vllm.model_executor.layers.quantization.qqq import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import ( GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, MARLIN_SUPPORTED_GROUP_SIZES, marlin_make_empty_g_idx, - marlin_permute_scales, query_marlin_supported_quant_types) + marlin_make_workspace_new, marlin_permute_scales, + query_marlin_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( + FP4_MARLIN_SUPPORTED_GROUP_SIZES, rand_marlin_weight_fp4_like) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - pack_fp8_to_int32) + marlin_quant_fp8_torch) from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( MarlinWorkspace, awq_marlin_quantize, get_weight_perm, marlin_quantize, marlin_weights) @@ -73,7 +76,7 @@ def rand_data(shape, dtype=torch.float16): @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) + query_marlin_supported_quant_types(False, False)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @@ -138,7 +141,7 @@ def test_gptq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) @pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) + query_marlin_supported_quant_types(True)) @pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, @@ -189,9 +192,10 @@ def test_awq_marlin_repack(k_chunk, n_chunk, quant_type, group_size, reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(False)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("quant_type", query_marlin_supported_quant_types()) +@pytest.mark.parametrize( + "group_size", + set(MARLIN_SUPPORTED_GROUP_SIZES + FP4_MARLIN_SUPPORTED_GROUP_SIZES)) @pytest.mark.parametrize("mnk_factors", MNK_FACTORS) @pytest.mark.parametrize("act_order", ACT_ORDER_OPTS) @pytest.mark.parametrize("is_k_full", K_FULL_OPTS) @@ -209,6 +213,7 @@ def test_gptq_marlin_gemm( use_fp32_reduce, ): m_factor, n_factor, k_factor = mnk_factors + has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8] size_m = m_factor size_k = k_chunk * k_factor @@ -219,39 +224,74 @@ def test_gptq_marlin_gemm( return if group_size == size_k: return + if has_zp: + return + + if size_k % group_size != 0: + return a_input = rand_data((size_m, size_k)) b_weight = rand_data((size_k, size_n)) - w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( - b_weight, quant_type, group_size, act_order) - - marlin_zp = marlin_make_empty_g_idx(marlin_s.device) + if quant_type == scalar_types.float4_e2m1f: + if group_size != 16 or act_order: + return + w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like( + b_weight.T, group_size) + g_idx = None + sort_indices = None + marlin_zp = None + elif quant_type == scalar_types.float8_e4m3fn: + if group_size not in [-1, 128]: + return + if act_order: + return + w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch( + b_weight.T, group_size) + g_idx = None + sort_indices = None + marlin_zp = None + marlin_s2 = None + elif has_zp: + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( + b_weight, quant_type, group_size) + g_idx = None + sort_indices = None + marlin_s2 = None + else: + if group_size == 16: + return + w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, _ = marlin_quantize( + b_weight, quant_type, group_size, act_order) + marlin_zp = None + marlin_s2 = None - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(w_ref.device) opcheck(torch.ops._C.gptq_marlin_gemm, - (a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices, - workspace.scratch, quant_type.id, a_input.shape[0], - b_weight.shape[1], a_input.shape[1], is_k_full, False, - use_atomic_add, use_fp32_reduce, False), + (a_input, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, g_idx, + sort_indices, workspace, quant_type.id, a_input.shape[0], + b_weight.shape[1], a_input.shape[1], is_k_full, use_atomic_add, + use_fp32_reduce, False), test_utils=DEFAULT_OPCHECK_TEST_UTILS) output = ops.gptq_marlin_gemm( a_input, + None, marlin_q_w, marlin_s, + marlin_s2, marlin_zp, g_idx, sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full=is_k_full, - has_zp=False, use_atomic_add=use_atomic_add, use_fp32_reduce=use_fp32_reduce, is_zp_float=False, @@ -326,143 +366,6 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size, assert max_diff < 0.04 -@pytest.mark.skipif(not is_quant_method_supported("fp8"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("num_bits", [8]) -@pytest.mark.parametrize("group_size", [-1]) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("dtype", DTYPES) -def test_fp8_marlin_gemm( - k_chunk, - n_chunk, - num_bits, - group_size, - mnk_factors, - dtype, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k), dtype=dtype) - b_weight = rand_data((size_k, size_n), dtype=dtype) - - # WEIGHTS - fp8_weight, weight_scale = ops.scaled_fp8_quant(b_weight, scale=None) - # Repack weights to gptq format (packed int32 elements) - packed_gptq_qweight = pack_fp8_to_int32(fp8_weight) - # Repack weights to marlin format - marlin_qweight = ops.gptq_marlin_repack( - b_q_weight=packed_gptq_qweight, - perm=torch.empty(0, dtype=torch.int, device="cuda"), - size_k=size_k, - size_n=size_n, - num_bits=8, - ) - - # WEIGHT SCALES - # Currently Marlin doesn't support per-tensor scales, so we - # expand it to channelwise - scales = weight_scale.repeat(1, size_n).to(a_input.dtype).to("cuda") - # Permute scales - marlin_scales = marlin_permute_scales(s=scales, - size_k=size_k, - size_n=size_n, - group_size=-1) - - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - opcheck(torch.ops._C.fp8_marlin_gemm, - (a_input, marlin_qweight, marlin_scales, workspace.scratch, - num_bits, a_input.shape[0], b_weight.shape[1], a_input.shape[1])) - - output = ops.fp8_marlin_gemm( - a=a_input, - b_q_weight=marlin_qweight, - b_scales=marlin_scales, - workspace=workspace.scratch, - num_bits=num_bits, - size_m=a_input.shape[0], - size_n=b_weight.shape[1], - size_k=a_input.shape[1], - ) - output_ref = torch.matmul(a_input, b_weight) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - -@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), - reason="Marlin is not supported on this GPU type.") -@pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) -@pytest.mark.parametrize("n_chunk", MARLIN_N_CHUNKS) -@pytest.mark.parametrize("quant_type", - query_marlin_supported_quant_types(True)) -@pytest.mark.parametrize("group_size", MARLIN_SUPPORTED_GROUP_SIZES) -@pytest.mark.parametrize("mnk_factors", MNK_FACTORS) -@pytest.mark.parametrize("use_fp32_reduce", USE_FP32_REDUCE_OPTS) -def test_awq_marlin_gemm( - k_chunk, - n_chunk, - quant_type, - group_size, - mnk_factors, - use_fp32_reduce, -): - m_factor, n_factor, k_factor = mnk_factors - - size_m = m_factor - size_k = k_chunk * k_factor - size_n = n_chunk * n_factor - - a_input = rand_data((size_m, size_k)) - b_weight = rand_data((size_k, size_n)) - - w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize( - b_weight, quant_type, group_size) - - g_idx = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - sort_indices = torch.empty(0, dtype=torch.int, device=marlin_q_w.device) - is_k_full = True - has_zp = True - - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) - - output = ops.gptq_marlin_gemm( - a_input, - marlin_q_w, - marlin_s, - marlin_zp, - g_idx, - sort_indices, - workspace.scratch, - quant_type, - a_input.shape[0], - b_weight.shape[1], - a_input.shape[1], - is_k_full=is_k_full, - has_zp=has_zp, - use_fp32_reduce=use_fp32_reduce, - is_zp_float=False, - ) - output_ref = torch.matmul(a_input, w_ref) - - torch.cuda.synchronize() - - max_diff = compute_max_diff(output, output_ref) - - assert max_diff < 0.04 - - @pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"), reason="Marlin is not supported on this GPU type.") @pytest.mark.parametrize("k_chunk", MARLIN_K_CHUNKS) @@ -508,23 +411,23 @@ def test_hqq_marlin_gemm( g_idx = marlin_make_empty_g_idx(dev) g_idx_sort_indices = marlin_make_empty_g_idx(dev) - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(b_weight.device) output = ops.gptq_marlin_gemm( a_input, + None, marlin_w_q, marlin_s, + None, marlin_zp, g_idx, g_idx_sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[0], a_input.shape[1], is_k_full=True, - has_zp=True, use_fp32_reduce=use_fp32_reduce, is_zp_float=True, ) @@ -621,23 +524,23 @@ def test_marlin_gemm_subset_input(): b_weight, quant_type, group_size, False) marlin_zp = marlin_make_empty_g_idx(marlin_s.device) - workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_MIN_THREAD_N, - GPTQ_MARLIN_MAX_PARALLEL) + workspace = marlin_make_workspace_new(a_input.device) output = ops.gptq_marlin_gemm( a_input, + None, marlin_q_w, marlin_s, + None, marlin_zp, g_idx, sort_indices, - workspace.scratch, + workspace, quant_type, a_input.shape[0], b_weight.shape[1], a_input.shape[1], is_k_full=True, - has_zp=False, use_atomic_add=False, use_fp32_reduce=True, is_zp_float=False, diff --git a/tests/kernels/quantization/test_nvfp4_quant.py b/tests/kernels/quantization/test_nvfp4_quant.py index 93735fc096d793a818eb3e317b2eb968ed4b3e05..b8aa1672100e2ccad5206b02e853e44646dc60b1 100644 --- a/tests/kernels/quantization/test_nvfp4_quant.py +++ b/tests/kernels/quantization/test_nvfp4_quant.py @@ -17,7 +17,7 @@ PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48), SEEDS = [42] CUDA_DEVICES = ['cuda:0'] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() +FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # E2M1 to float diff --git a/tests/kernels/quantization/test_nvfp4_scaled_mm.py b/tests/kernels/quantization/test_nvfp4_scaled_mm.py index b08026c5867dad3617f60eecb01b085612a78647..1f49900b2d90b8df868bde205521e1f5ee81cc2e 100644 --- a/tests/kernels/quantization/test_nvfp4_scaled_mm.py +++ b/tests/kernels/quantization/test_nvfp4_scaled_mm.py @@ -1,10 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 import pytest import torch +from nvfp4_utils import (FLOAT4_E2M1_MAX, FLOAT8_E4M3_MAX, + dequantize_nvfp4_to_dtype) from vllm import _custom_ops as ops from vllm.platforms import current_platform -from vllm.scalar_type import scalar_types if not current_platform.has_device_capability(100): pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.", @@ -19,95 +20,24 @@ SHAPES.extend(PAD_SHAPES) SEEDS = [42] CUDA_DEVICES = ['cuda:0'] -FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max() -FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max - -kE2M1ToFloatArray = [ - 0., - 0.5, - 1., - 1.5, - 2., - 3., - 4., - 6., -] - - -def e2m1_to_fp32(int4_value): - signBit = (int4_value & 0x8) - int4_absValue = int4_value & 0x7 - float_result = kE2M1ToFloatArray[int4_absValue] - if (signBit): - float_result = -float_result - return float_result - - -def break_fp4_bytes(a, dtype): - assert (a.dtype == torch.uint8) - m, n = a.shape - a = a.flatten() - # Get upper 4 bits - highHalfByte = (a & 0xF0) >> 4 - # Get lower 4 bits - lowHalfByte = a & 0x0F - fH = torch.tensor([e2m1_to_fp32(x) for x in highHalfByte]).to(a.device) - fL = torch.tensor([e2m1_to_fp32(x) for x in lowHalfByte]).to(a.device) - # [0xAB, 0xCD] -> [0xB, 0xA, 0xD, 0xC] - out = torch.stack((fL, fH), dim=-1).reshape(m, n * 2) - return out - - -def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): - sf_m, sf_k = a_sf_swizzled.shape - m_tiles = (m + 128 - 1) // 128 - f = block_size * 4 - k_tiles = (k + f - 1) // f - tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) - tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) - out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) - return out[0:m, 0:k] - - -def dequantize_to_dtype(tensor_fp4, - tensor_sf, - global_scale, - dtype, - device, - block_size=16): - """Dequantize the fp4 tensor back to high precision.""" - # Two fp4 values are packed into one uint8. - assert tensor_fp4.dtype == torch.uint8 - m, packed_k = tensor_fp4.shape - k = packed_k * 2 - tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) - tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) - tensor_sf = tensor_sf.view(torch.float8_e4m3fn) - tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) - tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale - - # scale the tensor - out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) - return out - def get_ref_results(a_fp4, b_fp4, a_sf, b_sf, a_global_scale, b_global_scale, m, n, dtype, block_size, device): _, m_k = a_fp4.shape _, n_k = b_fp4.shape assert (m_k == n_k) - a_in_dtype = dequantize_to_dtype(a_fp4, - a_sf, - a_global_scale, - dtype=dtype, - device=device, - block_size=block_size) - b_in_dtype = dequantize_to_dtype(b_fp4, - b_sf, - b_global_scale, - dtype=dtype, - device=device, - block_size=block_size) + a_in_dtype = dequantize_nvfp4_to_dtype(a_fp4, + a_sf, + a_global_scale, + dtype=dtype, + device=device, + block_size=block_size) + b_in_dtype = dequantize_nvfp4_to_dtype(b_fp4, + b_sf, + b_global_scale, + dtype=dtype, + device=device, + block_size=block_size) return torch.matmul(a_in_dtype, b_in_dtype.t()) diff --git a/tests/kernels/quantization/test_rocm_skinny_gemms.py b/tests/kernels/quantization/test_rocm_skinny_gemms.py index 622079c394457c5ef6eb6f0199ffd57ff0c81684..c7eee899896acea5ad15a0377d29212a4250feb3 100644 --- a/tests/kernels/quantization/test_rocm_skinny_gemms.py +++ b/tests/kernels/quantization/test_rocm_skinny_gemms.py @@ -8,7 +8,7 @@ from vllm.platforms import current_platform DTYPES = [torch.bfloat16, torch.float16] M = [16, 32, 64, 128, 256, 512, 1024, 4096, 8192] -K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192] # k % 8 == 0 +K = [8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 6144, 8192] # k % 8 == 0 N = [1, 2, 3, 4] SEEDS = [0] @@ -58,8 +58,9 @@ def test_rocm_wvsplitk_kernel(n, k, m, dtype, seed): @pytest.mark.parametrize("m", M + [28672]) # m >= 16 @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("seed", SEEDS) -@pytest.mark.skipif(not current_platform.is_rocm(), - reason="only test for rocm") +@pytest.mark.skipif( + not (current_platform.is_rocm() and current_platform.supports_fp8()), + reason="only test for rocm fp8") def test_rocm_wvsplitk_fp8_kernel(n, k, m, dtype, seed): torch.manual_seed(seed) diff --git a/tests/kernels/test_fused_quant_activation.py b/tests/kernels/test_fused_quant_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..faa8d49ce41be5990e2560e220d36b6716a94f92 --- /dev/null +++ b/tests/kernels/test_fused_quant_activation.py @@ -0,0 +1,70 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest +import torch + +import vllm._custom_ops as ops +from tests.kernels.utils import opcheck +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.platforms import current_platform + +DTYPES = [torch.bfloat16, torch.float16] +QUANT_DTYPES = [current_platform.fp8_dtype()] +NUM_TOKENS = [1, 17, 86, 1234, 3045] # Arbitrary values for testing +HIDDEN_SIZES = [16, 48, 128, 1562, 4096] # Arbitrary values for testing +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_impl(silu_and_mul: SiluAndMul, x: torch.Tensor, + scale: torch.Tensor) -> torch.Tensor: + silu_and_mul_out = silu_and_mul.forward_native(x) + out, scales = ops.scaled_fp8_quant(silu_and_mul_out, scale) + return out + + +def ops_impl(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + out_shape = (x.shape[0], x.shape[1] // 2) + out = torch.empty(out_shape, + dtype=current_platform.fp8_dtype(), + device=x.device) + torch.ops._C.silu_and_mul_quant(out, x, scale) + return out + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("quant_dtype", QUANT_DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_silu_and_mul( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + quant_dtype: torch.dtype, + seed: int, + device: str, +) -> None: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.set_default_device(device) + + layer = SiluAndMul() + + # Make inputs + scale = (torch.randn((1), device=device, dtype=torch.float32)) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + + ref_out = ref_impl(layer, x, scale) + ops_out = ops_impl(x, scale) + + assert ref_out.dtype == quant_dtype + assert ops_out.dtype == quant_dtype + assert ref_out.shape == ops_out.shape + assert torch.allclose(ref_out.to(dtype=torch.float32), + ops_out.to(dtype=torch.float32)) + opcheck(torch.ops._C.silu_and_mul_quant, (ops_out, x, scale)) diff --git a/tests/kv_transfer/test_disagg.py b/tests/kv_transfer/test_disagg.py index 5b9ea6dba401b4a47b3196bbed3d74cc3598afb2..dc948a48bf3267da4273fb8ee459cd5186e14150 100644 --- a/tests/kv_transfer/test_disagg.py +++ b/tests/kv_transfer/test_disagg.py @@ -14,8 +14,8 @@ import torch # Fixture to set up environment variables and teardown servers after tests @pytest.fixture(scope="module", autouse=True) def setup_servers(): - if torch.cuda.device_count() < 4: - pytest.skip("Skipping test: fewer than 4 GPUs available") + if torch.cuda.device_count() < 2: + pytest.skip("Skipping test: fewer than 2 GPUs available") # Set up environment variables VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'", diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index dc433f9dad260551f75c0fdc68ab0b9af3aa7b91..399311ce65bb823faaa86ba1f07845277fc54b62 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -47,7 +47,7 @@ def dist_init(): temp_file = tempfile.mkstemp()[1] backend = "nccl" - if current_platform.is_cpu(): + if current_platform.is_cpu() or current_platform.is_tpu(): backend = "gloo" init_distributed_environment(world_size=1, @@ -139,6 +139,12 @@ def dummy_model_gate_up() -> nn.Module: return model +@pytest.fixture(scope="session") +def llama_2_7b_base_huggingface_id(): + # used as a base model for testing with sql lora adapter + return "meta-llama/Llama-2-7b-hf" + + @pytest.fixture(scope="session") def sql_lora_huggingface_id(): # huggingface repo id is used to test lora runtime downloading. @@ -198,6 +204,12 @@ def qwen2vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen2-vl-lora-pokemon") +@pytest.fixture(scope="session") +def qwen25vl_base_huggingface_id(): + # used as a base model for testing with qwen25vl lora adapter + return "Qwen/Qwen2.5-VL-3B-Instruct" + + @pytest.fixture(scope="session") def qwen25vl_lora_files(): return snapshot_download(repo_id="jeeejeee/qwen25-vl-lora-pokemon") @@ -261,8 +273,8 @@ def run_with_both_engines_lora(request, monkeypatch): @pytest.fixture def reset_default_device(): """ - Some tests, such as `test_punica_ops.py`, explicitly set the - default device, which can affect subsequent tests. Adding this fixture + Some tests, such as `test_punica_ops.py`, explicitly set the + default device, which can affect subsequent tests. Adding this fixture helps avoid this problem. """ original_device = torch.get_default_device() diff --git a/tests/lora/test_lora_allowed_token_ids.py b/tests/lora/test_lora_allowed_token_ids.py new file mode 100644 index 0000000000000000000000000000000000000000..094541aef02bb893cf2d77211b90db713374e8bb --- /dev/null +++ b/tests/lora/test_lora_allowed_token_ids.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from vllm.config import (CacheConfig, DeviceConfig, LoRAConfig, ModelConfig, + VllmConfig) +from vllm.lora.request import LoRARequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs +from vllm.v1.engine.processor import Processor + + +def test_allowed_token_ids_with_lora_vocab(llama_2_7b_base_huggingface_id, + sql_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that define additional tokens. + """ + + # Setup a base model compatible with the sql_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=llama_2_7b_base_huggingface_id, + tokenizer=llama_2_7b_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(sql_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens added in the lora adapter should not raise an error + lora_token_ids = [32000, 32001, 32002, 32003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + lora_request=lora_request) + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens not in the lora adapter should raise an error + invalid_token_ids = [35000, 35001, 35002, 35003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) + + # tokens in the lora adapter with no lora request should raise an error + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=lora_token_ids), + ) + + +def test_allowed_token_ids_with_lora_adapter_no_vocab( + qwen25vl_base_huggingface_id, qwen25vl_lora_files): + """ + Test that we properly resolve the range of allowed token ids for lora + adapters that do not define additional tokens. + """ + + # Setup a base model compatible with the qwen25vl_lora_files adapter and + # a known number of tokens in the base model. + model_config = ModelConfig( + model=qwen25vl_base_huggingface_id, + tokenizer=qwen25vl_base_huggingface_id, + tokenizer_mode="auto", + ) + + vllm_config = VllmConfig( + model_config=model_config, + cache_config=CacheConfig(), + device_config=DeviceConfig(), + lora_config=LoRAConfig(), + ) + + tokenizer = init_tokenizer_from_configs( + model_config=vllm_config.model_config, + scheduler_config=vllm_config.scheduler_config, + lora_config=vllm_config.lora_config) + processor = Processor(vllm_config, tokenizer) + + lora_request = LoRARequest("1", 1, str(qwen25vl_lora_files)) + request_id = "1" + prompt = "a prompt" + + # tokens in the base model should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + lora_request=lora_request) + + # tokens in the base model with no lora request should not raise an error + base_token_ids = [1000, 1001, 1002, 1003] + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=base_token_ids), + ) + + # tokens not in the base model should raise an error + invalid_token_ids = [200000, 200001, 200002, 200003] + with pytest.raises(ValueError): + processor.process_inputs( + request_id, + prompt, + params=SamplingParams(allowed_token_ids=invalid_token_ids), + lora_request=lora_request) diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 0875128c4ff1baa4e5c987ce00ea0205ef88bd24..90498c47fb10431e1727a5dc5cf1f6b16f234226 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_path = get_adapter_absolute_path(lora_name) - # lora loading should work for either absolute path and hugggingface id. + # lora loading should work for either absolute path and huggingface id. peft_helper = PEFTHelper.from_local_dir(lora_path, 4096) lora_model = LoRAModel.from_local_checkpoint( lora_path, diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py index 67f3866beff55c9b6ebfb3601bf2c129cd934fc2..0d4e0bf681f2c05ddc23af1d8d3129cf163ed51c 100644 --- a/tests/lora/test_utils.py +++ b/tests/lora/test_utils.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 from collections import OrderedDict +from typing import NamedTuple, Optional from unittest.mock import patch import pytest @@ -9,52 +10,96 @@ from torch import nn from vllm.lora.utils import (get_adapter_absolute_path, parse_fine_tuned_lora_name, replace_submodule) +from vllm.model_executor.models.utils import WeightsMapper + + +class LoRANameParserTestConfig(NamedTuple): + name: str + module_name: str + is_lora_a: bool + is_bias: bool + weights_mapper: Optional[WeightsMapper] = None def test_parse_fine_tuned_lora_name_valid(): - fixture = { - ("base_model.model.lm_head.lora_A.weight", "lm_head", True, False), - ("base_model.model.lm_head.lora_B.weight", "lm_head", False, False), - ( + fixture = [ + LoRANameParserTestConfig("base_model.model.lm_head.lora_A.weight", + "lm_head", True, False), + LoRANameParserTestConfig("base_model.model.lm_head.lora_B.weight", + "lm_head", False, False), + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_A", "model.embed_tokens", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.embed_tokens.lora_embedding_B", "model.embed_tokens", False, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", "model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", "model.layers.9.mlp.down_proj", False, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_A.weight", "language_model.layers.9.mlp.down_proj", True, False, ), - ( + LoRANameParserTestConfig( "language_model.layers.9.mlp.down_proj.lora_B.weight", "language_model.layers.9.mlp.down_proj", False, False, ), - } - for name, module_name, is_lora_a, is_bias in fixture: + # Test with WeightsMapper + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_A.weight", + "language_model.model.layers.9.mlp.down_proj", + True, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + LoRANameParserTestConfig( + "model.layers.9.mlp.down_proj.lora_B.weight", + "language_model.model.layers.9.mlp.down_proj", + False, + False, + weights_mapper=WeightsMapper( + orig_to_new_prefix={"model.": "language_model.model."}), + ), + ] + for name, module_name, is_lora_a, is_bias, weights_mapper in fixture: assert (module_name, is_lora_a, - is_bias) == parse_fine_tuned_lora_name(name) + is_bias) == parse_fine_tuned_lora_name(name, weights_mapper) def test_parse_fine_tuned_lora_name_invalid(): diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 30b74ce3ef70ed6f9ed6b328853cb5a26ad13b89..e5ae660af1400aa6bf7b4e680116681a8bdc57b1 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -58,13 +58,19 @@ def test_worker_apply_lora(sql_lora_files): download_dir=None, load_format="dummy", ), - parallel_config=ParallelConfig(1, 1, False), + parallel_config=ParallelConfig( + pipeline_parallel_size=1, + tensor_parallel_size=1, + data_parallel_size=1, + ), scheduler_config=SchedulerConfig("generate", 32, 32, 32), device_config=DeviceConfig("cuda"), - cache_config=CacheConfig(block_size=16, - gpu_memory_utilization=1., - swap_space=0, - cache_dtype="auto"), + cache_config=CacheConfig( + block_size=16, + gpu_memory_utilization=1.0, + swap_space=0, + cache_dtype="auto", + ), lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), ) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index 2d9cf1d48fd5fefad6ccb234df6e286bb1d9c53c..e957db5b3f16a0425cdf2ce8a2c6c302dc35b05d 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -1,21 +1,22 @@ # SPDX-License-Identifier: Apache-2.0 import pytest +import torch from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) -from vllm.model_executor.layers.fused_moe.fused_moe import ( - dispatch_fused_experts_func, dispatch_topk_func, - torch_vllm_inplace_fused_experts, torch_vllm_outplace_fused_experts, - vllm_topk_softmax) +from vllm.model_executor.layers.fused_moe.fused_moe import (dispatch_topk_func, + vllm_topk_softmax) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.layernorm import ( RMSNorm, dispatch_cuda_rmsnorm_func, fused_add_rms_norm, rms_norm, rocm_aiter_fused_add_rms_norm, rocm_aiter_rms_norm) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + cutlass_scaled_mm, dispatch_w8a8_blockscale_func, w8a8_block_fp8_matmul) from vllm.platforms import current_platform @@ -98,35 +99,45 @@ def test_enabled_ops_invalid(env: str): RMSNorm(1024).enabled() +@pytest.mark.skipif( + not current_platform.is_rocm() or not current_platform.is_fp8_fnuz(), + reason="AITER is a feature exclusive for ROCm and FP8_FNUZ") +@pytest.mark.parametrize("use_cutlass", [True, False]) @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): +@pytest.mark.parametrize("use_rocm_aiter_gemm_w8a8_blockscale", ["0", "1"]) +def test_w8a8_blockscale_dispatch(use_cutlass: bool, use_rocm_aiter: str, + use_rocm_aiter_gemm_w8a8_blockscale: str, + monkeypatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) - topk_func = dispatch_topk_func() - is_rocm_aiter_moe_enabled.cache_clear() - if current_platform.is_rocm() and int(use_rocm_aiter): - from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_topk_softmax) - assert topk_func == rocm_aiter_topk_softmax + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", + use_rocm_aiter_gemm_w8a8_blockscale) + + use_aiter_and_is_supported = (bool(int(use_rocm_aiter)) and bool( + int(use_rocm_aiter_gemm_w8a8_blockscale))) + block_scale_func = dispatch_w8a8_blockscale_func( + use_cutlass, use_aiter_and_is_supported=use_aiter_and_is_supported) + if use_cutlass: + assert block_scale_func == cutlass_scaled_mm + elif current_platform.is_rocm() and int(use_rocm_aiter) and int( + use_rocm_aiter_gemm_w8a8_blockscale): + assert block_scale_func == ( + torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale) else: - assert topk_func == vllm_topk_softmax + assert block_scale_func == w8a8_block_fp8_matmul @pytest.mark.parametrize("use_rocm_aiter", ["0", "1"]) -@pytest.mark.parametrize("inplace", [True, False]) -def test_fused_experts_dispatch(use_rocm_aiter: str, inplace: bool, - monkeypatch): - +def test_topk_dispatch(use_rocm_aiter: str, monkeypatch): monkeypatch.setenv("VLLM_ROCM_USE_AITER", use_rocm_aiter) + topk_func = dispatch_topk_func() is_rocm_aiter_moe_enabled.cache_clear() - fused_experts_func = dispatch_fused_experts_func(inplace) if current_platform.is_rocm() and int(use_rocm_aiter): from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( - rocm_aiter_fused_experts) - assert fused_experts_func == rocm_aiter_fused_experts - elif inplace: - assert fused_experts_func == torch_vllm_inplace_fused_experts + rocm_aiter_topk_softmax) + assert topk_func == rocm_aiter_topk_softmax else: - assert fused_experts_func == torch_vllm_outplace_fused_experts + assert topk_func == vllm_topk_softmax @pytest.mark.parametrize("add_residual", [True, False]) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 59da575e37b18ab6bf9cf9634c9c1745b2ca0ba7..6cd966f84802bc703d9bef89228d77927111869c 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -202,12 +202,15 @@ def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): def test_guided_decoding_backend_options(): """Test backend-specific options""" - params = GuidedDecodingParams( - backend="xgrammar:option-1,option-2,option-3") - assert params.backend_options() == ["option-1", "option-2", "option-3"] - - no_fallback = GuidedDecodingParams(backend="xgrammar:option-1,no-fallback") - assert no_fallback.no_fallback() + with pytest.warns(DeprecationWarning): + guided_decoding_params = GuidedDecodingParams( + backend= + "xgrammar:no-fallback,disable-any-whitespace,no-additional-properties" + ) + assert guided_decoding_params.backend == "xgrammar" + assert guided_decoding_params.disable_fallback + assert guided_decoding_params.disable_any_whitespace + assert guided_decoding_params.disable_additional_properties def test_pickle_xgrammar_tokenizer_data(): diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 11dfe4d4995d51bf881e6941413561fbaeef906a..bdaba22c3c7a8962485bf976d671f8d6b077a0af 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -20,11 +20,11 @@ def test_hf_transfer_auto_activation(): try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa - HF_TRANFER_ACTIVE = True + HF_TRANSFER_ACTIVE = True except ImportError: - HF_TRANFER_ACTIVE = False + HF_TRANSFER_ACTIVE = False assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANFER_ACTIVE) + HF_TRANSFER_ACTIVE) def test_download_weights_from_hf(): diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py deleted file mode 100644 index 6d4df2c265c4d7f168f1c9467a698ab090a653fc..0000000000000000000000000000000000000000 --- a/tests/models/embedding/utils.py +++ /dev/null @@ -1,66 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -from collections.abc import Sequence -from typing import NamedTuple, Optional - -import torch -import torch.nn.functional as F - - -def check_embeddings_close( - *, - embeddings_0_lst: Sequence[list[float]], - embeddings_1_lst: Sequence[list[float]], - name_0: str, - name_1: str, - tol: float = 1e-3, -) -> None: - assert len(embeddings_0_lst) == len(embeddings_1_lst) - - for prompt_idx, (embeddings_0, embeddings_1) in enumerate( - zip(embeddings_0_lst, embeddings_1_lst)): - assert len(embeddings_0) == len(embeddings_1), ( - f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}") - - sim = F.cosine_similarity(torch.tensor(embeddings_0), - torch.tensor(embeddings_1), - dim=0) - - fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{embeddings_0[:16]!r}" - f"\n{name_1}:\t{embeddings_1[:16]!r}") - - assert sim >= 1 - tol, fail_msg - - -def matryoshka_fy(tensor, dimensions): - tensor = torch.tensor(tensor) - tensor = tensor[..., :dimensions] - tensor = F.normalize(tensor, p=2, dim=1) - return tensor - - -class EmbedModelInfo(NamedTuple): - name: str - is_matryoshka: bool - matryoshka_dimensions: Optional[list[int]] = None - architecture: str = "" - enable_test: bool = True - - -def correctness_test(hf_model, - inputs, - vllm_outputs: Sequence[list[float]], - dimensions: Optional[int] = None): - - hf_outputs = hf_model.encode(inputs) - if dimensions: - hf_outputs = matryoshka_fy(hf_outputs, dimensions) - - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) diff --git a/tests/models/encoder_decoder/vision_language/__init__.py b/tests/models/encoder_decoder/vision_language/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/tests/models/encoder_decoder/vision_language/test_broadcast.py b/tests/models/encoder_decoder/vision_language/test_broadcast.py deleted file mode 100644 index 8d986414eec863998b6598789415b1198e684616..0000000000000000000000000000000000000000 --- a/tests/models/encoder_decoder/vision_language/test_broadcast.py +++ /dev/null @@ -1,37 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 - -import pytest - -from ....utils import multi_gpu_test - - -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) -@pytest.mark.parametrize("model", [ - "meta-llama/Llama-3.2-11B-Vision-Instruct", -]) -def test_models(hf_runner, vllm_runner, image_assets, - distributed_executor_backend, model) -> None: - - dtype = "half" - max_tokens = 5 - num_logprobs = 5 - tensor_parallel_size = 2 - - if model.startswith("meta-llama/Llama-3.2-11B-Vision-Instruct"): - from .test_mllama import models, run_test - else: - raise NotImplementedError(f"Unsupported model: {model}") - - run_test( - hf_runner, - vllm_runner, - image_assets, - model=models[0], - size_factors=[0.25, 0.5, 1.0], - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - ) diff --git a/tests/models/decoder_only/__init__.py b/tests/models/language/__init__.py similarity index 100% rename from tests/models/decoder_only/__init__.py rename to tests/models/language/__init__.py diff --git a/tests/models/decoder_only/audio_language/__init__.py b/tests/models/language/generation/__init__.py similarity index 100% rename from tests/models/decoder_only/audio_language/__init__.py rename to tests/models/language/generation/__init__.py diff --git a/tests/models/encoder_decoder/language/test_bart.py b/tests/models/language/generation/test_bart.py similarity index 98% rename from tests/models/encoder_decoder/language/test_bart.py rename to tests/models/language/generation/test_bart.py index e8070d28befa6ab868b6feee5522b8b0c1721cd5..8ab0167dc771d2b8d50b6bedcb33de01cfa23acb 100644 --- a/tests/models/encoder_decoder/language/test_bart.py +++ b/tests/models/language/generation/test_bart.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM for BART models using greedy sampling. - -Run `pytest tests/models/encoder_decoder/language/test_bart.py`. -""" from typing import Optional import pytest diff --git a/tests/models/decoder_only/language/test_models.py b/tests/models/language/generation/test_common.py similarity index 77% rename from tests/models/decoder_only/language/test_models.py rename to tests/models/language/generation/test_common.py index d35d87459cd98212509947d9d15d5fc5191d85d2..05dd18fbdf8b341538add0f68a3aa77b77ab53f3 100644 --- a/tests/models/decoder_only/language/test_models.py +++ b/tests/models/language/generation/test_common.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM when using greedy sampling. - -Run `pytest tests/models/test_models.py`. -""" +import os +from typing import Optional import pytest import torch @@ -29,7 +27,8 @@ AITER_MODEL_LIST = [ "openbmb/MiniCPM3-4B", "Qwen/Qwen-7B-Chat", "Qwen/Qwen2.5-0.5B-Instruct", - "ehristoforu/Falcon3-MoE-2x7B-Insruct", + "TitanML/tiny-mixtral", + "Qwen/Qwen3-8B", ] @@ -80,12 +79,14 @@ AITER_MODEL_LIST = [ "Qwen/Qwen2.5-0.5B-Instruct", # qwen2 marks=[pytest.mark.core_model], ), + pytest.param( + "Qwen/Qwen3-8B", # qwen (text-only) + ), pytest.param("stabilityai/stablelm-3b-4e1t"), # stablelm pytest.param("bigcode/starcoder2-3b"), # starcoder2 pytest.param( - "ehristoforu/Falcon3-MoE-2x7B-Insruct", # mixtral - marks=[pytest.mark.cpu_model, - large_gpu_mark(min_gb=48)], + "TitanML/tiny-mixtral", # mixtral + marks=[pytest.mark.cpu_model], ) ]) @pytest.mark.parametrize("max_tokens", [32]) @@ -112,19 +113,38 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, # in parts of the operators pytest.skip(f"Skipping '{model}' model test with AITER kernel.") + use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0" + with hf_runner(model) as hf_model: hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) + prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds + else None) + + prompt_token_ids = [] + for prompt in example_prompts: + token_ids = hf_model.tokenizer(prompt, + return_tensors="pt").input_ids.to( + hf_model.model.device) + prompt_token_ids.append(token_ids) + if prompt_embeds is not None: + prompt_embeds.append(hf_model.model.get_input_embeddings()( + token_ids).squeeze(0)) + with vllm_runner( model, tokenizer_name=model_info.tokenizer or model, tokenizer_mode=model_info.tokenizer_mode, trust_remote_code=model_info.trust_remote_code, max_num_seqs=2, + enable_prompt_embeds=use_prompt_embeds, ) as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + if prompt_embeds is not None: + vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs( + prompt_embeds, max_tokens, num_logprobs) check_logprobs_close( outputs_0_lst=hf_outputs, @@ -132,6 +152,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str, name_0="hf", name_1="vllm", ) + if prompt_embeds is not None: + check_logprobs_close( + outputs_0_lst=vllm_outputs, + outputs_1_lst=vllm_outputs_from_embeds, + name_0="vllm", + name_1="vllm_from_embeds", + ) + if use_rocm_aiter: # this is to ensure that vllm engine # has deallocated the memory before running the next diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/language/generation/test_granite.py similarity index 89% rename from tests/models/decoder_only/language/test_granite.py rename to tests/models/language/generation/test_granite.py index 119b79d64c9696d6727d83f0dd99c69fc2138975..f381c34f44b8ca546f00f268ffade99ee303c0b5 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/language/generation/test_granite.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. - -Run `pytest tests/models/test_granite.py`. -""" import pytest from ...utils import check_logprobs_close diff --git a/tests/models/language/generation/test_granitemoehybrid.py b/tests/models/language/generation/test_granitemoehybrid.py new file mode 100644 index 0000000000000000000000000000000000000000..da3f5e1100bfd37b6d419f7707ae4d795857d14f --- /dev/null +++ b/tests/models/language/generation/test_granitemoehybrid.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from ...utils import check_logprobs_close + +# Path of the checkpoints +MODELS = [ + "ibm-granite/granite-4.0-tiny-preview", +] + + +@pytest.mark.skip( + reason="Granite 4.0 is not yet available in huggingface transformers") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float16", "bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_model_equivalence_to_hf_greedy( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +): + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/decoder_only/language/test_hybrid.py b/tests/models/language/generation/test_hybrid.py similarity index 94% rename from tests/models/decoder_only/language/test_hybrid.py rename to tests/models/language/generation/test_hybrid.py index 5931c25b8d8082087d9385bacbb9124c7c097cb0..9b7a42acece5974df5fd2d5ee2e8b19aa0f43a39 100644 --- a/tests/models/decoder_only/language/test_hybrid.py +++ b/tests/models/language/generation/test_hybrid.py @@ -23,12 +23,15 @@ SSM_MODELS = [ HYBRID_MODELS = [ "ai21labs/Jamba-tiny-dev", + # NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as + # it is not yet available in huggingface transformers + # "ibm-granite/granite-4.0-tiny-preview", # NOTE: Running Plamo2 in transformers implementation requires to install # causal-conv1d package, which is not listed as a test dependency as it's # not compatible with pip-compile. "pfnet/plamo-2-1b", "Zyphra/Zamba2-1.2B-instruct", - "ibm-ai-platform/Bamba-9B", + "hmellor/bamba-tiny-random", ] # Avoid OOM @@ -289,23 +292,25 @@ def test_multistep_correctness( @multi_gpu_test(num_gpus=2) @pytest.mark.parametrize("model", [SSM_MODELS[0], HYBRID_MODELS[0]]) @pytest.mark.parametrize("max_tokens", [64]) -def test_hybrid_distributed_produces_identical_generation( +@pytest.mark.parametrize("num_logprobs", [5]) +def test_distributed_correctness( vllm_runner, example_prompts, model: str, max_tokens: int, + num_logprobs: int, ) -> None: - with vllm_runner(model, tensor_parallel_size=2, + with vllm_runner(model, tensor_parallel_size=1, max_num_seqs=2) as vllm_model: - vllm_outputs_tp_2 = vllm_model.generate_greedy(example_prompts, - max_tokens) + vllm_outputs_tp_1 = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, tensor_parallel_size=1, + with vllm_runner(model, tensor_parallel_size=2, max_num_seqs=2) as vllm_model: - vllm_outputs_tp_1 = vllm_model.generate_greedy(example_prompts, - max_tokens) + vllm_outputs_tp_2 = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) - check_outputs_equal( + check_logprobs_close( outputs_0_lst=vllm_outputs_tp_1, outputs_1_lst=vllm_outputs_tp_2, name_0="vllm_tp_1", diff --git a/tests/models/decoder_only/language/test_mistral.py b/tests/models/language/generation/test_mistral.py similarity index 98% rename from tests/models/decoder_only/language/test_mistral.py rename to tests/models/language/generation/test_mistral.py index 79778072cc8b31f4b518b3ae84c5c18ac08fd1fa..c1b612ae213b97b2035007584fcbf7b9f77a90e4 100644 --- a/tests/models/decoder_only/language/test_mistral.py +++ b/tests/models/language/generation/test_mistral.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. - -Run `pytest tests/models/test_mistral.py`. -""" import copy import json diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/language/generation/test_phimoe.py similarity index 96% rename from tests/models/decoder_only/language/test_phimoe.py rename to tests/models/language/generation/test_phimoe.py index f9757d6ac295ebb3ab740434713bd1a6c897576e..603ca1cb12a5b61f7bf75d327a6d5b750c12745c 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/language/generation/test_phimoe.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the outputs of HF and vLLM for moe models using greedy sampling. - -Run `pytest tests/models/test_phimoe.py`. -""" import pytest import torch diff --git a/tests/models/decoder_only/language/__init__.py b/tests/models/language/pooling/__init__.py similarity index 100% rename from tests/models/decoder_only/language/__init__.py rename to tests/models/language/pooling/__init__.py diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7de2a9af2f2edfcc92bdebc0915ace9cfd422fa9 --- /dev/null +++ b/tests/models/language/pooling/mteb_utils.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: Apache-2.0 +import math +from collections.abc import Sequence + +import mteb +import numpy as np +import pytest + +from tests.models.utils import EmbedModelInfo +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +# Most models on the STS12 task (See #17175): +# - Model implementation and minor changes in tensor dtype +# results in differences less than 1e-4 +# - Different model results in differences more than 1e-3 +# 1e-4 is a good tolerance threshold +MTEB_EMBED_TASKS = ["STS12"] +MTEB_EMBED_TOL = 1e-4 + + +class VllmMtebEncoder(mteb.Encoder): + + def __init__(self, vllm_model): + super().__init__() + self.model = vllm_model + self.rng = np.random.default_rng(seed=42) + + def encode( + self, + sentences: Sequence[str], + *args, + **kwargs, + ) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + outputs = self.model.encode(sentences, use_tqdm=False) + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + +class OpenAIClientMtebEncoder(mteb.Encoder): + + def __init__(self, model_name: str, client): + super().__init__() + self.model_name = model_name + self.client = client + self.rng = np.random.default_rng(seed=42) + + def encode(self, sentences: Sequence[str], *args, **kwargs) -> np.ndarray: + # Hoping to discover potential scheduling + # issues by randomizing the order. + r = self.rng.permutation(len(sentences)) + sentences = [sentences[i] for i in r] + + embeddings = self.client.embeddings.create(model=self.model_name, + input=sentences) + outputs = [d.embedding for d in embeddings.data] + embeds = np.array(outputs) + embeds = embeds[np.argsort(r)] + return embeds + + +def run_mteb_embed_task(encoder, tasks): + tasks = mteb.get_tasks(tasks=tasks) + evaluation = mteb.MTEB(tasks=tasks) + results = evaluation.run(encoder, verbosity=0, output_folder=None) + + main_score = results[0].scores["test"][0]["main_score"] + return main_score + + +def run_mteb_embed_task_st(model_name, tasks): + from sentence_transformers import SentenceTransformer + model = SentenceTransformer(model_name) + return run_mteb_embed_task(model, tasks) + + +def mteb_test_embed_models(hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None): + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + vllm_extra_kwargs = vllm_extra_kwargs or {} + + with vllm_runner(model_info.name, + task="embed", + max_model_len=None, + dtype=model_info.dtype, + **vllm_extra_kwargs) as vllm_model: + + if model_info.architecture: + assert (model_info.architecture + in vllm_model.model.llm_engine.model_config.architectures) + + vllm_main_score = run_mteb_embed_task(VllmMtebEncoder(vllm_model), + MTEB_EMBED_TASKS) + vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + model_dtype = getattr( + vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", + vllm_dtype) + + with set_default_torch_dtype(model_dtype) and hf_runner( + model_info.name, is_sentence_transformer=True, + dtype=model_dtype) as hf_model: + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) + + print("VLLM:", vllm_dtype, vllm_main_score) + print("SentenceTransformer:", model_dtype, st_main_score) + print("Difference:", st_main_score - vllm_main_score) + + assert math.isclose(st_main_score, vllm_main_score, rel_tol=MTEB_EMBED_TOL) diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/language/pooling/test_classification.py similarity index 91% rename from tests/models/embedding/language/test_cls_models.py rename to tests/models/language/pooling/test_classification.py index 6a3cd8a5c594e70ca978593028de07260ad5546d..44af3df08a8673370006208e6c67119c7be4a1fe 100644 --- a/tests/models/embedding/language/test_cls_models.py +++ b/tests/models/language/pooling/test_classification.py @@ -1,8 +1,4 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the classification outputs of HF and vLLM models. - -Run `pytest tests/models/test_cls_models.py`. -""" import pytest import torch from transformers import AutoModelForSequenceClassification @@ -19,7 +15,7 @@ from vllm.platforms import current_platform ) @pytest.mark.parametrize("dtype", ["half"] if current_platform.is_rocm() else ["float"]) -def test_classification_models( +def test_models( hf_runner, vllm_runner, example_prompts, diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/language/pooling/test_embedding.py similarity index 94% rename from tests/models/embedding/language/test_embedding.py rename to tests/models/language/pooling/test_embedding.py index 5deb35fa321089f46aa56267cc3a2d0422034341..9db385e77bdbb75add7e1191cbbf4540cee7436a 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/language/pooling/test_embedding.py @@ -1,14 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the embedding outputs of HF and vLLM models. - -Run `pytest tests/models/embedding/language/test_embedding.py`. -""" import pytest from vllm.config import PoolerConfig from vllm.platforms import current_platform -from ..utils import check_embeddings_close +from ...utils import check_embeddings_close @pytest.mark.parametrize( diff --git a/tests/models/embedding/language/test_gritlm.py b/tests/models/language/pooling/test_gritlm.py similarity index 60% rename from tests/models/embedding/language/test_gritlm.py rename to tests/models/language/pooling/test_gritlm.py index 87a1dde9381fdbb60e3aff372c7fdd3a0319a2a7..7dd3c8a4e79e2b34d8aed724d48e6698e5462c31 100644 --- a/tests/models/embedding/language/test_gritlm.py +++ b/tests/models/language/pooling/test_gritlm.py @@ -7,12 +7,10 @@ from array import array import openai import pytest -import pytest_asyncio from scipy.spatial.distance import cosine -import vllm -import vllm.config -from vllm.utils import STR_BACKEND_ENV_VAR +from vllm import LLM, SamplingParams +from vllm.config import ModelConfig from ....utils import RemoteOpenAIServer @@ -31,73 +29,45 @@ def _arr(arr): return array("i", arr) -def test_find_array(monkeypatch: pytest.MonkeyPatch): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - - from vllm.model_executor.models.gritlm import GritLMPooler - - # Create an LLM object to get the model config. - llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) - pooler = GritLMPooler(model_config=llm.llm_engine.model_config) - - arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) - - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 - assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 - assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 - - with pytest.raises(ValueError): - pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) - - -@pytest.fixture(scope="module") -def server_embedding(): - # GritLM embedding implementation is only supported by XFormers backend. - args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] - with pytest.MonkeyPatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server - - -@pytest.fixture(scope="module") -def server_generate(): - args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] - with pytest.MonkeyPatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - yield remote_server +def test_find_array(): + from vllm.model_executor.models.gritlm import GritLMPooler + model_config = ModelConfig( + MODEL_NAME, + task="embed", + tokenizer=MODEL_NAME, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="bfloat16", + seed=0, + ) + pooler = GritLMPooler(model_config=model_config) -@pytest_asyncio.fixture -async def client_embedding(server_embedding: RemoteOpenAIServer): - async with server_embedding.get_async_client() as async_client: - yield async_client + arr = _arr([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=0) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=1) == 3 + assert pooler._find_array(arr, _arr([3, 4, 5]), start_idx=5) == -1 + assert pooler._find_array(arr, _arr([3, 5]), start_idx=0) == -1 -@pytest_asyncio.fixture -async def client_generate(server_generate: RemoteOpenAIServer): - async with server_generate.get_async_client() as async_client: - yield async_client + with pytest.raises(ValueError): + pooler._find_array(arr, _arr([3, 4, 5]), start_idx=-1) def run_llm_encode( - llm: vllm.LLM, + llm: LLM, queries: list[str], instruction: str, -) -> list[float]: - outputs = llm.encode([instruction + q for q in queries], ) +) -> list[list[float]]: + outputs = llm.embed([instruction + q for q in queries]) return [output.outputs.embedding for output in outputs] async def run_client_embeddings( - client: vllm.LLM, + client: openai.AsyncOpenAI, queries: list[str], instruction: str, -) -> list[float]: +) -> list[list[float]]: outputs = await client.embeddings.create( model=MODEL_NAME, input=[instruction + q for q in queries], @@ -132,7 +102,7 @@ def get_test_data(): return queries, q_instruction, documents, d_instruction -def validate_embed_output(q_rep: list[float], d_rep: list[float]): +def validate_embed_output(q_rep: list[list[float]], d_rep: list[list[float]]): cosine_sim_q0_d0 = 1 - cosine(q_rep[0], d_rep[0]) assert math.isclose(cosine_sim_q0_d0, 0.609, abs_tol=0.001) @@ -143,17 +113,18 @@ def validate_embed_output(q_rep: list[float], d_rep: list[float]): assert math.isclose(cosine_sim_q1_d0, 0.120, abs_tol=0.001) cosine_sim_q1_d1 = 1 - cosine(q_rep[1], d_rep[1]) - assert math.isclose(cosine_sim_q1_d1, 0.532, abs_tol=0.001) - + assert math.isclose(cosine_sim_q1_d1, 0.534, abs_tol=0.001) -def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch): - # GritLM embedding implementation is only supported by XFormers backend. - with monkeypatch.context() as m: - m.setenv(STR_BACKEND_ENV_VAR, "XFORMERS") - queries, q_instruction, documents, d_instruction = get_test_data() +def test_gritlm_offline_embedding(vllm_runner): + queries, q_instruction, documents, d_instruction = get_test_data() - llm = vllm.LLM(MODEL_NAME, task="embed", max_model_len=MAX_MODEL_LEN) + with vllm_runner( + MODEL_NAME, + task="embed", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model d_rep = run_llm_encode( llm, @@ -166,47 +137,62 @@ def test_gritlm_offline_embedding(monkeypatch: pytest.MonkeyPatch): q_instruction, ) - validate_embed_output(q_rep, d_rep) + validate_embed_output(q_rep, d_rep) @pytest.mark.asyncio -async def test_gritlm_api_server_embedding( - client_embedding: openai.AsyncOpenAI, ): +async def test_gritlm_api_server_embedding(): queries, q_instruction, documents, d_instruction = get_test_data() - d_rep = await run_client_embeddings( - client_embedding, - documents, - d_instruction, - ) - q_rep = await run_client_embeddings( - client_embedding, - queries, - q_instruction, - ) + args = ["--task", "embed", "--max_model_len", str(MAX_MODEL_LEN)] + + with RemoteOpenAIServer(MODEL_NAME, args) as server: + client_embedding = server.get_async_client() + + d_rep = await run_client_embeddings( + client_embedding, + documents, + d_instruction, + ) + q_rep = await run_client_embeddings( + client_embedding, + queries, + q_instruction, + ) validate_embed_output(q_rep, d_rep) -def test_gritlm_offline_gen(): +def test_gritlm_offline_generate(monkeypatch: pytest.MonkeyPatch, vllm_runner): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - llm = vllm.LLM(MODEL_NAME, max_model_len=MAX_MODEL_LEN) - sampling_params = vllm.SamplingParams(temperature=0.0, max_tokens=256) - outputs = llm.generate(input, sampling_params=sampling_params) + with vllm_runner( + MODEL_NAME, + task="generate", + max_model_len=MAX_MODEL_LEN, + ) as vllm_model: + llm = vllm_model.model + + sampling_params = SamplingParams(temperature=0.0, max_tokens=256) + outputs = llm.generate(input, sampling_params=sampling_params) assert outputs[0].outputs[0].text == "The capital of France is Paris." @pytest.mark.asyncio -async def test_gritlm_api_server_gen(client_generate: openai.AsyncOpenAI): +async def test_gritlm_api_server_generate(): input = "<|user|>\nWhat is the capital of France?\n<|assistant|>\n" - outputs = await client_generate.completions.create( - model=MODEL_NAME, - prompt=input, - max_tokens=256, - temperature=0.0, - ) + args = ["--task", "generate", "--max_model_len", str(MAX_MODEL_LEN)] + + with RemoteOpenAIServer(MODEL_NAME, args) as server: + client_generate = server.get_async_client() + + outputs = await client_generate.completions.create( + model=MODEL_NAME, + prompt=input, + max_tokens=256, + temperature=0.0, + ) assert outputs.choices[0].text == "The capital of France is Paris." diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py new file mode 100644 index 0000000000000000000000000000000000000000..3ccf2999664c22d156d8dd0f6b5ea225b7325922 --- /dev/null +++ b/tests/models/language/pooling/test_gte.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import pytest + +from ...utils import EmbedModelInfo, run_embedding_correctness_test + +MODELS = [ + ########## BertModel + EmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + dtype="float32", + enable_test=True), + EmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + ########### NewModel + EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), + ########### Qwen2ForCausalLM + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=False), + ########## ModernBertModel + EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + pytest.skip("Skipping mteb test.") + + from .mteb_utils import mteb_test_embed_models + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} + + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + mteb_test_embed_models(hf_runner, vllm_runner, model_info, + vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, + example_prompts) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") + + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} + + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + with vllm_runner(model_info.name, + task="embed", + dtype=model_info.dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/embedding/language/test_jina.py b/tests/models/language/pooling/test_jina.py similarity index 95% rename from tests/models/embedding/language/test_jina.py rename to tests/models/language/pooling/test_jina.py index 1e234368f3b317a5569bfabef841a4cded4793c1..5287ca37c0fb54cde57a21e97ac5edde140b8d5c 100644 --- a/tests/models/embedding/language/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -1,16 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 -# ruff: noqa: E501 -"""Compare the scoring outputs of HF and vLLM models. - -Run `pytest tests/models/embedding/language/test_jina.py`. -""" import math import pytest -from tests.models.embedding.utils import check_embeddings_close, matryoshka_fy from vllm import PoolingParams +from ...utils import check_embeddings_close, matryoshka_fy + SCORING_MODELS = [ "jinaai/jina-reranker-v2-base-multilingual", # Roberta ] @@ -21,9 +17,9 @@ TEXTS_2 = [ "Organic skincare for sensitive skin with aloe vera and chamomile.", "New makeup trends focus on bold colors and innovative techniques", "Bio-Hautpflege für empfindliche Haut mit Aloe Vera und Kamille", - "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken", - "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla", - "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras", + "Neue Make-up-Trends setzen auf kräftige Farben und innovative Techniken", # noqa: E501 + "Cuidado de la piel orgánico para piel sensible con aloe vera y manzanilla", # noqa: E501 + "Las nuevas tendencias de maquillaje se centran en colores vivos y técnicas innovadoras", # noqa: E501 "针对敏感肌专门设计的天然有机护肤产品", "新的化妆趋势注重鲜艳的颜色和创新的技巧", "敏感肌のために特別に設計された天然有機スキンケア製品", diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py new file mode 100644 index 0000000000000000000000000000000000000000..6e9de30f977dbaba3a077c97e9b353792889ee4e --- /dev/null +++ b/tests/models/language/pooling/test_nomic.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from ...utils import EmbedModelInfo, run_embedding_correctness_test + +MODELS = [ + EmbedModelInfo("nomic-ai/nomic-embed-text-v1", + architecture="NomicBertModel", + dtype="float32", + enable_test=True), + EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5", + architecture="NomicBertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", + architecture="NomicBertModel", + dtype="float32", + enable_test=True) +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + pytest.skip("Skipping mteb test.") + from .mteb_utils import mteb_test_embed_models + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, + example_prompts) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") + + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + + with vllm_runner(model_info.name, + task="embed", + dtype=model_info.dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/language/pooling/test_scoring.py similarity index 72% rename from tests/models/embedding/language/test_scoring.py rename to tests/models/language/pooling/test_scoring.py index d6408258ffce9cece5961786130eab5a7308f606..e9527700c3ca2da0f4fb1c0e6d21f6639261999d 100644 --- a/tests/models/embedding/language/test_scoring.py +++ b/tests/models/language/pooling/test_scoring.py @@ -1,15 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the scoring outputs of HF and vLLM models. - -Run `pytest tests/models/embedding/language/test_scoring.py`. -""" import math import pytest import torch import torch.nn.functional as F -MODELS = [ +CROSS_ENCODER_MODELS = [ "cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert "BAAI/bge-reranker-v2-m3", # Roberta ] @@ -28,21 +24,21 @@ TEXTS_2 = [ "The capital of Germany is Berlin.", ] +DTYPE = "half" + -@pytest.fixture(scope="module", params=MODELS) +@pytest.fixture(scope="module", params=CROSS_ENCODER_MODELS) def model_name(request): yield request.param -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): - +def test_cross_encoder_1_to_1(vllm_runner, hf_runner, model_name): text_pair = [TEXTS_1[0], TEXTS_2[0]] - with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict([text_pair]).tolist() - with vllm_runner(model_name, task="score", dtype=dtype, + with vllm_runner(model_name, task="score", dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) @@ -52,18 +48,16 @@ def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): - +def test_cross_encoder_1_to_N(vllm_runner, hf_runner, model_name): text_pairs = [ [TEXTS_1[0], TEXTS_2[0]], [TEXTS_1[0], TEXTS_2[1]], ] - with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, task="score", dtype=dtype, + with vllm_runner(model_name, task="score", dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) @@ -74,18 +68,16 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): - +def test_cross_encoder_N_to_N(vllm_runner, hf_runner, model_name): text_pairs = [ [TEXTS_1[0], TEXTS_2[0]], [TEXTS_1[1], TEXTS_2[1]], ] - with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + with hf_runner(model_name, dtype=DTYPE, is_cross_encoder=True) as hf_model: hf_outputs = hf_model.predict(text_pairs).tolist() - with vllm_runner(model_name, task="score", dtype=dtype, + with vllm_runner(model_name, task="score", dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) @@ -101,13 +93,10 @@ def emb_model_name(request): yield request.param -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_1_embedding(vllm_runner, hf_runner, emb_model_name, - dtype: str): - +def test_embedding_1_to_1(vllm_runner, hf_runner, emb_model_name): text_pair = [TEXTS_1[0], TEXTS_2[0]] - with hf_runner(emb_model_name, dtype=dtype, + with hf_runner(emb_model_name, dtype=DTYPE, is_sentence_transformer=True) as hf_model: hf_embeddings = hf_model.encode(text_pair) hf_outputs = [ @@ -116,7 +105,7 @@ def test_llm_1_to_1_embedding(vllm_runner, hf_runner, emb_model_name, with vllm_runner(emb_model_name, task="embed", - dtype=dtype, + dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) @@ -126,16 +115,13 @@ def test_llm_1_to_1_embedding(vllm_runner, hf_runner, emb_model_name, assert math.isclose(hf_outputs[0], vllm_outputs[0], rel_tol=0.01) -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_1_to_N_embedding(vllm_runner, hf_runner, emb_model_name, - dtype: str): - +def test_embedding_1_to_N(vllm_runner, hf_runner, emb_model_name): text_pairs = [ [TEXTS_1[0], TEXTS_2[0]], [TEXTS_1[0], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=dtype, + with hf_runner(emb_model_name, dtype=DTYPE, is_sentence_transformer=True) as hf_model: hf_embeddings = [ hf_model.encode(text_pair) for text_pair in text_pairs @@ -147,7 +133,7 @@ def test_llm_1_to_N_embedding(vllm_runner, hf_runner, emb_model_name, with vllm_runner(emb_model_name, task="embed", - dtype=dtype, + dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) @@ -158,16 +144,13 @@ def test_llm_1_to_N_embedding(vllm_runner, hf_runner, emb_model_name, assert math.isclose(hf_outputs[1], vllm_outputs[1], rel_tol=0.01) -@pytest.mark.parametrize("dtype", ["half"]) -def test_llm_N_to_N_embedding(vllm_runner, hf_runner, emb_model_name, - dtype: str): - +def test_embedding_N_to_N(vllm_runner, hf_runner, emb_model_name): text_pairs = [ [TEXTS_1[0], TEXTS_2[0]], [TEXTS_1[1], TEXTS_2[1]], ] - with hf_runner(emb_model_name, dtype=dtype, + with hf_runner(emb_model_name, dtype=DTYPE, is_sentence_transformer=True) as hf_model: hf_embeddings = [ hf_model.encode(text_pair) for text_pair in text_pairs @@ -179,7 +162,7 @@ def test_llm_N_to_N_embedding(vllm_runner, hf_runner, emb_model_name, with vllm_runner(emb_model_name, task="embed", - dtype=dtype, + dtype=DTYPE, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) diff --git a/tests/models/embedding/language/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py similarity index 55% rename from tests/models/embedding/language/test_snowflake_arctic_embed.py rename to tests/models/language/pooling/test_snowflake_arctic_embed.py index 2b884fceec80ce332d2ae4a451aac62d97c552cf..7d9c3c73d8529dc0cdb346b0cf31f3ae4bd8aa8e 100644 --- a/tests/models/embedding/language/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -1,18 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 -"""Compare the embedding outputs of HF and vLLM models. -Run `pytest tests/models/embedding/language/test_snowflake_arctic_embed.py`. -""" import pytest -from tests.models.embedding.utils import EmbedModelInfo - -from ..utils import check_embeddings_close - -EMBEDDING_PROMPTS = [ - 'what is snowflake?', 'Where can I get the best tacos?', 'The Data Cloud!', - 'Mexico City of Course!' -] +from ...utils import EmbedModelInfo, run_embedding_correctness_test MODELS = [ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", @@ -51,51 +41,38 @@ MODELS = [ @pytest.mark.parametrize("model_info", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_models( +def test_models_mteb( hf_runner, vllm_runner, - example_prompts, model_info: EmbedModelInfo, - dtype: str, - monkeypatch, ) -> None: - if not model_info.enable_test: - # A model family has many models with the same architecture, - # and we don't need to test each one. - pytest.skip("Skipping test.") + pytest.skip("Skipping mteb test.") + from .mteb_utils import mteb_test_embed_models + mteb_test_embed_models(hf_runner, vllm_runner, model_info) - example_prompts = example_prompts + EMBEDDING_PROMPTS - vllm_extra_kwargs = { - "hf_overrides": { - "is_matryoshka": model_info.is_matryoshka - } - } +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness( + hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, +) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") - with hf_runner(model_info.name, dtype=dtype, - is_sentence_transformer=True) as hf_model: - hf_outputs = hf_model.encode(example_prompts) + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] with vllm_runner(model_info.name, task="embed", - dtype=dtype, - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: - - assert (vllm_model.model.llm_engine.model_config.is_matryoshka == - model_info.is_matryoshka) - - if model_info.architecture: - assert (model_info.architecture - in vllm_model.model.llm_engine.model_config.architectures) - + dtype=model_info.dtype, + max_model_len=None) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/language/pooling/test_truncation_control.py b/tests/models/language/pooling/test_truncation_control.py new file mode 100644 index 0000000000000000000000000000000000000000..1b8ac395ed179c61aab296268ff534671a2d4cc3 --- /dev/null +++ b/tests/models/language/pooling/test_truncation_control.py @@ -0,0 +1,69 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +MODEL_NAME = "sentence-transformers/all-MiniLM-L12-v2" +max_model_len = 128 + +input_str = """Immerse yourself in the enchanting chronicle of calculus, a +mathematical domain that has radically transformed our comprehension of +change and motion. Despite its roots in ancient civilizations, the +formal birth of calculus predominantly occurred in the 17th century, +primarily under the influential guidance of Sir Isaac Newton and Gottfried +Wilhelm Leibniz. The earliest traces of calculus concepts are found in +ancient Greek mathematics,most notably in the works of Eudoxus and +Archimedes, around 300 BCE. They utilized the 'method of exhaustion'—a +technique for computing areas and volumes through the use of finite sums. +This methodology laid crucial foundational work for integral calculus. +In the 17th century, both Newton and Leibniz independently pioneered +calculus, each contributing unique perspectives that would shape this new +field.""" + + +def test_smaller_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + + truncate_prompt_tokens = 10 + + with vllm_runner(model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + vllm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + prompt_tokens = vllm_output[0].prompt_token_ids + + assert len(prompt_tokens) == truncate_prompt_tokens + + +def test_max_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + truncate_prompt_tokens = -1 + + with vllm_runner(model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + vllm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + prompt_tokens = vllm_output[0].prompt_token_ids + + assert len(prompt_tokens) == max_model_len + + +def test_bigger_truncation_size(vllm_runner, + model_name=MODEL_NAME, + input_str=input_str): + + truncate_prompt_tokens = max_model_len + 1 + + with pytest.raises(ValueError), vllm_runner( + model_name, task="embed", + max_model_len=max_model_len) as vllm_model: + + llm_output = vllm_model.model.encode( + input_str, truncate_prompt_tokens=truncate_prompt_tokens) + + assert llm_output == f"""truncate_prompt_tokens value + ({truncate_prompt_tokens}) is greater than + max_model_len ({max_model_len}). Please, select + a smaller truncation size.""" diff --git a/tests/models/decoder_only/vision_language/__init__.py b/tests/models/multimodal/generation/__init__.py similarity index 100% rename from tests/models/decoder_only/vision_language/__init__.py rename to tests/models/multimodal/generation/__init__.py diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/multimodal/generation/test_common.py similarity index 86% rename from tests/models/decoder_only/vision_language/test_models.py rename to tests/models/multimodal/generation/test_common.py index 6e8e2ecb2ab2dc2edb33f1569e12803fc135f7ae..f9b842ff4c3695ea29e0ac7e32660120d23199f3 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/multimodal/generation/test_common.py @@ -8,13 +8,14 @@ from collections import defaultdict from pathlib import PosixPath import pytest -from transformers import AutoModelForImageTextToText, AutoModelForVision2Seq +from transformers import (AutoModel, AutoModelForImageTextToText, + AutoModelForTextToWaveform, AutoModelForVision2Seq) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets, - _VideoAssets) +from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, + ImageTestAssets, VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -140,7 +141,7 @@ VLM_TEST_SETTINGS = { marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), "qwen2_5_omni": VLMTestInfo( - models=["Qwen/Qwen2.5-Omni-7B"], + models=["Qwen/Qwen2.5-Omni-3B"], test_type=( VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE, @@ -151,11 +152,23 @@ VLM_TEST_SETTINGS = { video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501 max_model_len=4096, max_num_seqs=2, - auto_cls=AutoModelForVision2Seq, + auto_cls=AutoModelForTextToWaveform, vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner, image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "ultravox": VLMTestInfo( + models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + test_type=VLMTestType.AUDIO, + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + audio_idx_to_prompt=lambda idx: "<|audio|>", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModel, + hf_output_post_proc=model_utils.ultravox_trunc_hf_output, + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -267,6 +280,7 @@ VLM_TEST_SETTINGS = { multi_image_prompt="Describe the two images in detail.", # noqa: E501 max_model_len=4096, max_num_seqs=2, + dtype="bfloat16", auto_cls=AutoModelForImageTextToText, vllm_runner_kwargs={"mm_processor_kwargs": {"do_pan_and_scan": True}}, patch_hf_runner=model_utils.gemma3_patch_hf_runner, @@ -390,7 +404,6 @@ VLM_TEST_SETTINGS = { formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 ), limit_mm_per_prompt={"video": 4}, - runner_mm_key="videos", )], ), "llava_next_video": VLMTestInfo( @@ -423,6 +436,8 @@ VLM_TEST_SETTINGS = { get_stop_token_ids=lambda tok: [tok.eos_id, tok.eot_id], hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_25_patch_hf_runner, + # FIXME: https://huggingface.co/openbmb/MiniCPM-V-2_6/discussions/55 + marks=[pytest.mark.skip("HF import fails")], ), "minicpmo_26": VLMTestInfo( models=["openbmb/MiniCPM-o-2_6"], @@ -434,6 +449,8 @@ VLM_TEST_SETTINGS = { get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmo_26_patch_hf_runner, + # FIXME: https://huggingface.co/openbmb/MiniCPM-V-2_6/discussions/55 + marks=[pytest.mark.skip("HF import fails")], ), "minicpmv_26": VLMTestInfo( models=["openbmb/MiniCPM-V-2_6"], @@ -445,6 +462,21 @@ VLM_TEST_SETTINGS = { get_stop_token_ids=lambda tok: tok.convert_tokens_to_ids(['<|im_end|>', '<|endoftext|>']), # noqa: E501 hf_output_post_proc=model_utils.minicpmv_trunc_hf_output, patch_hf_runner=model_utils.minicpmv_26_patch_hf_runner, + # FIXME: https://huggingface.co/openbmb/MiniCPM-V-2_6/discussions/55 + marks=[pytest.mark.skip("HF import fails")], + ), + "minimax_vl_01": VLMTestInfo( + models=["MiniMaxAI/MiniMax-VL-01"], + prompt_formatter=lambda img_prompt: f"user: {img_prompt} assistant:", # noqa: E501 + img_idx_to_prompt=lambda _: "", + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + max_model_len=8192, + max_num_seqs=4, + dtype="bfloat16", + hf_output_post_proc=model_utils.minimax_vl_01_hf_output, + patch_hf_runner=model_utils.minimax_vl_01_patch_hf_runner, + auto_cls=AutoModelForImageTextToText, + marks=[large_gpu_mark(min_gb=80)], ), "molmo": VLMTestInfo( models=["allenai/Molmo-7B-D-0924"], @@ -454,6 +486,43 @@ VLM_TEST_SETTINGS = { max_num_seqs=2, patch_hf_runner=model_utils.molmo_patch_hf_runner, ), + "ovis1_6-gemma2": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Gemma2-9B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"user\n{img_prompt}\nmodel\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + marks=[large_gpu_mark(min_gb=32)], + ), + "ovis1_6": VLMTestInfo( + models=["AIDC-AI/Ovis1.6-Llama3.2-3B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful and honest multimodal assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{img_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), + "ovis2": VLMTestInfo( + models=["AIDC-AI/Ovis2-1B"], + test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), + prompt_formatter=lambda img_prompt: f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "\n", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + dtype="half", + # use sdpa mode for hf runner since ovis2 didn't work with flash_attn + hf_model_kwargs={"llm_attn_implementation": "sdpa"}, + patch_hf_runner=model_utils.ovis_patch_hf_runner, + ), "phi3v": VLMTestInfo( models=["microsoft/Phi-3.5-vision-instruct"], test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE), @@ -663,6 +732,7 @@ VLM_TEST_SETTINGS = _mark_splits(VLM_TEST_SETTINGS, num_groups=2) # - multi-image # - image embeddings # - video +# - audio # - custom inputs @pytest.mark.parametrize( "model_type,test_case", @@ -675,7 +745,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -700,7 +770,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -725,7 +795,7 @@ def test_image_embedding_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -747,7 +817,7 @@ def test_image_embedding_models(model_type: str, )) def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -760,6 +830,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=False, + )) +def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( @@ -798,7 +890,7 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -824,7 +916,7 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -850,7 +942,8 @@ def test_image_embedding_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, monkeypatch): + image_assets: ImageTestAssets, + monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -873,7 +966,7 @@ def test_image_embedding_models_heavy(model_type: str, def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - video_assets: _VideoAssets, monkeypatch): + video_assets: VideoTestAssets, monkeypatch): if model_type in REQUIRES_V0_MODELS: monkeypatch.setenv("VLLM_USE_V1", "0") model_test_info = VLM_TEST_SETTINGS[model_type] @@ -886,6 +979,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=True, + )) +def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( diff --git a/tests/models/encoder_decoder/vision_language/test_florence2.py b/tests/models/multimodal/generation/test_florence2.py similarity index 97% rename from tests/models/encoder_decoder/vision_language/test_florence2.py rename to tests/models/multimodal/generation/test_florence2.py index 14b64393bf52aa8156a64a3aef431771fd0ab5c9..b8225f5f12437942234a1947edd0bac0f1bc1656 100644 --- a/tests/models/encoder_decoder/vision_language/test_florence2.py +++ b/tests/models/multimodal/generation/test_florence2.py @@ -9,7 +9,7 @@ from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt from vllm.multimodal.image import rescale_image_size from vllm.sequence import SampleLogprobs -from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner from ...utils import check_logprobs_close MODELS = ["microsoft/Florence-2-base"] @@ -118,7 +118,7 @@ def run_test( @pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - image_assets: _ImageAssets, model: str, + image_assets: ImageTestAssets, model: str, size_factors: list[int], dtype: str, max_tokens: int, num_logprobs: int) -> None: images = [asset.pil_image for asset in image_assets] diff --git a/tests/models/decoder_only/audio_language/test_granite_speech.py b/tests/models/multimodal/generation/test_granite_speech.py similarity index 94% rename from tests/models/decoder_only/audio_language/test_granite_speech.py rename to tests/models/multimodal/generation/test_granite_speech.py index 7c14845ec54d465832ccc623e8bcfda77f6b4eeb..96c444441e3d2c02ab5abf27b22a930b516199e8 100644 --- a/tests/models/decoder_only/audio_language/test_granite_speech.py +++ b/tests/models/multimodal/generation/test_granite_speech.py @@ -9,7 +9,8 @@ from transformers import AutoModelForSpeechSeq2Seq from vllm.lora.request import LoRARequest from vllm.sequence import SampleLogprobs -from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets +from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput, + VllmRunner) from ...registry import HF_EXAMPLE_MODELS from ...utils import check_logprobs_close @@ -116,9 +117,9 @@ def run_test( @pytest.mark.parametrize("max_model_len", [2048]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [10]) -def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets, - dtype: str, max_model_len: int, max_tokens: int, - num_logprobs: int) -> None: +def test_models(hf_runner, vllm_runner, model: str, + audio_assets: AudioTestAssets, dtype: str, max_model_len: int, + max_tokens: int, num_logprobs: int) -> None: model_info = HF_EXAMPLE_MODELS.find_hf_info(model) model_info.check_available_online(on_fail="skip") model_info.check_transformers_version(on_fail="skip") diff --git a/tests/models/decoder_only/vision_language/test_interleaved.py b/tests/models/multimodal/generation/test_interleaved.py similarity index 96% rename from tests/models/decoder_only/vision_language/test_interleaved.py rename to tests/models/multimodal/generation/test_interleaved.py index 8804497ae616f1e311c5c1db373cdbf95efc7a5f..eec84751e4504408d92238a17604c402eff5ccdb 100644 --- a/tests/models/decoder_only/vision_language/test_interleaved.py +++ b/tests/models/multimodal/generation/test_interleaved.py @@ -16,6 +16,7 @@ INTERLEAVED_PROMPT = base_prompt("