diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index b39dce2659a54fb0e70fa1f1c0b845463b79470b..0412c5f37952dd19227545913354b4afca2b23f4 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,36 +1,43 @@ import os +import sys import zipfile -MAX_SIZE_MB = 250 +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB +VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250)) def print_top_10_largest_files(zip_file): + """Print the top 10 largest files in the given zip file.""" with zipfile.ZipFile(zip_file, 'r') as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: - print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.") def check_wheel_size(directory): + """Check the size of .whl files in the given directory.""" for root, _, files in os.walk(directory): - for f in files: - if f.endswith(".whl"): - wheel_path = os.path.join(root, f) - wheel_size = os.path.getsize(wheel_path) - wheel_size_mb = wheel_size / (1024 * 1024) - if wheel_size_mb > MAX_SIZE_MB: - print( - f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " - f"compare to the allowed size ({MAX_SIZE_MB} MB).") + for file_name in files: + if file_name.endswith(".whl"): + wheel_path = os.path.join(root, file_name) + wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) + if wheel_size_mb > VLLM_MAX_SIZE_MB: + print(f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB).") print_top_10_largest_files(wheel_path) return 1 else: print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb} MB).") + f"({wheel_size_mb:.2f} MB).") return 0 if __name__ == "__main__": - import sys - sys.exit(check_wheel_size(sys.argv[1])) + if len(sys.argv) < 2: + print("Usage: python check-wheel-size.py ") + sys.exit(1) + + directory = sys.argv[1] + sys.exit(check_wheel_size(directory)) \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml index 15268395ec68bb4450b065349ff262ea44827922..d70ecb2a7e7b0789bf433bc174a1b14237248e16 100644 --- a/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml +++ b/.buildkite/lm-eval-harness/configs/DeepSeek-V2-Lite-Chat.yaml @@ -9,3 +9,4 @@ tasks: value: 0.664 limit: 1000 num_fewshot: 5 +trust_remote_code: True \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0ecfc01ef049f7afc58249f5adf4c5b4091a192c --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.764 + - name: "exact_match,flexible-extract" + value: 0.764 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml index c457468902c98a890f96585cb0a77994e2550983..042458659839198ee8415e1602c9375401d083fe 100644 --- a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml @@ -4,8 +4,8 @@ tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.409 + value: 0.419 - name: "exact_match,flexible-extract" - value: 0.406 + value: 0.416 limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml similarity index 60% rename from .buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml rename to .buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml index a0466748ea71e24b37d576c7cef8f672e210a162..3ea0b7bb5cd66f29e4146f6675dd8779f0942d35 100644 --- a/.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml +++ b/.buildkite/lm-eval-harness/configs/Minitron-4B-Base-FP8.yaml @@ -1,11 +1,11 @@ -# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1 -model_name: "nvidia/Minitron-4B-Base" +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m mgoin/Minitron-4B-Base-FP8 -b auto -l 1000 -f 5 -t 1 +model_name: "mgoin/Minitron-4B-Base-FP8" tasks: - name: "gsm8k" metrics: - name: "exact_match,strict-match" - value: 0.252 + value: 0.233 - name: "exact_match,flexible-extract" - value: 0.252 + value: 0.236 limit: 1000 num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index bca89f00653e354eace23b7df48d431def9197be..64a0f428587af1071a50b6cf6d62818c3965c9f6 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,10 +1,10 @@ Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml -Minitron-4B-Base.yaml +Minitron-4B-Base-FP8.yaml Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml Qwen2-1.5B-Instruct-FP8W8.yaml Meta-Llama-3-8B-QQQ.yaml diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh index fdb8ec5393b36b2507e73e9ccc3db387addf3c13..b2e910e1ba8a791ada3dc7751fb95f7a454be957 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh @@ -2,7 +2,7 @@ # We can use this script to compute baseline accuracy on GSM for transformers. # # Make sure you have lm-eval-harness installed: -# pip install git+https://github.com/EleutherAI/lm-evaluation-harness.git@9516087b81a61d0e220b22cc1b75be76de23bc10 +# pip install lm-eval==0.4.4 usage() { echo`` diff --git a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh index de841d959a4e40af2ab74b9a968f5931885730cd..4d32b49a4fac31ad1f6e5701002ee9eff4691a3b 100644 --- a/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh +++ b/.buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh @@ -3,7 +3,7 @@ # We use this for fp8, which HF does not support. # # Make sure you have lm-eval-harness installed: -# pip install lm-eval==0.4.3 +# pip install lm-eval==0.4.4 usage() { echo`` diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 7fdce7b53bd7fa59872fb2674fc33c3c94b8a2f3..afc935c1a931832246680d34a4d9ea51aa8708b2 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -14,7 +14,7 @@ import lm_eval import numpy import yaml -RTOL = 0.02 +RTOL = 0.05 TEST_DATA_FILE = os.environ.get( "LM_EVAL_TEST_DATA_FILE", ".buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct.yaml") @@ -23,9 +23,12 @@ TP_SIZE = os.environ.get("LM_EVAL_TP_SIZE", 1) def launch_lm_eval(eval_config): + trust_remote_code = eval_config.get('trust_remote_code', False) + model_args = f"pretrained={eval_config['model_name']}," \ f"tensor_parallel_size={TP_SIZE}," \ - f"add_bos_token=true" + f"add_bos_token=true," \ + f"trust_remote_code={trust_remote_code}" results = lm_eval.simple_evaluate( model="vllm", @@ -46,10 +49,15 @@ def test_lm_eval_correctness(): results = launch_lm_eval(eval_config) # Confirm scores match ground truth. + success = True for task in eval_config["tasks"]: for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] print(f'{task["name"]} | {metric["name"]}: ' f'ground_truth={ground_truth} | measured={measured_value}') - assert numpy.isclose(ground_truth, measured_value, rtol=RTOL) + success = success and numpy.isclose( + ground_truth, measured_value, rtol=RTOL) + + # Assert at the end, print all scores even on failure for debugging. + assert success diff --git a/.buildkite/nightly-benchmarks/README.md b/.buildkite/nightly-benchmarks/README.md index c1aebaf5b3bbe2e01470cf59c099924821358aa9..fbf41eb10a392aace343f4d300924752ec1ab2af 100644 --- a/.buildkite/nightly-benchmarks/README.md +++ b/.buildkite/nightly-benchmarks/README.md @@ -34,17 +34,18 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performan Performance benchmark will be triggered when: - A PR being merged into vllm. -- Every commit for those PRs with `perf-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label AND `ready` label. Nightly benchmark will be triggered when: -- Every commit for those PRs with `nightly-benchmarks` label. +- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label. ## Performance benchmark details -See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. + +See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases. #### Latency test @@ -68,7 +69,7 @@ Here is an example of one test inside `latency-tests.json`: In this example: - The `test_name` attributes is a unique identifier for the test. In `latency-tests.json`, it must start with `latency_`. -- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-benchmarks-suite.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` +- The `parameters` attribute control the command line arguments to be used for `benchmark_latency.py`. Note that please use underline `_` instead of the dash `-` when specifying the command line arguments, and `run-performance-benchmarks.sh` will convert the underline to dash when feeding the arguments to `benchmark_latency.py`. For example, the corresponding command line arguments for `benchmark_latency.py` will be `--model meta-llama/Meta-Llama-3-8B --tensor-parallel-size 1 --load-format dummy --num-iters-warmup 5 --num-iters 15` Note that the performance numbers are highly sensitive to the value of the parameters. Please make sure the parameters are set correctly. diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index 8490c9f1da221c3121e917d8335c95c278f65f15..eec2a51e2f8fde0d8dc0abeccd7e8fdd5ea6eb94 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -8,8 +8,7 @@ steps: containers: - image: badouralix/curl-jq command: - - sh - - .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh - wait - label: "A100" agents: @@ -21,7 +20,7 @@ steps: containers: - image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT command: - - bash .buildkite/nightly-benchmarks/run-benchmarks-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh resources: limits: nvidia.com/gpu: 8 diff --git a/.buildkite/nightly-benchmarks/nightly-annotation.md b/.buildkite/nightly-benchmarks/nightly-annotation.md new file mode 100644 index 0000000000000000000000000000000000000000..1e33793842bf8c4a4c580a6b534e5dec46ab93cf --- /dev/null +++ b/.buildkite/nightly-benchmarks/nightly-annotation.md @@ -0,0 +1,28 @@ + +## Description + +This file contains the downloading link for benchmarking results. + +- [benchmarking pipeline](artifact://nightly-pipeline.yaml) +- [benchmarking results](artifact://results.zip) +- [benchmarking code](artifact://nightly-benchmarks.zip) + +Please download the visualization scripts in the post + + +## Results reproduction + +- Find the docker we use in `benchmarking pipeline` +- Deploy the docker, and inside the docker: + - Download `nightly-benchmarks.zip`. + - In the same folder, run the following code +``` +export HF_TOKEN= +apt update +apt install -y git +unzip nightly-benchmarks.zip +VLLM_SOURCE_CODE_LOC=./ bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh +``` + +And the results will be inside `./benchmarks/results`. + diff --git a/.buildkite/nightly-benchmarks/nightly-descriptions.md b/.buildkite/nightly-benchmarks/nightly-descriptions.md index c3d3cbf4739689ab81fa2401246a87706a241e40..7dec7a0fe0b4e9258a2fb9655e7ea63daae95a34 100644 --- a/.buildkite/nightly-benchmarks/nightly-descriptions.md +++ b/.buildkite/nightly-benchmarks/nightly-descriptions.md @@ -1,45 +1,39 @@ # Nightly benchmark -The main goal of this benchmarking is two-fold: -- Performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and tgi) leads in performance in what workload. -- Reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions in [reproduce.md](). - - -## Docker images - -We benchmark vllm, tensorrt-llm, lmdeploy and tgi using the following docker images: -- vllm/vllm-openai:v0.5.0.post1 -- nvcr.io/nvidia/tritonserver:24.04-trtllm-python-py3 -- openmmlab/lmdeploy:v0.5.0 -- ghcr.io/huggingface/text-generation-inference:2.1 - - - - -## Hardware - -One AWS node with 8x NVIDIA A100 GPUs. - - -## Workload description - -We benchmark vllm, tensorrt-llm, lmdeploy and tgi using the following workload: - -- Input length: randomly sample 500 prompts from ShareGPT dataset (with fixed random seed). -- Output length: the corresponding output length of these 500 prompts. -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. -- Average QPS (query per second): 4 for the small model (llama-3 8B) and 2 for other two models. For each QPS, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). - - - -## Plots - -In the following plots, the dot shows the mean and the error bar shows the standard error of the mean. Value 0 means that the corresponding benchmark crashed. - -Benchmarking results - -## Results - -{nightly_results_benchmarking_table} +This benchmark aims to: +- Provide performance clarity: Provide clarity on which one (vllm, tensorrt-llm, lmdeploy and SGLang) leads in performance in what workload. +- Be reproducible: one can run the exact same set of benchmarking commands inside the exact same docker by following reproducing instructions. + +Latest results: [results link](https://blog.vllm.ai/2024/09/05/perf-update.html), scroll to the end. + +Latest reproduction guilde: [github issue link](https://github.com/vllm-project/vllm/issues/8176) + + +## Setup + +- Docker images: + - vLLM: `vllm/vllm-openai:v0.6.2` + - SGLang: `lmsysorg/sglang:v0.3.2-cu121` + - LMDeploy: `openmmlab/lmdeploy:v0.6.1-cu12` + - TensorRT-LLM: `nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3` + - *NOTE: we uses r24.07 as the current implementation only works for this version. We are going to bump this up.* + - Check [nightly-pipeline.yaml](nightly-pipeline.yaml) for the concrete docker images, specs and commands we use for the benchmark. +- Hardware + - 8x Nvidia A100 GPUs +- Workload: + - Dataset + - ShareGPT dataset + - Prefill-heavy dataset (in average 462 input tokens, 16 tokens as output) + - Decode-heavy dataset (in average 462 input tokens, 256 output tokens) + - Check [nightly-tests.json](tests/nightly-tests.json) for the concrete configuration of datasets we use. + - Models: llama-3 8B, llama-3 70B. + - We do not use llama 3.1 as it is incompatible with trt-llm r24.07. ([issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105)). + - Average QPS (query per second): 2, 4, 8, 16, 32 and inf. + - Queries are randomly sampled, and arrival patterns are determined via Poisson process, but all with fixed random seed. + - Evaluation metrics: Throughput (higher the better), TTFT (time to the first token, lower the better), ITL (inter-token latency, lower the better). + +# Known issues + +- TRT-LLM crashes with Llama 3.1 8B [issue](https://github.com/NVIDIA/TensorRT-LLM/issues/2105). +- TGI does not support `ignore-eos` flag. \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/nightly-pipeline.yaml b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml index 6e399bb936fbca5abc4d6c1fbe75b0bafade566c..199517e8b067c05df6375a7bcbc35e5a93ddb838 100644 --- a/.buildkite/nightly-benchmarks/nightly-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/nightly-pipeline.yaml @@ -13,7 +13,7 @@ common_pod_spec: &common_pod_spec common_container_settings: &common_container_settings command: - - bash .buildkite/nightly-benchmarks/run-nightly-suite.sh + - bash .buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh resources: limits: nvidia.com/gpu: 8 @@ -37,7 +37,10 @@ common_container_settings: &common_container_settings steps: - block: ":rocket: Ready for comparing vllm against alternatives? This will take 4 hours." - - label: "A100 trt benchmark" + + + + - label: "A100 vllm step 10" priority: 100 agents: queue: A100 @@ -46,7 +49,21 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: nvcr.io/nvidia/tritonserver:24.04-trtllm-python-py3 + - image: vllm/vllm-openai:v0.6.2 + <<: *common_container_settings + + + + - label: "A100 sglang benchmark" + priority: 100 + agents: + queue: A100 + plugins: + - kubernetes: + podSpec: + <<: *common_pod_spec + containers: + - image: lmsysorg/sglang:v0.3.2-cu121 <<: *common_container_settings - label: "A100 lmdeploy benchmark" @@ -58,11 +75,13 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: openmmlab/lmdeploy:v0.5.0 + - image: openmmlab/lmdeploy:v0.6.1-cu12 <<: *common_container_settings - - - label: "A100 vllm benchmark" + + + + - label: "A100 trt llama-8B" priority: 100 agents: queue: A100 @@ -71,10 +90,25 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: vllm/vllm-openai:latest + - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 <<: *common_container_settings + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_SOURCE_CODE_LOC + value: /workspace/build/buildkite/vllm/performance-benchmark + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: TEST_SELECTOR + value: "llama8B" - - label: "A100 tgi benchmark" + + - label: "A100 trt llama-70B" priority: 100 agents: queue: A100 @@ -83,12 +117,54 @@ steps: podSpec: <<: *common_pod_spec containers: - - image: ghcr.io/huggingface/text-generation-inference:2.1 + - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 <<: *common_container_settings + env: + - name: VLLM_USAGE_SOURCE + value: ci-test + - name: HF_HOME + value: /root/.cache/huggingface + - name: VLLM_SOURCE_CODE_LOC + value: /workspace/build/buildkite/vllm/performance-benchmark + - name: HF_TOKEN + valueFrom: + secretKeyRef: + name: hf-token-secret + key: token + - name: TEST_SELECTOR + value: "llama70B" + + + # FIXME(Kuntai): uncomment this after NVIDIA gives us their test docker image + # - label: "A100 trt benchmark" + # priority: 100 + # agents: + # queue: A100 + # plugins: + # - kubernetes: + # podSpec: + # <<: *common_pod_spec + # containers: + # - image: nvcr.io/nvidia/tritonserver:24.07-trtllm-python-py3 + # <<: *common_container_settings + + + # FIXME(Kuntai): uncomment this after TGI supports `--ignore-eos`. + # - label: "A100 tgi benchmark" + # priority: 100 + # agents: + # queue: A100 + # plugins: + # - kubernetes: + # podSpec: + # <<: *common_pod_spec + # containers: + # - image: ghcr.io/huggingface/text-generation-inference:2.2.0 + # <<: *common_container_settings - wait - - label: "Plot" + - label: "Collect the results" priority: 100 agents: queue: A100 @@ -117,4 +193,4 @@ steps: name: hf-token-secret key: token - - wait \ No newline at end of file + - block: ":rocket: check the results!" \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/descriptions.md b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md similarity index 81% rename from .buildkite/nightly-benchmarks/tests/descriptions.md rename to .buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md index 891e4917070d930f33c00e7ef0440bdb44e5bfdf..da32d1f073cea81cfbc5ba77481636a13bd84332 100644 --- a/.buildkite/nightly-benchmarks/tests/descriptions.md +++ b/.buildkite/nightly-benchmarks/performance-benchmarks-descriptions.md @@ -1,47 +1,42 @@ ## Latency tests -This test suite aims to test vllm's end-to-end latency under a controlled setup. - - Input length: 32 tokens. - Output length: 128 tokens. - Batch size: fixed (8). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: end-to-end latency (mean, median, p99). -### Latency benchmarking results {latency_tests_markdown_table} -## Throughput tests -This test suite aims to test vllm's throughput. +## Throughput tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). - Output length: the corresponding output length of these 200 prompts. - Batch size: dynamically determined by vllm to achieve maximum throughput. -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. - Evaluation metrics: throughput. -### Throughput benchmarking results {throughput_tests_markdown_table} -## Serving tests -This test suite aims to test vllm's real serving metrics. +## Serving tests - Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed). - Output length: the corresponding output length of these 200 prompts. - Batch size: dynamically determined by vllm and the arrival pattern of the requests. - **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed). -- Models: llama-3 8B, llama-3 70B, mixtral 8x7B. +- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B. +- We also added a speculative decoding test for llama-3 70B, under QPS 2 - Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99). -### Serving benchmarking results {serving_tests_markdown_table} + ## json version of the benchmarking tables This section contains the data of the markdown tables above in JSON format. diff --git a/.buildkite/nightly-benchmarks/run-nightly-suite.sh b/.buildkite/nightly-benchmarks/run-nightly-suite.sh deleted file mode 100644 index 627a3e6971578019f0555f241e82b0a9116e1ea0..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/run-nightly-suite.sh +++ /dev/null @@ -1,76 +0,0 @@ -#!/bin/bash - -set -o pipefail -set -x - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -check_hf_token() { - # check if HF_TOKEN is available and valid - if [[ -z "$HF_TOKEN" ]]; then - echo "Error: HF_TOKEN is not set." - exit 1 - elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then - echo "Error: HF_TOKEN does not start with 'hf_'." - exit 1 - else - echo "HF_TOKEN is set and valid." - fi -} - -main() { - - check_gpus - check_hf_token - - df -h - - (which wget && which curl) || (apt-get update && apt-get install -y wget curl) - (which jq) || (apt-get update && apt-get -y install jq) - - cd $VLLM_SOURCE_CODE_LOC/benchmarks - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json - - - # run lmdeploy - if which lmdeploy >/dev/null; then - echo "lmdeploy is available, redirect to run-lmdeploy-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh - exit 0 - fi - - # run tgi - if [ -e /tgi-entrypoint.sh ]; then - echo "tgi is available, redirect to run-tgi-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh - exit 0 - fi - - # run trt - if which trtllm-build >/dev/null; then - echo "trtllm is available, redirect to run-trt-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh - exit 0 - fi - - # run vllm - if [ -e /vllm-workspace ]; then - echo "vllm is available, redirect to run-vllm-nightly.sh" - bash ../.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh - exit 0 - fi - -} - -main "$@" \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 534ecf17930e93d5f5523a35af04c952139c2c88..f90e464288cf1ffeb4fff9716ae8fef31406e1c8 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -174,8 +174,8 @@ if __name__ == "__main__": # document the result with open(results_folder / "benchmark_results.md", "w") as f: - results = read_markdown( - "../.buildkite/nightly-benchmarks/tests/descriptions.md") + results = read_markdown("../.buildkite/nightly-benchmarks/" + + "performance-benchmarks-descriptions.md") results = results.format( latency_tests_markdown_table=latency_md_table, throughput_tests_markdown_table=throughput_md_table, diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py new file mode 100644 index 0000000000000000000000000000000000000000..6059588fe7277c62ca8263136e5282e15ee5888d --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -0,0 +1,95 @@ +import argparse +import json +from pathlib import Path + +import numpy as np +import pandas as pd +from tabulate import tabulate + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description= + 'Parse command line arguments for summary-nightly-results script.') + parser.add_argument('--results-folder', + type=str, + required=True, + help='The folder where the results are stored.') + parser.add_argument('--description', + type=str, + required=True, + help='Description of the results.') + + args = parser.parse_args() + return args + + +def get_perf(df, method, model, metric): + + means = [] + + for qps in [2, 4, 8, 16, "inf"]: + target = df['Test name'].str.contains(model) + target = target & df['Engine'].str.contains(method) + target = target & df['Test name'].str.contains("qps_" + str(qps)) + filtered_df = df[target] + + if filtered_df.empty: + means.append(0.) + else: + means.append(filtered_df[metric].values[0]) + + return np.array(means) + + +def get_perf_w_std(df, method, model, metric): + + if metric in ["TTFT", "ITL"]: + mean = get_perf(df, method, model, "Mean " + metric + " (ms)") + mean = mean.tolist() + std = get_perf(df, method, model, "Std " + metric + " (ms)") + if std.mean() == 0: + std = None + success = get_perf(df, method, model, "Successful req.") + if std is not None: + std = std / np.sqrt(success) + std = std.tolist() + + else: + assert metric == "Tput" + mean = get_perf(df, method, model, "Input Tput (tok/s)") + get_perf( + df, method, model, "Output Tput (tok/s)") + mean = mean.tolist() + std = None + + return mean, std + + +def main(args): + results_folder = Path(args.results_folder) + + results = [] + + # collect results + for test_file in results_folder.glob("*_nightly_results.json"): + with open(test_file, "r") as f: + results = results + json.loads(f.read()) + + # generate markdown table + df = pd.DataFrame.from_dict(results) + + md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) + + with open(args.description, "r") as f: + description = f.read() + + description = description.format( + nightly_results_benchmarking_table=md_table) + + with open("nightly_results.md", "w") as f: + f.write(description) + + +if __name__ == '__main__': + args = parse_arguments() + main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/launch-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-server.sh new file mode 100644 index 0000000000000000000000000000000000000000..e9d7d6a8d760ae3da6d1cf94dec14f6ff297daee --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/launch-server.sh @@ -0,0 +1,241 @@ +#!/bin/bash + +# Currently FP8 benchmark is NOT enabled. + +set -x +server_params=$1 +common_params=$2 + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +launch_trt_server() { + + model_path=$(echo "$common_params" | jq -r '.model') + model_name="${model_path#*/}" + model_type=$(echo "$server_params" | jq -r '.model_type') + model_dtype=$(echo "$server_params" | jq -r '.model_dtype') + model_tp_size=$(echo "$common_params" | jq -r '.tp') + max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') + max_input_len=$(echo "$server_params" | jq -r '.max_input_len') + max_seq_len=$(echo "$server_params" | jq -r '.max_seq_len') + max_num_tokens=$(echo "$server_params" | jq -r '.max_num_tokens') + trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') + + # create model caching directory + cd ~ + rm -rf models + mkdir -p models + cd models + models_dir=$(pwd) + trt_model_path=${models_dir}/${model_name}-trt-ckpt + trt_engine_path=${models_dir}/${model_name}-trt-engine + + # clone tensorrt backend + cd / + rm -rf tensorrtllm_backend + git clone https://github.com/triton-inference-server/tensorrtllm_backend.git + git lfs install + cd tensorrtllm_backend + git checkout $trt_llm_version + tensorrtllm_backend_dir=$(pwd) + git submodule update --init --recursive + + # build trtllm engine + cd /tensorrtllm_backend + cd ./tensorrt_llm/examples/${model_type} + python3 convert_checkpoint.py \ + --model_dir ${model_path} \ + --dtype ${model_dtype} \ + --tp_size ${model_tp_size} \ + --output_dir ${trt_model_path} + trtllm-build \ + --checkpoint_dir ${trt_model_path} \ + --use_fused_mlp \ + --reduce_fusion disable \ + --workers 8 \ + --gpt_attention_plugin ${model_dtype} \ + --gemm_plugin ${model_dtype} \ + --tp_size ${model_tp_size} \ + --max_batch_size ${max_batch_size} \ + --max_input_len ${max_input_len} \ + --max_seq_len ${max_seq_len} \ + --max_num_tokens ${max_num_tokens} \ + --output_dir ${trt_engine_path} + + # handle triton protobuf files and launch triton server + cd /tensorrtllm_backend + mkdir triton_model_repo + cp -r all_models/inflight_batcher_llm/* triton_model_repo/ + cd triton_model_repo + rm -rf ./tensorrt_llm/1/* + cp -r ${trt_engine_path}/* ./tensorrt_llm/1 + python3 ../tools/fill_template.py -i tensorrt_llm/config.pbtxt triton_backend:tensorrtllm,engine_dir:/tensorrtllm_backend/triton_model_repo/tensorrt_llm/1,decoupled_mode:true,batching_strategy:inflight_fused_batching,batch_scheduler_policy:guaranteed_no_evict,exclude_input_in_output:true,triton_max_batch_size:2048,max_queue_delay_microseconds:0,max_beam_width:1,max_queue_size:2048,enable_kv_cache_reuse:false + python3 ../tools/fill_template.py -i preprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,preprocessing_instance_count:5 + python3 ../tools/fill_template.py -i postprocessing/config.pbtxt triton_max_batch_size:2048,tokenizer_dir:$model_path,postprocessing_instance_count:5,skip_special_tokens:false + python3 ../tools/fill_template.py -i ensemble/config.pbtxt triton_max_batch_size:$max_batch_size + python3 ../tools/fill_template.py -i tensorrt_llm_bls/config.pbtxt triton_max_batch_size:$max_batch_size,decoupled_mode:true,accumulate_tokens:"False",bls_instance_count:1 + cd /tensorrtllm_backend + python3 scripts/launch_triton_server.py \ + --world_size=${model_tp_size} \ + --model_repo=/tensorrtllm_backend/triton_model_repo & + +} + +launch_tgi_server() { + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params." + server_command="/tgi-entrypoint.sh \ + --model-id $model \ + --num-shard $tp \ + --port $port \ + --quantize fp8 \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="/tgi-entrypoint.sh \ + --model-id $model \ + --num-shard $tp \ + --port $port \ + $server_args" + fi + + echo "Server command: $server_command" + eval "$server_command" & + +} + +launch_lmdeploy_server() { + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + server_args=$(json2args "$server_params") + + server_command="lmdeploy serve api_server $model \ + --tp $tp \ + --server-port $port \ + $server_args" + + # run the server + echo "Server command: $server_command" + bash -c "$server_command" & +} + +launch_sglang_server() { + + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." + model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') + server_command="python3 \ + -m sglang.launch_server \ + --tp $tp \ + --model-path $model \ + --port $port \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="python3 \ + -m sglang.launch_server \ + --tp $tp \ + --model-path $model \ + --port $port \ + $server_args" + fi + + # run the server + echo "Server command: $server_command" + eval "$server_command" & +} + +launch_vllm_server() { + + export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') + + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + server_args=$(json2args "$server_params") + + if echo "$common_params" | jq -e 'has("fp8")' >/dev/null; then + echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." + model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + -tp $tp \ + --model $model \ + --port $port \ + $server_args" + else + echo "Key 'fp8' does not exist in common params." + server_command="python3 \ + -m vllm.entrypoints.openai.api_server \ + -tp $tp \ + --model $model \ + --port $port \ + $server_args" + fi + + # run the server + echo "Server command: $server_command" + eval "$server_command" & +} + +main() { + + if [[ $CURRENT_LLM_SERVING_ENGINE == "trt" ]]; then + launch_trt_server + fi + + if [[ $CURRENT_LLM_SERVING_ENGINE == "tgi" ]]; then + launch_tgi_server + fi + + if [[ $CURRENT_LLM_SERVING_ENGINE == "lmdeploy" ]]; then + launch_lmdeploy_server + fi + + if [[ $CURRENT_LLM_SERVING_ENGINE == "sglang" ]]; then + launch_sglang_server + fi + + if [[ "$CURRENT_LLM_SERVING_ENGINE" == *"vllm"* ]]; then + launch_vllm_server + fi +} + +main diff --git a/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh b/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh deleted file mode 100644 index f8262653a662893c5250446b3f3a5b1416af04f6..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh +++ /dev/null @@ -1,102 +0,0 @@ -#!/bin/bash - - -server_params=$1 -common_params=$2 - - - -model_path=$(echo "$common_params" | jq -r '.model') -model_name="${model_path#*/}" -model_type=$(echo "$server_params" | jq -r '.model_type') -model_dtype=$(echo "$server_params" | jq -r '.model_dtype') -model_tp_size=$(echo "$common_params" | jq -r '.tp') -max_batch_size=$(echo "$server_params" | jq -r '.max_batch_size') -max_input_len=$(echo "$server_params" | jq -r '.max_input_len') -max_output_len=$(echo "$server_params" | jq -r '.max_output_len') -trt_llm_version=$(echo "$server_params" | jq -r '.trt_llm_version') - -cd ~ -rm -rf models -mkdir -p models -cd models -models_dir=$(pwd) -trt_model_path=${models_dir}/${model_name}-trt-ckpt -trt_engine_path=${models_dir}/${model_name}-trt-engine - -cd ~ -rm -rf tensorrt-demo -git clone https://github.com/neuralmagic/tensorrt-demo.git -cd tensorrt-demo -tensorrt_demo_dir=$(pwd) - -# make sure the parameter inside tensorrt_demo is consistent to envvar -sed -i.bak "/key: \"tokenizer_dir\"/,/string_value:/s|string_value: \".*\"|string_value: \"$model_path\"|" ./triton_model_repo/postprocessing/config.pbtxt -sed -i.bak "/key: \"tokenizer_dir\"/,/string_value:/s|string_value: \".*\"|string_value: \"$model_path\"|" ./triton_model_repo/preprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/ensemble/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/preprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/postprocessing/config.pbtxt -sed -i.bak "s|\(max_batch_size:\s*\)[0-9]*|\1$max_batch_size|g" ./triton_model_repo/tensorrt_llm_bls/config.pbtxt - - -cd / -rm -rf tensorrtllm_backend -git clone https://github.com/triton-inference-server/tensorrtllm_backend.git -git lfs install -cd tensorrtllm_backend -git checkout $trt_llm_version -tensorrtllm_backend_dir=$(pwd) -git submodule update --init --recursive -cp -r ${tensorrt_demo_dir}/triton_model_repo ${tensorrtllm_backend_dir}/ - -cd /tensorrtllm_backend -cd ./tensorrt_llm/examples/${model_type} - - -if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - - echo "Key 'fp8' exists in common params. Use quantize.py instead of convert_checkpoint.py" - echo "Reference: https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/llama/README.md" - python ../quantization/quantize.py \ - --model_dir ${model_path} \ - --dtype ${model_dtype} \ - --tp_size ${model_tp_size} \ - --output_dir ${trt_model_path} \ - --qformat fp8 \ - --kv_cache_dtype fp8 \ - --calib_size 2 - -else - - echo "Key 'fp8' does not exist in common params. Use convert_checkpoint.py" - python3 convert_checkpoint.py \ - --model_dir ${model_path} \ - --dtype ${model_dtype} \ - --tp_size ${model_tp_size} \ - --output_dir ${trt_model_path} - -fi - - - -trtllm-build \ ---checkpoint_dir=${trt_model_path} \ ---gpt_attention_plugin=${model_dtype} \ ---gemm_plugin=${model_dtype} \ ---remove_input_padding=enable \ ---paged_kv_cache=enable \ ---tp_size=${model_tp_size} \ ---max_batch_size=${max_batch_size} \ ---max_input_len=${max_input_len} \ ---max_output_len=${max_output_len} \ ---max_num_tokens=${max_output_len} \ ---opt_num_tokens=${max_output_len} \ ---output_dir=${trt_engine_path} - -cd /tensorrtllm_backend/triton_model_repo -rm -rf ./tensorrt_llm/1/* -cp -r ${trt_engine_path}/* ./tensorrt_llm/1 -cd /tensorrtllm_backend -python3 scripts/launch_triton_server.py \ ---world_size=${model_tp_size} \ ---model_repo=/tensorrtllm_backend/triton_model_repo & \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh index 1168912c6e229ecbc7eb3837bbf188756d0a2467..c6a1bbdeb7d48b14769c1d41218b9cc4dfd0da3d 100644 --- a/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh +++ b/.buildkite/nightly-benchmarks/scripts/nightly-annotate.sh @@ -8,6 +8,7 @@ main() { (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which zip) || (apt-get install -y zip) if [ ! -f /workspace/buildkite-agent ]; then echo "buildkite-agent binary not found. Skip plotting the results." @@ -24,17 +25,54 @@ main() { ls ls results/ - # generate figures - python3 -m pip install tabulate pandas matplotlib - python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ - --description $description \ - --results-folder results/ + # upload benchmark results + zip -r results.zip results/ + /workspace/buildkite-agent artifact upload "results.zip" + + # upload benchmarking scripts + cd $VLLM_SOURCE_CODE_LOC/ + zip -r nightly-benchmarks.zip .buildkite/ benchmarks/ + /workspace/buildkite-agent artifact upload "nightly-benchmarks.zip" + + cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ + # upload benchmarking pipeline + /workspace/buildkite-agent artifact upload "nightly-pipeline.yaml" + + cd $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ + /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly-annotation.md + + + + # The figures should be genereated by a separate process outside the CI/CD pipeline + + # # generate figures + # python3 -m pip install tabulate pandas matplotlib + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py \ + # --description $description \ + # --results-folder results/ + + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sharegpt + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sonnet_2048_128 + + # python3 $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py \ + # --description $description \ + # --results-folder results/ \ + # --dataset sonnet_128_2048 - # upload results and figures - /workspace/buildkite-agent artifact upload "nightly_results.png" - /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-pipeline.yaml - /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/tests/nightly-tests.json - /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md + # # upload results and figures + # /workspace/buildkite-agent artifact upload "nightly_results*.png" + # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/nightly-pipeline.yaml + # /workspace/buildkite-agent artifact upload $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/tests/nightly-tests.json + # /workspace/buildkite-agent annotate --style "success" --context "nightly-benchmarks-results" --append < nightly_results.md } main "$@" \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py deleted file mode 100644 index e5cfcc64a9b2a5c860eb21bc4e7393506fb49fa1..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/plot-nightly-results.py +++ /dev/null @@ -1,135 +0,0 @@ -import argparse -import json -import math -from pathlib import Path - -import matplotlib.pyplot as plt -import pandas as pd -from tabulate import tabulate - - -def parse_arguments(): - parser = argparse.ArgumentParser( - description= - 'Parse command line arguments for summary-nightly-results script.') - parser.add_argument('--results-folder', - type=str, - required=True, - help='The folder where the results are stored.') - parser.add_argument('--description', - type=str, - required=True, - help='Description of the results.') - - args = parser.parse_args() - return args - - -def main(args): - bar_colors = ['#56B4E9', '#009E73', '#D55E00', '#E69F00'] - results_folder = Path(args.results_folder) - - results = [] - - # collect results - for test_file in results_folder.glob("*_nightly_results.json"): - with open(test_file, "r") as f: - results = results + json.loads(f.read()) - - # generate markdown table - df = pd.DataFrame.from_dict(results) - - md_table = tabulate(df, headers='keys', tablefmt='pipe', showindex=False) - - with open(args.description, "r") as f: - description = f.read() - - description = description.format( - nightly_results_benchmarking_table=md_table) - - with open("nightly_results.md", "w") as f: - f.write(description) - - plt.rcParams.update({'font.size': 20}) - - # plot results - fig, axes = plt.subplots(3, 3, figsize=(16, 14)) - fig.subplots_adjust(hspace=1) - methods = ["vllm", "trt", "lmdeploy", "tgi"] - for i, model in enumerate(["llama8B", "llama70B", "mixtral8x7B"]): - for j, metric in enumerate(["TTFT", "ITL"]): - means, stds = [], [] - for method in methods: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - filtered_df = df[target] - - if filtered_df.empty: - means.append(0.) - stds.append(0.) - else: - means.append(filtered_df[f"Mean {metric} (ms)"].values[0]) - std = filtered_df[f"Std {metric} (ms)"].values[0] - success = filtered_df["Successful req."].values[0] - stds.append(std / math.sqrt(success)) - - print(model, metric) - print(means, stds) - - ax = axes[i, j + 1] - - bars = ax.bar( - ["vllm", "trt", "lmdeploy", "tgi"], - means, - yerr=stds, - capsize=10, - ) - for idx, bar in enumerate(bars): - bar.set_color(bar_colors[idx]) - ax.set_ylim(bottom=0) - - ax.set_ylabel(f"{metric} (ms)") - ax.set_title(f"{model} {metric}") - ax.grid(axis='y') - - metric = "Tput" - j = 0 - if True: - tputs = [] - for method in methods: - target = df['Test name'].str.contains(model) - target = target & df['Engine'].str.contains(method) - filtered_df = df[target] - - if filtered_df.empty: - tputs.append(0.) - else: - input_tput = filtered_df["Input Tput (tok/s)"].values[0] - output_tput = filtered_df["Output Tput (tok/s)"].values[0] - tputs.append(input_tput + output_tput) - - print(model, metric) - print(tputs) - - ax = axes[i, j] - - bars = ax.bar( - ["vllm", "trt", "lmdeploy", "tgi"], - tputs, - ) - for idx, bar in enumerate(bars): - bar.set_color(bar_colors[idx]) - - ax.set_ylim(bottom=0) - - ax.set_ylabel("Tput (token/s)") - ax.set_title(f"{model} {metric}") - ax.grid(axis='y') - - fig.tight_layout() - fig.savefig("nightly_results.png", bbox_inches='tight', dpi=400) - - -if __name__ == '__main__': - args = parse_arguments() - main(args) diff --git a/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh deleted file mode 100644 index d6f112aaa42fd3a9e64c583aefaaa4bc6e4e98f2..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-lmdeploy-nightly.sh +++ /dev/null @@ -1,218 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill lmdeploy || true - # waiting for GPU processes to be fully killed - sleep 10 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl -s localhost:8000/v1/completions > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append lmdeploy to the test name - test_name=lmdeploy_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.lmdeploy_server_parameters') - client_params=$(echo "$params" | jq -r '.lmdeploy_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - # prepare tokenizer - rm -rf /tokenizer_cache - mkdir /tokenizer_cache - python ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ - --model "$model" \ - --cachedir /tokenizer_cache - - server_command="lmdeploy serve api_server $model \ - --tp $tp \ - --server-port $port \ - $server_args" - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - bash -c "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "lmdeploy server is up and running." - else - echo "" - echo "lmdeploy failed to start within the timeout period." - break - fi - - # get model name - model_name=$(python ../.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py) - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend lmdeploy \ - --tokenizer /tokenizer_cache \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - --model \"$model_name\" \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "lmdeploy" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - python -m pip install transformers==4.41.2 - - export CURRENT_LLM_SERVING_ENGINE=lmdeploy - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh new file mode 100644 index 0000000000000000000000000000000000000000..dd8c15e0700ebc11996f5eff589a5ff509f42546 --- /dev/null +++ b/.buildkite/nightly-benchmarks/scripts/run-nightly-benchmarks.sh @@ -0,0 +1,357 @@ +#!/bin/bash + +set -o pipefail +set -x + +check_gpus() { + # check the number of GPUs and GPU type. + declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) + if [[ $gpu_count -gt 0 ]]; then + echo "GPU found." + else + echo "Need at least 1 GPU to run benchmarking." + exit 1 + fi + declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') + echo "GPU type is $gpu_type" +} + +check_hf_token() { + # check if HF_TOKEN is available and valid + if [[ -z "$HF_TOKEN" ]]; then + echo "Error: HF_TOKEN is not set." + exit 1 + elif [[ ! "$HF_TOKEN" =~ ^hf_ ]]; then + echo "Error: HF_TOKEN does not start with 'hf_'." + exit 1 + else + echo "HF_TOKEN is set and valid." + fi +} + + +upload_to_buildkite() { + # upload the benchmarking results to buildkite + + # if the agent binary is not found, skip uploading the results, exit 0 + if [ ! -f /workspace/buildkite-agent ]; then + echo "buildkite-agent binary not found. Skip uploading the results." + return 0 + fi + # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md + /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" +} + + +get_current_llm_serving_engine() { + + if which lmdeploy >/dev/null; then + echo "Container: lmdeploy" + export CURRENT_LLM_SERVING_ENGINE=lmdeploy + return + fi + + if [ -e /tgi-entrypoint.sh ]; then + echo "Container: tgi" + export CURRENT_LLM_SERVING_ENGINE=tgi + return + fi + + if which trtllm-build >/dev/null; then + echo "Container: tensorrt-llm" + export CURRENT_LLM_SERVING_ENGINE=trt + return + fi + + if [ -e /sgl-workspace ]; then + echo "Container: sglang" + export CURRENT_LLM_SERVING_ENGINE=sglang + return + fi + + if [ -e /vllm-workspace ]; then + echo "Container: vllm" + # move to a completely irrelevant directory, to avoid import vllm from current folder + export CURRENT_LLM_SERVING_ENGINE=vllm + + return + fi +} + +json2args() { + # transforms the JSON string to command line args, and '_' is replaced to '-' + # example: + # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } + # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 + local json_string=$1 + local args=$( + echo "$json_string" | jq -r ' + to_entries | + map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | + join(" ") + ' + ) + echo "$args" +} + +kill_gpu_processes() { + pkill -f python + pkill -f python3 + pkill -f tritonserver + pkill -f pt_main_thread + pkill -f text-generation + pkill -f lmdeploy + + while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do + sleep 1 + done +} + +wait_for_server() { + # wait for vllm server to start + # return 1 if vllm server crashes + timeout 1200 bash -c ' + until curl -s localhost:8000/v1/completions > /dev/null; do + sleep 1 + done' && return 0 || return 1 +} + +ensure_installed() { + # Ensure that the given command is installed by apt-get + local cmd=$1 + if ! which $cmd >/dev/null; then + apt-get update && apt-get install -y $cmd + fi +} + +run_serving_tests() { + # run serving tests using `benchmark_serving.py` + # $1: a json file specifying serving test cases + + local serving_test_file + serving_test_file=$1 + + # Iterate over serving tests + jq -c '.[]' "$serving_test_file" | while read -r params; do + # get the test name, and append the GPU type back to it. + test_name=$(echo "$params" | jq -r '.test_name') + + # if TEST_SELECTOR is set, only run the test cases that match the selector + if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then + echo "Skip test case $test_name." + continue + fi + + # prepend the current serving engine to the test name + test_name=${CURRENT_LLM_SERVING_ENGINE}_${test_name} + + # get common parameters + common_params=$(echo "$params" | jq -r '.common_parameters') + model=$(echo "$common_params" | jq -r '.model') + tp=$(echo "$common_params" | jq -r '.tp') + dataset_name=$(echo "$common_params" | jq -r '.dataset_name') + dataset_path=$(echo "$common_params" | jq -r '.dataset_path') + port=$(echo "$common_params" | jq -r '.port') + num_prompts=$(echo "$common_params" | jq -r '.num_prompts') + reuse_server=$(echo "$common_params" | jq -r '.reuse_server') + + # get client and server arguments + server_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_server_parameters") + client_params=$(echo "$params" | jq -r ".${CURRENT_LLM_SERVING_ENGINE}_client_parameters") + client_args=$(json2args "$client_params") + qps_list=$(echo "$params" | jq -r '.qps_list') + qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') + echo "Running over qps list $qps_list" + + # check if there is enough GPU to run the test + if [[ $gpu_count -lt $tp ]]; then + echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." + continue + fi + + if [[ $reuse_server == "true" ]]; then + echo "Reuse previous server for test case $test_name" + else + kill_gpu_processes + bash $VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/scripts/launch-server.sh \ + "$server_params" "$common_params" + fi + + wait_for_server + + if [ $? -eq 0 ]; then + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE server is up and running." + else + echo "" + echo "$CURRENT_LLM_SERVING_ENGINE failed to start within the timeout period." + break + fi + + # prepare tokenizer + # this is required for lmdeploy. + cd $VLLM_SOURCE_CODE_LOC/benchmarks + rm -rf /tokenizer_cache + mkdir /tokenizer_cache + python3 ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ + --model "$model" \ + --cachedir /tokenizer_cache + cd $VLLM_SOURCE_CODE_LOC/benchmarks + + + # change model name for lmdeploy (it will not follow standard hf name) + if [[ "$CURRENT_LLM_SERVING_ENGINE" == "lmdeploy" ]]; then + model=$(python ../.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py) + fi + + # iterate over different QPS + for qps in $qps_list; do + # remove the surrounding single quote from qps + if [[ "$qps" == *"inf"* ]]; then + echo "qps was $qps" + qps="inf" + echo "now qps is $qps" + fi + + new_test_name=$test_name"_qps_"$qps + + backend=$CURRENT_LLM_SERVING_ENGINE + + if [[ $backend = "trt" ]]; then + backend="tensorrt-llm" + fi + + if [[ "$backend" == *"vllm"* ]]; then + backend="vllm" + fi + + if [[ "$dataset_name" = "sharegpt" ]]; then + + client_command="python3 benchmark_serving.py \ + --backend $backend \ + --tokenizer /tokenizer_cache \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --num-prompts $num_prompts \ + --port $port \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --ignore-eos \ + $client_args" + + elif [[ "$dataset_name" = "sonnet" ]]; then + + sonnet_input_len=$(echo "$common_params" | jq -r '.sonnet_input_len') + sonnet_output_len=$(echo "$common_params" | jq -r '.sonnet_output_len') + sonnet_prefix_len=$(echo "$common_params" | jq -r '.sonnet_prefix_len') + + client_command="python3 benchmark_serving.py \ + --backend $backend \ + --tokenizer /tokenizer_cache \ + --model $model \ + --dataset-name $dataset_name \ + --dataset-path $dataset_path \ + --num-prompts $num_prompts \ + --sonnet-input-len $sonnet_input_len \ + --sonnet-output-len $sonnet_output_len \ + --sonnet-prefix-len $sonnet_prefix_len \ + --port $port \ + --save-result \ + --result-dir $RESULTS_FOLDER \ + --result-filename ${new_test_name}.json \ + --request-rate $qps \ + --ignore-eos \ + $client_args" + + else + + echo "The dataset name must be either 'sharegpt' or 'sonnet'. Got $dataset_name." + exit 1 + + fi + + + + echo "Running test case $test_name with qps $qps" + echo "Client command: $client_command" + + eval "$client_command" + + server_command="None" + + # record the benchmarking commands + jq_output=$(jq -n \ + --arg server "$server_command" \ + --arg client "$client_command" \ + --arg gpu "$gpu_type" \ + --arg engine "$CURRENT_LLM_SERVING_ENGINE" \ + '{ + server_command: $server, + client_command: $client, + gpu_type: $gpu, + engine: $engine + }') + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" + + done + + done + + kill_gpu_processes +} + + +prepare_dataset() { + + # download sharegpt dataset + cd $VLLM_SOURCE_CODE_LOC/benchmarks + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json + + # duplicate sonnet by 4x, to allow benchmarking with input length 2048 + cd $VLLM_SOURCE_CODE_LOC/benchmarks + echo "" > sonnet_4x.txt + for _ in {1..4} + do + cat sonnet.txt >> sonnet_4x.txt + done + +} + +main() { + + # check if the environment variable is successfully injected from yaml + + check_gpus + check_hf_token + get_current_llm_serving_engine + + pip install -U transformers + + # check storage + df -h + + ensure_installed wget + ensure_installed curl + ensure_installed jq + + prepare_dataset + + cd $VLLM_SOURCE_CODE_LOC/benchmarks + declare -g RESULTS_FOLDER=results/ + mkdir -p $RESULTS_FOLDER + BENCHMARK_ROOT=$VLLM_SOURCE_CODE_LOC/.buildkite/nightly-benchmarks/ + + # run the test + run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json + + # upload benchmark results to buildkite + python3 -m pip install tabulate pandas + python3 $BENCHMARK_ROOT/scripts/summary-nightly-results.py + upload_to_buildkite + +} + +main "$@" diff --git a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh similarity index 89% rename from .buildkite/nightly-benchmarks/run-benchmarks-suite.sh rename to .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index 1a88d038b4b527ecc2e3ae1c8089d95fbd7e8d33..a0b9a409b758d121d67930b0c9f669160564745f 100644 --- a/.buildkite/nightly-benchmarks/run-benchmarks-suite.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -37,9 +37,9 @@ check_hf_token() { ensure_sharegpt_downloaded() { local FILE=ShareGPT_V3_unfiltered_cleaned_split.json if [ ! -f "$FILE" ]; then - wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE + wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE else - echo "$FILE already exists." + echo "$FILE already exists." fi } @@ -68,35 +68,38 @@ wait_for_server() { done' && return 0 || return 1 } -kill_gpu_processes() { - # kill all processes on GPU. - pids=$(nvidia-smi --query-compute-apps=pid --format=csv,noheader) - if [ -z "$pids" ]; then - echo "No GPU processes found." +kill_processes_launched_by_current_bash() { + # Kill all python processes launched from current bash script + current_shell_pid=$$ + processes=$(ps -eo pid,ppid,command | awk -v ppid="$current_shell_pid" -v proc="$1" '$2 == ppid && $3 ~ proc {print $1}') + if [ -n "$processes" ]; then + echo "Killing the following processes matching '$1':" + echo "$processes" + echo "$processes" | xargs kill -9 else - for pid in $pids; do - kill -9 "$pid" - echo "Killed process with PID: $pid" - done - - echo "All GPU processes have been killed." + echo "No processes found matching '$1'." fi +} + +kill_gpu_processes() { - # waiting for GPU processes to be fully killed - # loop while nvidia-smi returns any processes - while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do + ps -aux + lsof -t -i:8000 | xargs -r kill -9 + pkill -f pt_main_thread + # this line doesn't work now + # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 + pkill -f python3 + pkill -f /usr/bin/python3 + + + # wait until GPU memory usage smaller than 1GB + while [ $(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits | head -n 1) -ge 1000 ]; do sleep 1 - echo "Waiting for GPU processes to be killed" done # remove vllm config file rm -rf ~/.config/vllm - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" } upload_to_buildkite() { @@ -114,7 +117,7 @@ upload_to_buildkite() { fi # Use the determined command to annotate and upload artifacts - $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" < $RESULTS_FOLDER/benchmark_results.md + $BUILDKITE_AGENT_COMMAND annotate --style "info" --context "$BUILDKITE_LABEL-benchmark-results" <$RESULTS_FOLDER/benchmark_results.md $BUILDKITE_AGENT_COMMAND artifact upload "$RESULTS_FOLDER/*" } @@ -166,7 +169,7 @@ run_latency_tests() { latency_command: $latency, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$latency_command" @@ -176,7 +179,6 @@ run_latency_tests() { done } - run_throughput_tests() { # run throughput tests using `benchmark_throughput.py` # $1: a json file specifying throughput test cases @@ -224,7 +226,7 @@ run_throughput_tests() { throughput_command: $command, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/$test_name.commands" + echo "$jq_output" >"$RESULTS_FOLDER/$test_name.commands" # run the benchmark eval "$throughput_command" @@ -256,7 +258,6 @@ run_serving_tests() { continue fi - # get client and server arguments server_params=$(echo "$params" | jq -r '.server_parameters') client_params=$(echo "$params" | jq -r '.client_parameters') @@ -334,7 +335,7 @@ run_serving_tests() { client_command: $client, gpu_type: $gpu }') - echo "$jq_output" > "$RESULTS_FOLDER/${new_test_name}.commands" + echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" done @@ -351,6 +352,7 @@ main() { # dependencies (which wget && which curl) || (apt-get update && apt-get install -y wget curl) (which jq) || (apt-get update && apt-get -y install jq) + (which lsof) || (apt-get update && apt-get install -y lsof) # get the current IP address, required by benchmark_serving.py export VLLM_HOST_IP=$(hostname -I | awk '{print $1}') @@ -369,7 +371,6 @@ main() { run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json - # postprocess benchmarking results pip install tabulate pandas python3 $QUICK_BENCHMARK_ROOT/scripts/convert-results-json-to-markdown.py diff --git a/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh deleted file mode 100644 index fed03654f8b77c43130677deaf11abf3a1cdd387..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-tgi-nightly.sh +++ /dev/null @@ -1,216 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill text-generation || true - # waiting for GPU processes to be fully killed - sleep 10 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - timeout 1200 bash -c ' - until curl -s localhost:8000/generate_stream > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append tgi to the test name - test_name=tgi_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.tgi_server_parameters') - client_params=$(echo "$params" | jq -r '.tgi_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required num-shard $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - echo "Key 'fp8' exists in common params." - server_command="/tgi-entrypoint.sh \ - --model-id $model \ - --num-shard $tp \ - --port $port \ - --quantize fp8 \ - $server_args" - else - echo "Key 'fp8' does not exist in common params." - server_command="/tgi-entrypoint.sh \ - --model-id $model \ - --num-shard $tp \ - --port $port \ - $server_args" - fi - - - - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "tgi server is up and running." - else - echo "" - echo "tgi failed to start within the timeout period." - break - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend tgi \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "tgi" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - export CURRENT_LLM_SERVING_ENGINE=tgi - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh deleted file mode 100644 index 4a82b9ec64d713ed9be72bc888aa651123574ae3..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-trt-nightly.sh +++ /dev/null @@ -1,214 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - pkill tritonserver || true - # waiting for GPU processes to be fully killed - sleep 20 - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - timeout 1200 bash -c ' - until curl -s localhost:8000/generate_stream > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append trt to the test name - test_name=trt_$test_name - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.trt_server_parameters') - client_params=$(echo "$params" | jq -r '.trt_client_parameters') - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required model_tp_size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - - - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - - echo "Running test case $test_name" - bash ../.buildkite/nightly-benchmarks/scripts/launch-trt-server.sh "$server_params" "$common_params" - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "trt server is up and running." - else - echo "" - echo "trt failed to start within the timeout period." - break - fi - - # prepare tokenizer - cd $VLLM_SOURCE_CODE_LOC/benchmarks - rm -rf /tokenizer_cache - mkdir /tokenizer_cache - python ../.buildkite/nightly-benchmarks/scripts/download-tokenizer.py \ - --model "$model" \ - --cachedir /tokenizer_cache - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend tensorrt-llm \ - --tokenizer /tokenizer_cache \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - server_command="" - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "trt" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - - -main() { - - check_gpus - - - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - # update transformers package, to make sure mixtral tokenizer is available - python -m pip install transformers -U - - export CURRENT_LLM_SERVING_ENGINE=trt - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - python -m pip install tabulate pandas - python $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh b/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh deleted file mode 100644 index 663045b8a912269e2a89796a63e8806c97c1e5c1..0000000000000000000000000000000000000000 --- a/.buildkite/nightly-benchmarks/scripts/run-vllm-nightly.sh +++ /dev/null @@ -1,221 +0,0 @@ -#!/bin/bash - -set -o pipefail - -check_gpus() { - # check the number of GPUs and GPU type. - declare -g gpu_count=$(nvidia-smi --list-gpus | wc -l) - if [[ $gpu_count -gt 0 ]]; then - echo "GPU found." - else - echo "Need at least 1 GPU to run benchmarking." - exit 1 - fi - declare -g gpu_type=$(echo $(nvidia-smi --query-gpu=name --format=csv,noheader) | awk '{print $2}') - echo "GPU type is $gpu_type" -} - -kill_gpu_processes() { - # kill all processes on GPU. - pkill pt_main_thread - sleep 10 - - # remove vllm config file - rm -rf ~/.config/vllm - - # Print the GPU memory usage - # so that we know if all GPU processes are killed. - gpu_memory_usage=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i 0) - # The memory usage should be 0 MB. - echo "GPU 0 Memory Usage: $gpu_memory_usage MB" -} - -json2args() { - # transforms the JSON string to command line args, and '_' is replaced to '-' - # example: - # input: { "model": "meta-llama/Llama-2-7b-chat-hf", "tensor_parallel_size": 1 } - # output: --model meta-llama/Llama-2-7b-chat-hf --tensor-parallel-size 1 - local json_string=$1 - local args=$( - echo "$json_string" | jq -r ' - to_entries | - map("--" + (.key | gsub("_"; "-")) + " " + (.value | tostring)) | - join(" ") - ' - ) - echo "$args" -} - -wait_for_server() { - # wait for vllm server to start - # return 1 if vllm server crashes - timeout 1200 bash -c ' - until curl -s localhost:8000/v1/completions > /dev/null; do - sleep 1 - done' && return 0 || return 1 -} - -run_serving_tests() { - # run serving tests using `benchmark_serving.py` - # $1: a json file specifying serving test cases - - local serving_test_file - serving_test_file=$1 - - # Iterate over serving tests - jq -c '.[]' "$serving_test_file" | while read -r params; do - # get the test name, and append the GPU type back to it. - test_name=$(echo "$params" | jq -r '.test_name') - - # if TEST_SELECTOR is set, only run the test cases that match the selector - if [[ -n "$TEST_SELECTOR" ]] && [[ ! "$test_name" =~ $TEST_SELECTOR ]]; then - echo "Skip test case $test_name." - continue - fi - - # append vllm to the test name - test_name=vllm_$test_name - - - # get common parameters - common_params=$(echo "$params" | jq -r '.common_parameters') - model=$(echo "$common_params" | jq -r '.model') - tp=$(echo "$common_params" | jq -r '.tp') - dataset_name=$(echo "$common_params" | jq -r '.dataset_name') - dataset_path=$(echo "$common_params" | jq -r '.dataset_path') - port=$(echo "$common_params" | jq -r '.port') - num_prompts=$(echo "$common_params" | jq -r '.num_prompts') - - # get client and server arguments - server_params=$(echo "$params" | jq -r '.vllm_server_parameters') - client_params=$(echo "$params" | jq -r '.vllm_client_parameters') - server_args=$(json2args "$server_params") - client_args=$(json2args "$client_params") - qps_list=$(echo "$params" | jq -r '.qps_list') - qps_list=$(echo "$qps_list" | jq -r '.[] | @sh') - echo "Running over qps list $qps_list" - - # check if there is enough GPU to run the test - if [[ $gpu_count -lt $tp ]]; then - echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name." - continue - fi - - if echo "$common_params" | jq -e 'has("fp8")' > /dev/null; then - echo "Key 'fp8' exists in common params. Use neuralmagic fp8 model for convenience." - model=$(echo "$common_params" | jq -r '.neuralmagic_quantized_model') - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - -tp $tp \ - --model $model \ - --port $port \ - $server_args" - else - echo "Key 'fp8' does not exist in common params." - server_command="python3 \ - -m vllm.entrypoints.openai.api_server \ - -tp $tp \ - --model $model \ - --port $port \ - $server_args" - fi - - # run the server - echo "Running test case $test_name" - echo "Server command: $server_command" - eval "$server_command" & - - # wait until the server is alive - wait_for_server - if [ $? -eq 0 ]; then - echo "" - echo "vllm server is up and running." - else - echo "" - echo "vllm failed to start within the timeout period." - break - fi - - # iterate over different QPS - for qps in $qps_list; do - # remove the surrounding single quote from qps - if [[ "$qps" == *"inf"* ]]; then - echo "qps was $qps" - qps="inf" - echo "now qps is $qps" - fi - - new_test_name=$test_name"_qps_"$qps - - client_command="python3 benchmark_serving.py \ - --backend vllm \ - --model $model \ - --dataset-name $dataset_name \ - --dataset-path $dataset_path \ - --num-prompts $num_prompts \ - --port $port \ - --save-result \ - --result-dir $RESULTS_FOLDER \ - --result-filename ${new_test_name}.json \ - --request-rate $qps \ - $client_args" - - echo "Running test case $test_name with qps $qps" - echo "Client command: $client_command" - - eval "$client_command" - - # record the benchmarking commands - jq_output=$(jq -n \ - --arg server "$server_command" \ - --arg client "$client_command" \ - --arg gpu "$gpu_type" \ - --arg engine "vllm" \ - '{ - server_command: $server, - client_command: $client, - gpu_type: $gpu, - engine: $engine - }') - echo "$jq_output" >"$RESULTS_FOLDER/${new_test_name}.commands" - - done - - # clean up - kill_gpu_processes - rm -rf /root/.cache/huggingface/* - done -} - - -upload_to_buildkite() { - # upload the benchmarking results to buildkite - - # if the agent binary is not found, skip uploading the results, exit 0 - if [ ! -f /workspace/buildkite-agent ]; then - echo "buildkite-agent binary not found. Skip uploading the results." - return 0 - fi - # /workspace/buildkite-agent annotate --style "success" --context "benchmark-results" --append < $RESULTS_FOLDER/${CURRENT_LLM_SERVING_ENGINE}_nightly_results.md - /workspace/buildkite-agent artifact upload "$RESULTS_FOLDER/*" -} - -main() { - - check_gpus - # enter vllm directory - cd $VLLM_SOURCE_CODE_LOC/benchmarks - declare -g RESULTS_FOLDER=results/ - mkdir -p $RESULTS_FOLDER - BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/ - - export CURRENT_LLM_SERVING_ENGINE=vllm - run_serving_tests $BENCHMARK_ROOT/tests/nightly-tests.json - - python3 -m pip install tabulate pandas - python3 $BENCHMARK_ROOT/scripts/summary-nightly-results.py - upload_to_buildkite - -} - -main "$@" diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 782d1ef9aab98842558e65122fec2f2e3190cab2..4e4d4cd4ca3c6c7933d5d47b95ebc3c7dbb32292 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -17,10 +17,17 @@ serving_column_mapping = { "request_throughput": "Tput (req/s)", "mean_ttft_ms": "Mean TTFT (ms)", "std_ttft_ms": "Std TTFT (ms)", + "median_ttft_ms": "Median TTFT (ms)", "mean_itl_ms": "Mean ITL (ms)", "std_itl_ms": "Std ITL (ms)", - "input_throughput": "Input Tput (tok/s)", + "median_itl_ms": "Median ITL (ms)", + "mean_tpot_ms": "Mean TPOT (ms)", + "std_tpot_ms": "Std TPOT (ms)", + "median_tpot_ms": "Median TPOT (ms)", + "total_token_throughput": "Total Token Tput (tok/s)", "output_throughput": "Output Tput (tok/s)", + "total_input_tokens": "Total input tokens", + "total_output_tokens": "Total output tokens", "engine": "Engine", } diff --git a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh index c785e6a0da628d64df59c4f527500c5cd83bf1ad..f16862907def1b80ed66725b94dd9ca6c2131ab8 100644 --- a/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh +++ b/.buildkite/nightly-benchmarks/scripts/wait-for-image.sh @@ -2,9 +2,11 @@ TOKEN=$(curl -s -L "https://public.ecr.aws/token?service=public.ecr.aws&scope=repository:q9t5s3a7/vllm-ci-test-repo:pull" | jq -r .token) URL="https://public.ecr.aws/v2/q9t5s3a7/vllm-ci-test-repo/manifests/$BUILDKITE_COMMIT" +TIMEOUT_SECONDS=10 + retries=0 while [ $retries -lt 1000 ]; do - if [ $(curl -s -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then + if [ $(curl -s --max-time $TIMEOUT_SECONDS -L -H "Authorization: Bearer $TOKEN" -o /dev/null -w "%{http_code}" $URL) -eq 200 ]; then exit 0 fi diff --git a/.buildkite/nightly-benchmarks/tests/latency-tests.json b/.buildkite/nightly-benchmarks/tests/latency-tests.json index 06488cd79110aa1dd68c292c3a640d2e707802b0..1841186da158f7b55072ec6f080ee7954d9ed8bf 100644 --- a/.buildkite/nightly-benchmarks/tests/latency-tests.json +++ b/.buildkite/nightly-benchmarks/tests/latency-tests.json @@ -2,7 +2,7 @@ { "test_name": "latency_llama8B_tp1", "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "num_iters_warmup": 5, @@ -12,7 +12,7 @@ { "test_name": "latency_llama70B_tp4", "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "num-iters-warmup": 5, diff --git a/.buildkite/nightly-benchmarks/tests/nightly-tests.json b/.buildkite/nightly-benchmarks/tests/nightly-tests.json index f250833c62710f1165fcdcad7f76ea39085e24e7..fda1a7a3ec53cc7fb4bbc396e28105699adca796 100644 --- a/.buildkite/nightly-benchmarks/tests/nightly-tests.json +++ b/.buildkite/nightly-benchmarks/tests/nightly-tests.json @@ -1,16 +1,18 @@ [ { - "test_name": "llama8B_tp1", - "qps_list": [4], + "test_name": "llama8B_tp1_sharegpt", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3-8B-Instruct", "tp": 1, "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 500, - "port": 8000 + "port": 8000, + "reuse_server": false }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -21,34 +23,158 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama8B_tp1_sonnet_512_16", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 16, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" + }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama8B_tp1_sonnet_512_256", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-8B-Instruct", + "tp": 1, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 256, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "enable_torch_compile": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } }, { - "test_name": "llama70B_tp4", - "qps_list": [2], + "test_name": "llama70B_tp4_sharegpt", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { "model": "meta-llama/Meta-Llama-3-70B-Instruct", "tp": 4, "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", "num_prompts": 500, - "port": 8000 + "port": 8000, + "reuse_server": false }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -59,34 +185,50 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" - }, + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } }, { - "test_name": "mixtral8x7B_tp2", - "qps_list": [2], + "test_name": "llama70B_tp4_sonnet_512_16", + "qps_list": [4,8,16,32,"inf"], "common_parameters": { - "model": "mistralai/Mixtral-8x7B-Instruct-v0.1", - "tp": 2, - "dataset_name": "sharegpt", - "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tp": 4, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", "num_prompts": 500, - "port": 8000 + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 16, + "sonnet_prefix_len": 50, + "reuse_server": true }, "lmdeploy_server_parameters": { + "dtype": "bfloat16" }, "lmdeploy_client_parameters": { }, @@ -97,20 +239,85 @@ }, "trt_server_parameters": { "model_type": "llama", - "model_dtype": "float16", - "max_batch_size": 256, + "model_dtype": "bfloat16", + "max_batch_size": 2048, "max_input_len": 4096, - "max_output_len": 4096, - "trt_llm_version": "r24.04" + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" }, "trt_client_parameters": { "endpoint": "/v2/models/ensemble/generate_stream" + }, + "vllm_server_parameters": { + "disable_log_stats": "", + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" + }, + "vllm_client_parameters": { }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { + } + }, + { + "test_name": "llama70B_tp4_sonnet_512_256", + "qps_list": [4,8,16,32,"inf"], + "common_parameters": { + "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "tp": 4, + "dataset_name": "sonnet", + "dataset_path": "./sonnet_4x.txt", + "num_prompts": 500, + "port": 8000, + "sonnet_input_len": 512, + "sonnet_output_len": 256, + "sonnet_prefix_len": 50, + "reuse_server": true + }, + "lmdeploy_server_parameters": { + "dtype": "bfloat16" + }, + "lmdeploy_client_parameters": { + }, + "tgi_server_parameters": { + }, + "tgi_client_parameters": { + "endpoint": "/generate_stream" + }, + "trt_server_parameters": { + "model_type": "llama", + "model_dtype": "bfloat16", + "max_batch_size": 2048, + "max_input_len": 4096, + "max_seq_len": 6144, + "max_num_tokens": 16384, + "trt_llm_version": "v0.11.0" + }, + "trt_client_parameters": { + "endpoint": "/v2/models/ensemble/generate_stream" + }, "vllm_server_parameters": { "disable_log_stats": "", - "disable_log_requests": "" + "disable_log_requests": "", + "gpu_memory_utilization": 0.9, + "num_scheduler_steps": 10, + "max_num_seqs": 512, + "dtype": "bfloat16" }, "vllm_client_parameters": { + }, + "sglang_server_parameters": { + "disable_radix_cache": "", + "dtype": "bfloat16" + }, + "sglang_client_parameters": { } } ] \ No newline at end of file diff --git a/.buildkite/nightly-benchmarks/tests/serving-tests.json b/.buildkite/nightly-benchmarks/tests/serving-tests.json index 300af0524d7c0dff2b12d19d4792ba8ff4124468..facb0eac749caf2c2c385b3615e2d49e67ec2816 100644 --- a/.buildkite/nightly-benchmarks/tests/serving-tests.json +++ b/.buildkite/nightly-benchmarks/tests/serving-tests.json @@ -3,7 +3,7 @@ "test_name": "serving_llama8B_tp1_sharegpt", "qps_list": [1, 4, 16, "inf"], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "swap_space": 16, "disable_log_stats": "", @@ -11,7 +11,7 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -22,7 +22,7 @@ "test_name": "serving_llama70B_tp4_sharegpt", "qps_list": [1, 4, 16, "inf"], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "swap_space": 16, "disable_log_stats": "", @@ -30,7 +30,7 @@ "load_format": "dummy" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -60,7 +60,7 @@ "test_name": "serving_llama70B_tp4_sharegpt_specdecode", "qps_list": [2], "server_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "disable_log_requests": "", "tensor_parallel_size": 4, "swap_space": 16, @@ -70,7 +70,7 @@ "use_v2_block_manager": "" }, "client_parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "backend": "vllm", "dataset_name": "sharegpt", "dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json", diff --git a/.buildkite/nightly-benchmarks/tests/throughput-tests.json b/.buildkite/nightly-benchmarks/tests/throughput-tests.json index 41ac1357487047ac2f6029376431b9f17f88b845..91ef6d16be6381576f70beb3d326f29f7d8185b0 100644 --- a/.buildkite/nightly-benchmarks/tests/throughput-tests.json +++ b/.buildkite/nightly-benchmarks/tests/throughput-tests.json @@ -2,7 +2,7 @@ { "test_name": "throughput_llama8B_tp1", "parameters": { - "model": "meta-llama/Meta-Llama-3-8B", + "model": "meta-llama/Meta-Llama-3.1-8B-Instruct", "tensor_parallel_size": 1, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", @@ -13,7 +13,7 @@ { "test_name": "throughput_llama70B_tp4", "parameters": { - "model": "meta-llama/Meta-Llama-3-70B-Instruct", + "model": "meta-llama/Meta-Llama-3.1-70B-Instruct", "tensor_parallel_size": 4, "load_format": "dummy", "dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json", diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 5be9a553dddd4dcd010e52b3341a0b09311ccd79..3b7fa0f2d94b39dc42af743707f528ae1156e0d6 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,9 +1,28 @@ steps: - - label: "Build wheel - CUDA {{matrix.cuda_version}}" + - label: "Build wheel - CUDA 12.1" agents: queue: cpu_queue commands: - - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ." + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.1.0 --tag vllm-ci:build-image --target build --progress plain ." + - "mkdir artifacts" + - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" + # rename the files to change linux -> manylinux1 + - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done" + - "mv artifacts/dist/$(ls artifacts/dist) artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/$BUILDKITE_COMMIT/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + env: + DOCKER_BUILDKIT: "1" + + - block: "Build CUDA 11.8 wheel" + key: block-build-cu118-wheel + + - label: "Build wheel - CUDA 11.8" + depends_on: block-build-cu118-wheel + agents: + queue: cpu_queue + commands: + - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=11.8.0 --tag vllm-ci:build-image --target build --progress plain ." - "mkdir artifacts" - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" # rename the files to change linux -> manylinux1 @@ -12,8 +31,3 @@ steps: - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/" env: DOCKER_BUILDKIT: "1" - matrix: - setup: - cuda_version: - - "11.8.0" - - "12.1.0" diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh old mode 100644 new mode 100755 index ccc2f090565e4303e181e8d5207f060c6002be85..df201cdc7c554c70e0581e8f3f8137a2eef21649 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,5 +1,5 @@ # This script runs test inside the corresponding ROCm docker container. -set -ex +set -o pipefail # Print ROCm version echo "--- Confirming Clean Initial State" @@ -70,15 +70,85 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p ${HF_CACHE} HF_MOUNT="/root/.cache/huggingface" -docker run \ +commands=$@ +echo "Commands:$commands" +#ignore certain kernels tests +if [[ $commands == *" kernels "* ]]; then + commands="${commands} \ + --ignore=kernels/test_attention.py \ + --ignore=kernels/test_attention_selector.py \ + --ignore=kernels/test_blocksparse_attention.py \ + --ignore=kernels/test_causal_conv1d.py \ + --ignore=kernels/test_cutlass.py \ + --ignore=kernels/test_encoder_decoder_attn.py \ + --ignore=kernels/test_flash_attn.py \ + --ignore=kernels/test_flashinfer.py \ + --ignore=kernels/test_gguf.py \ + --ignore=kernels/test_int8_quant.py \ + --ignore=kernels/test_machete_gemm.py \ + --ignore=kernels/test_mamba_ssm.py \ + --ignore=kernels/test_marlin_gemm.py \ + --ignore=kernels/test_moe.py \ + --ignore=kernels/test_prefix_prefill.py \ + --ignore=kernels/test_rand.py \ + --ignore=kernels/test_sampler.py" +fi + +#ignore certain Entrypoints tests +if [[ $commands == *" entrypoints/openai "* ]]; then + commands=${commands//" entrypoints/openai "/" entrypoints/openai \ + --ignore=entrypoints/openai/test_accuracy.py \ + --ignore=entrypoints/openai/test_audio.py \ + --ignore=entrypoints/openai/test_encoder_decoder.py \ + --ignore=entrypoints/openai/test_embedding.py \ + --ignore=entrypoints/openai/test_oot_registration.py "} +fi + +PARALLEL_JOB_COUNT=8 +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +if [[ $commands == *"--shard-id="* ]]; then + for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do + #replace shard arguments + commands=${commands//"--shard-id= "/"--shard-id=${GPU} "} + commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + echo "Shard ${GPU} commands:$commands" + docker run \ --device /dev/kfd --device /dev/dri \ --network host \ --shm-size=16gb \ --rm \ + -e HIP_VISIBLE_DEVICES=${GPU} \ -e HF_TOKEN \ -v ${HF_CACHE}:${HF_MOUNT} \ -e HF_HOME=${HF_MOUNT} \ - --name ${container_name} \ + --name ${container_name}_${GPU} \ ${image_name} \ - /bin/bash -c "${@}" - + /bin/bash -c "${commands}" \ + |& while read -r line; do echo ">>Shard $GPU: $line"; done & + PIDS+=($!) + done + #wait for all processes to finish and collect exit codes + for pid in ${PIDS[@]}; do + wait ${pid} + STATUS+=($?) + done + for st in ${STATUS[@]}; do + if [[ ${st} -ne 0 ]]; then + echo "One of the processes failed with $st" + exit ${st} + fi + done +else + docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --shm-size=16gb \ + --rm \ + -e HIP_VISIBLE_DEVICES=0 \ + -e HF_TOKEN \ + -v ${HF_CACHE}:${HF_MOUNT} \ + -e HF_HOME=${HF_MOUNT} \ + --name ${container_name} \ + ${image_name} \ + /bin/bash -c "${commands}" +fi diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh new file mode 100755 index 0000000000000000000000000000000000000000..fd60f5b6afeca639061c087fd8434a871d5b6dba --- /dev/null +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -0,0 +1,39 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image, setting --shm-size=4g for tensor parallel. +source /etc/environment +#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN=$HF_TOKEN --name cpu-test cpu-test + +# Run basic model test +docker exec cpu-test bash -c " + pip install pytest matplotlib einops transformers_stream_generator + pytest -v -s tests/models -m \"not vlm\" \ + --ignore=tests/models/test_embedding.py \ + --ignore=tests/models/test_oot_registration.py \ + --ignore=tests/models/test_registry.py \ + --ignore=tests/models/test_jamba.py \ + --ignore=tests/models/test_mamba.py \ + --ignore=tests/models/test_danube3_4b.py" # Mamba kernels and Danube3-4B on CPU is not supported + +# online inference +docker exec cpu-test bash -c " + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 45bc8eb2f84772341c6a8abee40c71151000860f..c331a9c49c0d0ca676640742234be80fa9a32978 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -22,8 +22,25 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " - pip install pytest Pillow protobuf - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pip install pytest matplotlib einops transformers_stream_generator datamodel_code_generator + pytest -v -s tests/models/encoder_decoder/language + pytest -v -s tests/models/decoder_only/language \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/decoder_only/language/test_jamba.py \ + --ignore=tests/models/decoder_only/language/test_mamba.py \ + --ignore=tests/models/decoder_only/language/test_granitemoe.py \ + --ignore=tests/models/decoder_only/language/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + +# Run compressed-tensor test +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \ + tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token" + +# Run AWQ test +docker exec cpu-test bash -c " + pytest -s -v \ + tests/quantization/test_ipex_quant.py" # online inference docker exec cpu-test bash -c " diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 4aabd123ae2347ec9e97621ef6e31277d0729c98..6989c94d46a89e37c67082021f9d75c75b75bfd6 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,5 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \ - python3 /workspace/vllm/examples/offline_inference_tpu.py +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index 22a7e76937a76e1df7c33aad37e9b5d592efcf0c..6ffa66d5ef3d66e03e647e633bbd92f915c62704 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -11,4 +11,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py +docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 93b3e3fe916639b48f761ca92febbf2052af075e..d2324d7cee60f6ef22d4d126c54b858bca625d07 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -5,264 +5,498 @@ # https://github.com/vllm-project/buildkite-ci/blob/main/scripts/test-template-aws.j2 # to generate the final pipeline yaml file. +# Documentation +# label(str): the name of the test. emoji allowed. +# fast_check(bool): whether to run this on each commit on fastcheck pipeline. +# fast_check_only(bool): run this test on fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) +# command(str): the single command to run for tests. incompatible with commands. +# commands(list): the list of commands to run for test. incompatbile with command. +# mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] +# gpu(str): override the GPU selection for the test. default is on L4 GPUs. currently only supports a100 +# num_gpus(int): override the number of GPUs for the test. default to 1 GPU. currently support 2,4. +# num_nodes(int): whether to simulate multi-node setup by launch multiple containers on one host, +# in this case, commands must be specified. the first command runs on first host, the second +# command runs on the second host. +# working_dir(str): specify the place where command should execute, default to /vllm-workspace/tests +# source_file_dependencies(list): the list of prefix to opt-in the test for, if empty, the test will always run. + +# When adding a test +# - If the test belong to an existing group, add it there +# - If the test is short, add to any existing step +# - If the test takes more than 10min, then it is okay to create a new step. +# Note that all steps execute in parallel. steps: -- label: Async Engine, Inputs, Utils, Worker Test +##### fast check tests ##### + +- label: Documentation Build # 2min + working_dir: "/vllm-workspace/test_docs/docs" fast_check: true - fast_check_only: true + no_gpu: True commands: - - pytest -v -s async_engine # Async Engine + - pip install -r requirements-docs.txt + - SPHINXOPTS=\"-W\" make html + # Check API reference (if it fails, you may have missing mock imports) + - grep \"sig sig-object py\" build/html/dev/sampling_params.html + +- label: Async Engine, Inputs, Utils, Worker Test # 24min + fast_check: true + source_file_dependencies: + - vllm/ + - tests/mq_llm_engine + - tests/async_engine + - tests/test_inputs + - tests/multimodal + - tests/test_utils + - tests/worker + commands: + - pytest -v -s mq_llm_engine # MQLLMEngine + - pytest -v -s async_engine # AsyncLLMEngine + - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py - pytest -v -s test_inputs.py - pytest -v -s multimodal - pytest -v -s test_utils.py # Utils - pytest -v -s worker # Worker -- label: Metrics, Tracing Test - fast_check: true - fast_check_only: true - commands: - - pytest -v -s metrics # Metrics - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" # Tracing - - pytest -v -s tracing - -- label: Regression Test - mirror_hardwares: [amd] - fast_check: true - command: pytest -v -s test_regression.py - working_dir: "/vllm-workspace/tests" # optional - -- label: AsyncEngine Test +- label: Basic Correctness Test # 30min #mirror_hardwares: [amd] - command: pytest -v -s async_engine - -- label: Basic Correctness Test - mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_preemption commands: - # This flashinfer installation will fail on AMD ROCm, so it is set as optional. - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py + +- label: Chunked Prefill Test + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_chunked_prefill + commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py -- label: Core Test +- label: Core Test # 10min mirror_hardwares: [amd] fast_check: true + source_file_dependencies: + - vllm/core + - vllm/distributed + - tests/core commands: - pytest -v -s core -- label: Distributed Comm Ops Test - #mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 - commands: - - pytest -v -s distributed/test_comm_ops.py - - pytest -v -s distributed/test_shm_broadcast.py - -- label: 2 Node Tests (4 GPUs in total) +- label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" - num_gpus: 2 - num_nodes: 2 - commands: - - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py - - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py - -- label: Distributed Tests (2 GPUs) + fast_check: true mirror_hardwares: [amd] - working_dir: "/vllm-workspace/tests" - num_gpus: 2 + source_file_dependencies: + - vllm/ commands: - - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py - - TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py - - pytest -v -s distributed/test_chunked_prefill_distributed.py - - pytest -v -s distributed/test_multimodal_broadcast.py - - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py - -- label: Distributed Tests (4 GPUs) - #mirror_hardwares: [amd] + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py + - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process + - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process + - pytest -v -s entrypoints/test_chat_utils.py + - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests + +- label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" num_gpus: 4 fast_check: true + source_file_dependencies: + - vllm/distributed/ + - vllm/core/ + - tests/distributed + - tests/spec_decode/e2e/test_integration_dist_tp4 + - tests/compile commands: + - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py - pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py -- label: Pipeline Parallelism Test - working_dir: "/vllm-workspace/tests" - num_gpus: 4 +- label: Metrics, Tracing Test # 10min + num_gpus: 2 + fast_check: true + source_file_dependencies: + - vllm/ + - tests/metrics + - tests/tracing commands: - - pytest -v -s distributed/test_pipeline_parallel.py + - pytest -v -s metrics + - "pip install \ + 'opentelemetry-sdk>=1.26.0,<1.27.0' \ + 'opentelemetry-api>=1.26.0,<1.27.0' \ + 'opentelemetry-exporter-otlp>=1.26.0,<1.27.0' \ + 'opentelemetry-semantic-conventions-ai>=0.4.1,<0.5.0'" + - pytest -v -s tracing -- label: Engine Test +##### fast check tests ##### +##### 1 GPU test ##### + +- label: Regression Test # 5min mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/test_regression commands: - - pytest -v -s engine test_sequence.py test_config.py test_logger.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenization + - pip install modelscope + - pytest -v -s test_regression.py + working_dir: "/vllm-workspace/tests" # optional -- label: Entrypoints Test - fast_check: true +- label: Engine Test # 10min mirror_hardwares: [amd] - + source_file_dependencies: + - vllm/ + - tests/engine + - tests/tokenization commands: - - pytest -v -s entrypoints/llm - - pytest -v -s entrypoints/openai + - pytest -v -s engine test_sequence.py test_config.py test_logger.py + # OOM in the CI unless we run this separately + - pytest -v -s tokenization -- label: Examples Test +- label: Examples Test # 15min working_dir: "/vllm-workspace/examples" - mirror_hardwares: [amd] + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/entrypoints + - examples/ commands: - # install tensorizer for tensorize_vllm_model.py - - pip install awscli tensorizer + - pip install awscli tensorizer # for llava example and tensorizer test - python3 offline_inference.py - python3 cpu_offload.py + - python3 offline_inference_chat.py - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 offline_inference_vision_language.py + - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors + - python3 offline_inference_encoder_decoder.py + - python3 offline_profile.py --model facebook/opt-125m -- label: Inputs Test +- label: Prefix Caching Test # 9min #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/prefix_caching commands: - - pytest -v -s test_inputs.py - - pytest -v -s multimodal + - pytest -v -s prefix_caching -# - label: Kernels Test %N -# #mirror_hardwares: [amd] -# commands: -# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl -# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT -# parallelism: 4 +- label: Samplers Test # 36min + source_file_dependencies: + - vllm/model_executor/layers + - vllm/sampling_metadata.py + - tests/samplers + commands: + - pytest -v -s samplers + - VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers -- label: Models Test - #mirror_hardwares: [amd] +- label: LogitsProcessor Test # 5min + mirror_hardwares: [amd] + source_file_dependencies: + - vllm/model_executor/layers + - tests/test_logits_processor + command: pytest -v -s test_logits_processor.py + +- label: Speculative decoding tests # 30min + source_file_dependencies: + - vllm/spec_decode + - tests/spec_decode commands: - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - - pytest -v -s models -m \"not vlm\" + - pytest -v -s spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py -- label: Vision Language Models Test +- label: LoRA Test %N # 15min each mirror_hardwares: [amd] + source_file_dependencies: + - vllm/lora + - tests/lora + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py + parallelism: 4 + +- label: "PyTorch Fullgraph Smoke Test" # 9min + fast_check: true + source_file_dependencies: + - vllm/ + - tests/compile commands: - - pytest -v -s models -m vlm + - pytest -v -s compile/test_basic_correctness.py + +# TODO: re-write in comparison tests, and fix symbolic shape +# for quantization ops. +# - label: "PyTorch Fullgraph Test" # 18min +# source_file_dependencies: +# - vllm/ +# - tests/compile +# commands: +# - pytest -v -s compile/test_full_graph.py -- label: Prefix Caching Test +- label: Kernels Test %N # 1h each mirror_hardwares: [amd] + source_file_dependencies: + - csrc/ + - vllm/attention + - tests/kernels commands: - - pytest -v -s prefix_caching + - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT + parallelism: 4 -- label: Samplers Test - #mirror_hardwares: [amd] - command: pytest -v -s samplers +- label: Tensorizer Test # 11min + mirror_hardwares: [amd] + soft_fail: true + source_file_dependencies: + - vllm/model_executor/model_loader + - tests/tensorizer_loader + commands: + - apt-get update && apt-get install -y curl libsodium23 + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s tensorizer_loader -- label: LogitsProcessor Test +- label: Benchmarks # 9min + working_dir: "/vllm-workspace/.buildkite" mirror_hardwares: [amd] - command: pytest -v -s test_logits_processor.py + source_file_dependencies: + - benchmarks/ + commands: + - pip install aiohttp + - bash run-benchmarks.sh -- label: Utils Test +- label: Quantization Test # 33min + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + - tests/quantization + command: VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization + +- label: LM Eval Small Models # 53min + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization commands: - - pytest -v -s test_utils.py - - pytest -v -s test_embedded_commit.py + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-small.txt -t 1 -- label: Worker Test - mirror_hardwares: [amd] - command: pytest -v -s worker +- label: Encoder Decoder tests # 5min + source_file_dependencies: + - vllm/ + - tests/encoder_decoder + commands: + - pytest -v -s encoder_decoder + +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use -- label: Speculative decoding tests - #mirror_hardwares: [amd] +##### models test ##### + +- label: Basic Models Test # 3min + source_file_dependencies: + - vllm/ + - tests/models commands: - # See https://github.com/vllm-project/vllm/issues/5152 - - export VLLM_ATTENTION_BACKEND=XFORMERS - - pytest -v -s spec_decode + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s models/test_oot_registration.py # it needs a clean process + - pytest -v -s models/*.py --ignore=models/test_oot_registration.py -# - label: LoRA Test %N -# #mirror_hardwares: [amd] -# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py -# parallelism: 4 +- label: Decoder-only Language Models Test # 1h36min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/language + commands: + - pytest -v -s models/decoder_only/language -# - label: LoRA Long Context (Distributed) -# #mirror_hardwares: [amd] -# num_gpus: 4 -# # This test runs llama 13B, so it is required to run on 4 GPUs. -# commands: -# # FIXIT: find out which code initialize cuda before running the test -# # before the fix, we need to use spawn to test it -# - export VLLM_WORKER_MULTIPROC_METHOD=spawn -# - pytest -v -s -x lora/test_long_context.py +- label: Decoder-only Multi-Modal Models Test # 1h31min + #mirror_hardwares: [amd] + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + commands: + - pytest -v -s models/decoder_only/audio_language + - pytest -v -s models/decoder_only/vision_language -- label: Tensorizer Test +- label: Other Models Test # 6min #mirror_hardwares: [amd] - fast_check: true + source_file_dependencies: + - vllm/ + - tests/models/embedding/language + - tests/models/embedding/vision_language + - tests/models/encoder_decoder/language + - tests/models/encoder_decoder/vision_language commands: - - apt-get install -y curl libsodium23 - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - pytest -v -s tensorizer_loader + - pytest -v -s models/embedding/language + - pytest -v -s models/embedding/vision_language + - pytest -v -s models/encoder_decoder/language + - pytest -v -s models/encoder_decoder/vision_language + +# This test is used only in PR development phase to test individual models and should never run on main +- label: Custom Models Test + optional: true + commands: + - echo 'Testing custom models...' + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* -- label: Metrics Test - mirror_hardwares: [amd] - command: pytest -v -s metrics +##### 1 GPU test ##### +##### multi gpus test ##### + +- label: Distributed Comm Ops Test # 7min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed + - tests/distributed + commands: + - pytest -v -s distributed/test_comm_ops.py + - pytest -v -s distributed/test_shm_broadcast.py + +- label: 2 Node Tests (4 GPUs in total) # 16min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + num_nodes: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + commands: + - # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py + - VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py + - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) + - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' -- label: Quantization Test +- label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] - command: pytest -v -s quantization + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ + - vllm/compilation + commands: + - pytest -v -s ./compile/test_basic_correctness.py + - pytest -v -s ./compile/test_wrapper.py + - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' + - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus + # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py + - pip install -e ./plugins/vllm_add_dummy_model + - pytest -v -s distributed/test_distributed_oot.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py + - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py -- label: Tracing Test - commands: - - "pip install \ - opentelemetry-sdk \ - opentelemetry-api \ - opentelemetry-exporter-otlp \ - opentelemetry-semantic-conventions-ai" - - pytest -v -s tracing - -- label: Benchmarks - working_dir: "/vllm-workspace/.buildkite" - mirror_hardwares: [amd] +- label: Multi-step Tests (4 GPUs) # 36min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/model_executor/layers/sampler.py + - vllm/sequence.py + - vllm/worker/worker_base.py + - vllm/worker/worker.py + - vllm/worker/multi_step_worker.py + - vllm/worker/model_runner_base.py + - vllm/worker/model_runner.py + - vllm/worker/multi_step_model_runner.py + - vllm/engine + - tests/multi_step commands: - - pip install aiohttp - - bash run-benchmarks.sh + - pytest -v -s multi_step/test_correctness_async_llm.py + - pytest -v -s multi_step/test_correctness_llm.py -- label: LM Eval Small Models - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" +- label: Pipeline Parallelism Test # 45min + working_dir: "/vllm-workspace/tests" + num_gpus: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/model_executor/models/ + - tests/distributed/ commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-small.txt -t 1 + - pytest -v -s distributed/test_pp_cudagraph.py + - pytest -v -s distributed/test_pipeline_parallel.py -- label: LM Eval Large Models - gpu: a100 +- label: LoRA Long Context (Distributed) # 11min + # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 - working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + soft_fail: true + source_file_dependencies: + - vllm/lora + - tests/lora/test_long_context commands: - - pip install lm-eval - - export VLLM_WORKER_MULTIPROC_METHOD=spawn - - bash ./run-tests.sh -c configs/models-large.txt -t 4 + # FIXIT: find out which code initialize cuda before running the test + # before the fix, we need to use spawn to test it + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - pytest -v -s -x lora/test_long_context.py -- label: Documentation Build - working_dir: "/vllm-workspace/test_docs/docs" - fast_check: true - no_gpu: True +- label: Weight Loading Multiple GPU Test # 33min + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + source_file_dependencies: + - vllm/ + - tests/weight_loading commands: - - pip install -r requirements-docs.txt - - SPHINXOPTS=\"-W\" make html + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt + + +##### multi gpus test ##### +##### A100 test ##### -- label: Distributed Tests (A100) +- label: Distributed Tests (A100) # optional gpu: a100 num_gpus: 4 + source_file_dependencies: + - vllm/ commands: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl - - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus - pytest -v -s -x lora/test_mixtral.py + +- label: LM Eval Large Models # optional + gpu: a100 + num_gpus: 4 + working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" + source_file_dependencies: + - csrc/ + - vllm/model_executor/layers/quantization + commands: + - export VLLM_WORKER_MULTIPROC_METHOD=spawn + - bash ./run-tests.sh -c configs/models-large.txt -t 4 diff --git a/.dockerignore b/.dockerignore index 5cfe0dcb065dc88226d754e98abbaa52125b1ccc..3863656915d035c3831614a5fcba05e09699542f 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1 +1,33 @@ +/.venv +/build +dist vllm/*.so + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +.mypy_cache + +# Distribution / packaging +.Python +/build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +/dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000000000000000000000000000000..cd721971d01d6b1fc333458f4e9e05bfaa976641 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,30 @@ +# See https://help.github.com/articles/about-codeowners/ +# for more info about CODEOWNERS file + +# This lists cover the "core" components of vLLM that require careful review +/vllm/attention/backends/abstract.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/core @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/engine/llm_engine.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/executor/executor_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker_base.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/worker/worker.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +/vllm/model_executor/layers/sampler.py @WoosukKwon @zhuohan123 @youkaichao @alexm-neuralmagic @comaniac @njhill +CMakeLists.txt @tlrmchlsmth @WoosukKwon + +# Test ownership +/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo +/tests/test_inputs.py @DarkLight1337 @ywang96 +/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo +/tests/models @DarkLight1337 @ywang96 +/tests/multimodal @DarkLight1337 @ywang96 +/tests/prefix_caching @comaniac @KuntaiDu +/tests/spec_decode @njhill @LiuXiaoxuanPKU +/tests/kernels @tlrmchlsmth @WoosukKwon +/tests/quantization @mgoin @robertgshaw2-neuralmagic +/.buildkite/lm-eval-harness @mgoin @simon-mo +/tests/distributed/test_multi_node_assignment.py @youkaichao +/tests/distributed/test_pipeline_parallel.py @youkaichao +/tests/distributed/test_same_node.py @youkaichao +/tests/multi_step @alexm-neuralmagic @comaniac +/tests/weight_loading @mgoin @youkaichao +/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac diff --git a/.github/ISSUE_TEMPLATE/100-documentation.yml b/.github/ISSUE_TEMPLATE/100-documentation.yml index 501c0aa48b887431ef69825a07dce98fba3d8959..74d397b231acdc23dd929637b1fa68f7d4ecd25b 100644 --- a/.github/ISSUE_TEMPLATE/100-documentation.yml +++ b/.github/ISSUE_TEMPLATE/100-documentation.yml @@ -20,3 +20,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/200-installation.yml b/.github/ISSUE_TEMPLATE/200-installation.yml index df41ade8c3c018dfd4ac8e2c2671dbfe352dda61..590e56c137813059a695e501833138e0a85a947e 100644 --- a/.github/ISSUE_TEMPLATE/200-installation.yml +++ b/.github/ISSUE_TEMPLATE/200-installation.yml @@ -38,3 +38,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/300-usage.yml b/.github/ISSUE_TEMPLATE/300-usage.yml index 54763af1058f643a6947f23fb4bcf3af9b145826..004798a388a63d949c8ef0e73194e53c3c4c0e39 100644 --- a/.github/ISSUE_TEMPLATE/300-usage.yml +++ b/.github/ISSUE_TEMPLATE/300-usage.yml @@ -36,3 +36,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/400-bug report.yml b/.github/ISSUE_TEMPLATE/400-bug report.yml index ce980c3f4a01d4df29be6feafd07922d164242f4..30db1721a9df71b4b7ef017208d7b6053cf8cb01 100644 --- a/.github/ISSUE_TEMPLATE/400-bug report.yml +++ b/.github/ISSUE_TEMPLATE/400-bug report.yml @@ -20,11 +20,25 @@ body: ``` It is suggested to download and execute the latest script, as vllm might frequently update the diagnosis information needed for accurately and quickly responding to issues. value: | +
+ The output of `python collect_env.py` + ```text - The output of `python collect_env.py` + Your output of `python collect_env.py` here ``` + +
validations: required: true +- type: textarea + attributes: + label: Model Input Dumps + description: | + If you are facing crashing due to illegal memory access or other issues with model execution, vLLM may dump the problematic input of the model. In this case, you will see the message `Error in model execution (input dumped to /tmp/err_xxx.pkl)`. If you see this message, please zip the file (because GitHub doesn't support .pkl file format) and upload it here. This will help us to reproduce the issue and facilitate the debugging process. + placeholder: | + Upload the dumped input file. + validations: + required: false - type: textarea attributes: label: 🐛 Describe the bug @@ -84,3 +98,10 @@ body: - If the error only appears in vllm, please provide the detailed script of how you run `transformers` and `vllm`, also highlight the difference and what you expect. Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/500-feature request.yml b/.github/ISSUE_TEMPLATE/500-feature request.yml index 47a90628c76cea99d70b42db2bafd771ae389303..097d88f50930d64e765c6fa92d3d2094c69426e8 100644 --- a/.github/ISSUE_TEMPLATE/500-feature request.yml +++ b/.github/ISSUE_TEMPLATE/500-feature request.yml @@ -29,3 +29,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/600-new model.yml b/.github/ISSUE_TEMPLATE/600-new model.yml index bbddbfd67138ac91ec344351622c356551ad6364..794617a0cfdf6401c207bb2bae085bcd05f86921 100644 --- a/.github/ISSUE_TEMPLATE/600-new model.yml +++ b/.github/ISSUE_TEMPLATE/600-new model.yml @@ -31,3 +31,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/700-performance discussion.yml b/.github/ISSUE_TEMPLATE/700-performance discussion.yml index 4f8843420a94e94ba206aded662750c165e1bb3d..273f50d59cf76a61fa96e02478885f27d7d0f40b 100644 --- a/.github/ISSUE_TEMPLATE/700-performance discussion.yml +++ b/.github/ISSUE_TEMPLATE/700-performance discussion.yml @@ -50,3 +50,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml index 5382b124dcd799e2b084baf61038133365bd07fd..e447c077473f0ed5dafef68e396bc32eb7d35a3e 100644 --- a/.github/ISSUE_TEMPLATE/750-RFC.yml +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -47,3 +47,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/ISSUE_TEMPLATE/800-misc discussion.yml b/.github/ISSUE_TEMPLATE/800-misc discussion.yml index ddb10f72db293dfa4427f4f3f3fa68e5d371d382..79e6e9080d51cc513a7a41ee7bf7a1d8baf2dad0 100644 --- a/.github/ISSUE_TEMPLATE/800-misc discussion.yml +++ b/.github/ISSUE_TEMPLATE/800-misc discussion.yml @@ -19,3 +19,10 @@ body: attributes: value: > Thanks for contributing 🎉! +- type: checkboxes + id: askllm + attributes: + label: Before submitting a new issue... + options: + - label: Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the [documentation page](https://docs.vllm.ai/en/latest/), which can answer lots of frequently asked questions. + required: true diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 262ce8e1530a83a05f618e3874a8d326ca077c3b..be0afc6305044e84f68426e527c37f66faa99f31 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -39,6 +39,16 @@ FIX #xxxx (*link existing issues this PR will resolve*)
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.
  • +

    Adding or changing kernels

    +

    Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.

    + +

    Notes for Large Changes

    Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

    diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000000000000000000000000000000000..6fddca0d6e4b9fb23127719b8c5f7a7e1b5821c6 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,7 @@ +version: 2 +updates: + # Maintain dependencies for GitHub Actions + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml new file mode 100644 index 0000000000000000000000000000000000000000..2a0e3239f58dadc565d0e690ae28ffc9e225c1ea --- /dev/null +++ b/.github/workflows/actionlint.yml @@ -0,0 +1,37 @@ +name: Lint GitHub Actions workflows +on: + push: + branches: + - "main" + paths: + - '.github/workflows/*.ya?ml' + - '.github/workflows/actionlint.*' + pull_request: + branches: + - "main" + paths: + - '.github/workflows/*.ya?ml' + - '.github/workflows/actionlint.*' + +env: + LC_ALL: en_US.UTF-8 + +defaults: + run: + shell: bash + +permissions: + contents: read + +jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - name: "Checkout" + uses: actions/checkout@eef61447b9ff4aafe5dcd4e0bbf5d482be7e7871 # v4.2.1 + with: + fetch-depth: 0 + + - name: "Run actionlint" + run: | + tools/actionlint.sh -color diff --git a/.github/workflows/add_label_automerge.yml b/.github/workflows/add_label_automerge.yml index cd53b764c7200eee0b335991648812ba9bf0f9db..2e7c7f7f087afe279f5dfc7abd0aa9ae9b3808b1 100644 --- a/.github/workflows/add_label_automerge.yml +++ b/.github/workflows/add_label_automerge.yml @@ -8,7 +8,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Add label - uses: actions/github-script@v5 + uses: actions/github-script@v7 with: script: | github.rest.issues.addLabels({ diff --git a/.github/workflows/add_label_ready_comment.yml b/.github/workflows/add_label_ready_comment.yml deleted file mode 100644 index 729c1452af03db14900ca90564ea308fa98e5dbb..0000000000000000000000000000000000000000 --- a/.github/workflows/add_label_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Add Ready Label on Ready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready') - steps: - - name: Add label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: ['ready'] - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/clang-format.yml b/.github/workflows/clang-format.yml index 79b85d8cad0d5f7a9d5b492f96dd6748a5004d0d..064af291009fa63644946502967ddbaea7160e3f 100644 --- a/.github/workflows/clang-format.yml +++ b/.github/workflows/clang-format.yml @@ -17,9 +17,9 @@ jobs: matrix: python-version: ["3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies @@ -30,6 +30,11 @@ jobs: run: | EXCLUDES=( 'csrc/moe/topk_softmax_kernels.cu' + 'csrc/quantization/gguf/ggml-common.h' + 'csrc/quantization/gguf/dequantize.cuh' + 'csrc/quantization/gguf/vecdotq.cuh' + 'csrc/quantization/gguf/mmq.cuh' + 'csrc/quantization/gguf/mmvq.cuh' ) find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \ | grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \ diff --git a/.github/workflows/matchers/actionlint.json b/.github/workflows/matchers/actionlint.json new file mode 100644 index 0000000000000000000000000000000000000000..4613e1617bfe25b72dbf381d0da6e309112fbcbf --- /dev/null +++ b/.github/workflows/matchers/actionlint.json @@ -0,0 +1,17 @@ +{ + "problemMatcher": [ + { + "owner": "actionlint", + "pattern": [ + { + "regexp": "^(?:\\x1b\\[\\d+m)?(.+?)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*:(?:\\x1b\\[\\d+m)*(\\d+)(?:\\x1b\\[\\d+m)*: (?:\\x1b\\[\\d+m)*(.+?)(?:\\x1b\\[\\d+m)* \\[(.+?)\\]$", + "file": 1, + "line": 2, + "column": 3, + "message": 4, + "code": 5 + } + ] + } + ] +} diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 8d423657630c293b69f4b2c1191e30a9ab8f0822..22e3564779ad9e90fa120c40f0f6390e9d173f5e 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -11,38 +11,25 @@ on: - main jobs: - ruff: + mypy: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install mypy==1.9.0 + pip install mypy==1.11.1 pip install types-setuptools pip install types-PyYAML pip install types-requests pip install types-setuptools - name: Mypy run: | - mypy - mypy tests --follow-imports skip - mypy vllm/attention --follow-imports skip - mypy vllm/core --follow-imports skip - mypy vllm/distributed --follow-imports skip - mypy vllm/engine --follow-imports skip - mypy vllm/entrypoints --follow-imports skip - mypy vllm/executor --follow-imports skip - mypy vllm/lora --follow-imports skip - mypy vllm/model_executor --follow-imports skip - mypy vllm/prompt_adapter --follow-imports skip - mypy vllm/spec_decode --follow-imports skip - mypy vllm/worker --follow-imports skip - + tools/mypy.sh diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index aeeaf6efab04353181833ce0ea22030d8be12b39..96549b3f99181e48b349716388f508237388df97 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -21,16 +21,16 @@ jobs: upload_url: ${{ steps.create_release.outputs.upload_url }} steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Extract branch info shell: bash run: | - echo "release_tag=${GITHUB_REF#refs/*/}" >> $GITHUB_ENV + echo "release_tag=${GITHUB_REF#refs/*/}" >> "$GITHUB_ENV" - name: Create Release id: create_release - uses: "actions/github-script@v6" + uses: "actions/github-script@v7" env: RELEASE_TAG: ${{ env.release_tag }} with: @@ -54,7 +54,7 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Setup ccache uses: hendrikmuhs/ccache-action@v1.2 @@ -68,7 +68,7 @@ jobs: bash -x .github/workflows/scripts/env.sh - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -86,10 +86,10 @@ jobs: CMAKE_BUILD_TYPE: Release # do not compile with debug symbol to reduce wheel size run: | bash -x .github/workflows/scripts/build.sh ${{ matrix.python-version }} ${{ matrix.cuda-version }} - wheel_name=$(ls dist/*whl | xargs -n 1 basename) + wheel_name=$(find dist -name "*whl" -print0 | xargs -0 -n 1 basename) asset_name=${wheel_name//"linux"/"manylinux1"} - echo "wheel_name=${wheel_name}" >> $GITHUB_ENV - echo "asset_name=${asset_name}" >> $GITHUB_ENV + echo "wheel_name=${wheel_name}" >> "$GITHUB_ENV" + echo "asset_name=${asset_name}" >> "$GITHUB_ENV" - name: Upload Release Asset uses: actions/upload-release-asset@v1 diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 390c88bb65308a23cadbc6f2f33245e45f230486..d1791c3bc865ab081d19e97c72421e9a5f3585d5 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -8,14 +8,14 @@ jobs: runs-on: ubuntu-latest steps: - name: Remind to run full CI on PR - uses: actions/github-script@v6 + uses: actions/github-script@v7 with: script: | github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, - body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀' + body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀' }) env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/remove_label_not_ready_comment.yml b/.github/workflows/remove_label_not_ready_comment.yml deleted file mode 100644 index d1da7726eaee3d5a24f98d2149d665e7a02ce0a3..0000000000000000000000000000000000000000 --- a/.github/workflows/remove_label_not_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Remove ready Label on notready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready') - steps: - - name: Remove ready label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.removeLabel({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - name: 'ready' - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/ruff.yml b/.github/workflows/ruff.yml index 1a794af572fefba3c138f50fbd8744fe74b65e6f..be73fb85ed1fa85c0d698b7c21a8f12f766e3393 100644 --- a/.github/workflows/ruff.yml +++ b/.github/workflows/ruff.yml @@ -17,18 +17,18 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ruff==0.1.5 codespell==2.3.0 tomli==2.0.1 isort==5.13.2 + pip install -r requirements-lint.txt - name: Analysing the code with ruff run: | - ruff . + ruff check . - name: Spelling check with codespell run: | codespell --toml pyproject.toml diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index 0a759d303238b0275cbd29cef9aad77e987f8adc..122e4e101e2011898335d8863f5bfae403ff2792 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -1,4 +1,5 @@ #!/bin/bash +set -eux python_executable=python$1 cuda_home=/usr/local/cuda-$2 @@ -8,12 +9,15 @@ PATH=${cuda_home}/bin:$PATH LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH # Install requirements -$python_executable -m pip install wheel packaging -$python_executable -m pip install -r requirements-cuda.txt +$python_executable -m pip install -r requirements-build.txt -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM export MAX_JOBS=1 # Make sure release wheels are built for the following architectures export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX" +export VLLM_FA_CMAKE_GPU_ARCHES="80-real;90-real" + +bash tools/check_repo.sh + # Build $python_executable setup.py bdist_wheel --dist-dir=dist diff --git a/.github/workflows/yapf.yml b/.github/workflows/yapf.yml index c89f82dfaaaf6aabbdb7dd2168c939a43691facb..eb728ae04dfc18f888f2d8d55e9d50e8a10d28f2 100644 --- a/.github/workflows/yapf.yml +++ b/.github/workflows/yapf.yml @@ -16,9 +16,9 @@ jobs: matrix: python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.gitignore b/.gitignore index 17184b19127ca10f25a29a8add79f8ab679285e4..1ea6e3419db2a16bd8b77438fb25be6e792475c4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ -# vllm commit id, generated by setup.py -vllm/commit_id.py +# version file generated by setuptools-scm +/vllm/_version.py + +# vllm-flash-attn built from source +vllm/vllm_flash_attn/ # Byte-compiled / optimized / DLL files __pycache__/ @@ -12,6 +15,8 @@ __pycache__/ # Distribution / packaging .Python build/ +cmake-build-*/ +CMakeUserPresets.json develop-eggs/ dist/ downloads/ @@ -28,6 +33,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +/.deps/ # PyInstaller # Usually these files are written by a python script from a template @@ -87,6 +93,9 @@ target/ profile_default/ ipython_config.py +# generated files +**/generated/** + # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: @@ -189,4 +198,7 @@ _build/ hip_compat.h # Benchmark dataset -*.json +benchmarks/*.json + +# Linting +actionlint diff --git a/.readthedocs.yaml b/.readthedocs.yaml index f1959ad2743f390c666b9beda9004a71cde8038e..42cbf18a0f7122002bc001ac25e53d9be8a6ed55 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -13,10 +13,10 @@ sphinx: fail_on_warning: true # If using Sphinx, optionally build your docs in additional formats such as PDF -formats: - - pdf +formats: [] # Optionally declare the Python requirements required to build your docs python: install: - requirements: docs/requirements-docs.txt + diff --git a/CMakeLists.txt b/CMakeLists.txt index 569ffab03a04b6747f4e19785497038e59e1a2ca..ad9616fc15eb451804d69cdc4c801772596aaea5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,5 +1,16 @@ -cmake_minimum_required(VERSION 3.21) +cmake_minimum_required(VERSION 3.26) +# When building directly using CMake, make sure you run the install step +# (it places the .so files in the correct location). +# +# Example: +# mkdir build && cd build +# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_INSTALL_PREFIX=.. .. +# cmake --build . --target install +# +# If you want to only build one target, make sure to install it manually: +# cmake --build . --target _C +# cmake --install . --component _C project(vllm_extensions LANGUAGES CXX) # CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py) @@ -12,6 +23,14 @@ message(STATUS "Target device: ${VLLM_TARGET_DEVICE}") include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake) add_compile_options(-w) + + +# Suppress potential warnings about unused manually-specified variables +set(ignoreMe "${VLLM_PYTHON_PATH}") + +# Prevent installation of dependencies (cutlass) by default. +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) + # # Supported python versions. These versions will be searched in order, the # first match will be selected. These should be kept in sync with setup.py. @@ -69,19 +88,6 @@ endif() find_package(Torch REQUIRED) # -# Add the `default` target which detects which extensions should be -# built based on platform/architecture. This is the same logic that -# setup.py uses to select which extensions should be built and should -# be kept in sync. -# -# The `default` target makes direct use of cmake easier since knowledge -# of which extensions are supported has been factored in, e.g. -# -# mkdir build && cd build -# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm .. -# cmake --build . --target default -# -add_custom_target(default) message(STATUS "Enabling core extension.") # Define _core_C extension @@ -99,8 +105,6 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) -add_dependencies(default _core_C) - # # Forward the non-CUDA device extensions to external CMake scripts. # @@ -143,15 +147,33 @@ else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") endif() -# -# Override the GPU architectures detected by cmake/torch and filter them by -# the supported versions for the current language. -# The final set of arches is stored in `VLLM_GPU_ARCHES`. -# -override_gpu_arches(VLLM_GPU_ARCHES - ${VLLM_GPU_LANG} - "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") - + +if(VLLM_GPU_LANG STREQUAL "CUDA") + # + # For cuda we want to be able to control which architectures we compile for on + # a per-file basis in order to cut down on compile time. So here we extract + # the set of architectures we want to compile for and remove the from the + # CMAKE_CUDA_FLAGS so that they are not applied globally. + # + clear_cuda_arches(CUDA_ARCH_FLAGS) + extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") + message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") + # Filter the target architectures by the supported supported archs + # since for some files we will build for all CUDA_ARCHS. + cuda_archs_loose_intersection(CUDA_ARCHS + "${CUDA_SUPPORTED_ARCHS}" "${CUDA_ARCHS}") + message(STATUS "CUDA supported target architectures: ${CUDA_ARCHS}") +else() + # + # For other GPU targets override the GPU architectures detected by cmake/torch + # and filter them by the supported versions for the current language. + # The final set of arches is stored in `VLLM_GPU_ARCHES`. + # + override_gpu_arches(VLLM_GPU_ARCHES + ${VLLM_GPU_LANG} + "${${VLLM_GPU_LANG}_SUPPORTED_ARCHS}") +endif() + # # Query torch for additional GPU compilation flags for the given # `VLLM_GPU_LANG`. @@ -166,6 +188,17 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() + +# +# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. +# Configure it to place files in vllm/.deps, in order to play nicely with sccache. +# +include(FetchContent) +get_filename_component(PROJECT_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE) +file(MAKE_DIRECTORY "${FETCHCONTENT_BASE_DIR}") +set(FETCHCONTENT_BASE_DIR "${PROJECT_ROOT_DIR}/.deps") +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") + # # Define other extension targets # @@ -184,9 +217,9 @@ set(VLLM_EXT_SRC "csrc/opt/transpose_kernels.cu" "csrc/opt/activation_kernels_opt.cu" "csrc/attention/attention_kernels_opt.cu" + "csrc/attention/attention_kernels_opt_tc.cu" "csrc/opt/layernorm_kernels_opt.cu" - "csrc/quantization/squeezellm/quant_cuda_kernel.cu" - "csrc/quantization/gptq/q_gemm.cu" + # "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" # "csrc/quantization/fp8/common.cu" "csrc/cuda_utils_kernels.cu" @@ -195,46 +228,188 @@ set(VLLM_EXT_SRC "csrc/torch_bindings.cpp") if(VLLM_GPU_LANG STREQUAL "CUDA") - include(FetchContent) SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") + + # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. + set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") + FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git - # CUTLASS 3.5.1 - GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9 + GIT_TAG v3.5.1 GIT_PROGRESS TRUE + + # Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history. + # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. + # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE + GIT_SHALLOW TRUE ) FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" - "csrc/quantization/gptq_marlin/gptq_marlin.cu" - "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" - "csrc/quantization/fp8/fp8_marlin.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + "csrc/permute_cols.cu" + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + # Only build Marlin kernels if we are building for at least some compatible archs. + # Keep building Marlin for 9.0 as there are some group sizes and shapes that + # are not supported by Machete yet. + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS}) + if (MARLIN_ARCHS) + set(MARLIN_SRCS + "csrc/quantization/fp8/fp8_marlin.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_SRCS}" + CUDA_ARCHS "${MARLIN_ARCHS}") + list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") + else() + message(STATUS "Not building Marlin kernels as no compatible archs found" + "in CUDA target architectures") + endif() # - # The CUTLASS kernels for Hopper require sm90a to be enabled. - # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. - # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") + message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) endif() + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # + # Machete kernels + + # The machete kernels only work on hopper and require CUDA 12.0 or later. + # Only build Machete kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) + # + # For the Machete kernels we automatically generate sources for various + # preselected input type pairs and schedules. + # Generate sources: + set(MACHETE_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) + file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) + + message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") + message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} + OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} + CACHE STRING "Last run machete generate script hash" FORCE) + message(STATUS "Machete generation completed successfully.") + endif() + else() + message(STATUS "Machete generation script has not changed, skipping generation.") + endif() + + # Add machete generated sources + file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") + list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) + + # forward compatible + set_gencode_flags_for_srcs( + SRCS "${MACHETE_GEN_SOURCES}" + CUDA_ARCHS "${MACHETE_ARCHS}") + + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) + + message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 + AND MACHETE_ARCHS) + message(STATUS "Not building Machete kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building Machete kernels as no compatible archs " + "found in CUDA target architectures") + endif() + endif() +# if CUDA endif endif() +message(STATUS "Enabling C extension.") define_gpu_extension_target( _C DESTINATION vllm @@ -246,6 +421,12 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +# If CUTLASS is compiled on NVCC >= 12.5, it by default uses +# cudaGetDriverEntryPointByVersion as a wrapper to avoid directly calling the +# driver API. This causes problems when linking with earlier versions of CUDA. +# Setting this variable sidesteps the issue by calling the driver directly. +target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) + # # _moe_C extension # @@ -254,6 +435,36 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") +set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + +if(VLLM_GPU_LANG STREQUAL "CUDA") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") + if (MARLIN_MOE_ARCHS) + set(MARLIN_MOE_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu" + "csrc/moe/marlin_moe_ops.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") + else() + message(STATUS "Not building Marlin MOE kernels as no compatible archs found" + "in CUDA target architectures") + endif() +endif() + +message(STATUS "Enabling moe extension.") define_gpu_extension_target( _moe_C DESTINATION vllm @@ -264,13 +475,102 @@ define_gpu_extension_target( USE_SABI 3 WITH_SOABI) +#[[ +if(VLLM_GPU_LANG STREQUAL "HIP") + # + # _rocm_C extension + # + set(VLLM_ROCM_EXT_SRC + "csrc/rocm/torch_bindings.cpp" + "csrc/rocm/attention.cu") + + define_gpu_extension_target( + _rocm_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${VLLM_ROCM_EXT_SRC} + COMPILE_FLAGS ${VLLM_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + USE_SABI 3 + WITH_SOABI) +endif() +]] + +# vllm-flash-attn currently only supported on CUDA +if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") + return() +endif () +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() -if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP") - message(STATUS "Enabling C extension.") - add_dependencies(default _C) +# +# Build vLLM flash attention from source +# +# IMPORTANT: This has to be the last thing we do, because vllm-flash-attn uses the same macros/functions as vLLM. +# Because functions all belong to the global scope, vllm-flash-attn's functions overwrite vLLMs. +# They should be identical but if they aren't, this is a massive footgun. +# +# The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. +# To only install vllm-flash-attn, use --component vllm_flash_attn_c. +# If no component is specified, vllm-flash-attn is still installed. - message(STATUS "Enabling moe extension.") - add_dependencies(default _moe_C) +# If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. +# This is to enable local development of vllm-flash-attn within vLLM. +# It can be set as an environment variable or passed as a cmake argument. +# The environment variable takes precedence. +if (DEFINED ENV{VLLM_FLASH_ATTN_SRC_DIR}) + set(VLLM_FLASH_ATTN_SRC_DIR $ENV{VLLM_FLASH_ATTN_SRC_DIR}) +endif() +if(VLLM_FLASH_ATTN_SRC_DIR) + FetchContent_Declare(vllm-flash-attn SOURCE_DIR ${VLLM_FLASH_ATTN_SRC_DIR}) +#[[ +else() + FetchContent_Declare( + vllm-flash-attn + GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git + GIT_TAG 013f0c4fc47e6574060879d9734c1df8c5c273bd + GIT_PROGRESS TRUE + ) +]] endif() + +# Set the parent build flag so that the vllm-flash-attn library does not redo compile flag and arch initialization. +set(VLLM_PARENT_BUILD ON) + +#[[ +# Ensure the vllm/vllm_flash_attn directory exists before installation +install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" COMPONENT vllm_flash_attn_c) + +# Make sure vllm-flash-attn install rules are nested under vllm/ +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT vllm_flash_attn_c) +install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT vllm_flash_attn_c) + +# Fetch the vllm-flash-attn library +FetchContent_MakeAvailable(vllm-flash-attn) +message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") + +# Restore the install prefix +install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT vllm_flash_attn_c) +install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT vllm_flash_attn_c) + +# Copy over the vllm-flash-attn python files +install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm/vllm_flash_attn + COMPONENT vllm_flash_attn_c + FILES_MATCHING PATTERN "*.py" +) + +# Nothing after vllm-flash-attn, see comment about macros above +]] \ No newline at end of file diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..f801b5f8f55133dd75c41d8c8494a5ce774ffb98 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ + +# vLLM Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline/IRL event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement in the #code-of-conduct +channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1, available at +[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). + +For answers to common questions about this code of conduct, see the +[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at +[Contributor Covenant translations](https://www.contributor-covenant.org/translations). + diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 81a8db2b268b0acd03d676da41bee5be52dc245a..5f79356bd32f7a4be762f85bbc928183d7ac8679 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,30 +1,23 @@ # Contributing to vLLM -Thank you for your interest in contributing to vLLM! -Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. -There are several ways you can contribute to the project: +Thank you for your interest in contributing to vLLM! Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large. There are several ways you can contribute to the project: - Identify and report any issues or bugs. -- Request or add a new model. +- Request or add support for a new model. - Suggest or implement new features. +- Improve documentation or contribute a how-to guide. -However, remember that contributions aren't just about code. -We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions. +We also believe in the power of community support; thus, answering queries, offering PR reviews, and assisting others are also highly regarded and beneficial contributions. -Finally, one of the most impactful ways to support us is by raising awareness about vLLM. -Talk about it in your blog posts, highlighting how it's driving your incredible projects. -Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository. +Finally, one of the most impactful ways to support us is by raising awareness about vLLM. Talk about it in your blog posts and highlight how it's driving your incredible projects. Express your support on social media if you're using vLLM, or simply offer your appreciation by starring our repository! -## Setup for development +## Developing -### Build from source +Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation. Check out the [building from source](https://docs.vllm.ai/en/latest/getting_started/installation.html#build-from-source) documentation for details. -```bash -pip install -e . # This may take several minutes. -``` -### Testing +## Testing ```bash pip install -r requirements-dev.txt @@ -36,15 +29,16 @@ mypy # Unit tests pytest tests/ ``` -**Note:** Currently, the repository does not pass the mypy tests. +**Note:** Currently, the repository does not pass the ``mypy`` tests. +## Contribution Guidelines -## Contributing Guidelines +### Issues -### Issue Reporting +If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible. -If you encounter a bug or have a feature request, please check our issues page first to see if someone else has already reported it. -If not, please file a new issue, providing as much relevant information as possible. +> [!IMPORTANT] +> If you discover a security vulnerability, please follow the instructions [here](/SECURITY.md#reporting-a-vulnerability). ### Pull Requests & Code Reviews @@ -53,4 +47,4 @@ Please check the PR checklist in the [PR template](.github/PULL_REQUEST_TEMPLATE ### Thank You Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. -Your contributions make vLLM a great tool for everyone! +All of your contributions help make vLLM a great tool and community for everyone! diff --git a/Dockerfile b/Dockerfile index 49aaea2949ac6f838de003b0a02db272fd35d1a4..0a562253c537b6a07380a9e5c3215000581e5f53 100644 --- a/Dockerfile +++ b/Dockerfile @@ -9,28 +9,31 @@ ARG CUDA_VERSION=12.4.1 #################### BASE BUILD IMAGE #################### # prepare basic build environment FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base - ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 - +ARG PYTHON_VERSION=3.12 ENV DEBIAN_FRONTEND=noninteractive +# Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common \ + && apt-get install -y ccache software-properties-common git curl sudo \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ - && python3 --version - -RUN apt-get update -y \ - && apt-get install -y git curl sudo - -# Install pip s.t. it will be compatible with our PYTHON_VERSION -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} -RUN python3 -m pip --version + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version + +# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 +# as it was causing spam when compiling the CUTLASS kernels +RUN apt-get install -y gcc-10 g++-10 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 +RUN <> /etc/environment + +# Install Python and other dependencies RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \ && apt-get update -y \ - && apt-get install -y ccache software-properties-common \ + && apt-get install -y ccache software-properties-common git curl sudo vim python3-pip \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && add-apt-repository ppa:deadsnakes/ppa \ && apt-get update -y \ - && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \ - && if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \ - && python3 --version - -RUN apt-get update -y \ - && apt-get install -y python3-pip git vim curl libibverbs-dev - -# Install pip s.t. it will be compatible with our PYTHON_VERSION -RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} -RUN python3 -m pip --version + && apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \ + && update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \ + && update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \ + && ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \ + && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ + && python3 --version && python3 -m pip --version # Workaround for https://github.com/openai/triton/issues/2507 and # https://github.com/pytorch/pytorch/issues/107960 -- hopefully @@ -189,12 +173,10 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl + . /etc/environment && \ + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl +COPY examples examples #################### vLLM installation IMAGE #################### @@ -224,7 +206,7 @@ FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate hf_transfer 'modelscope!=1.15.0' + pip install accelerate hf_transfer 'modelscope!=1.15.0' bitsandbytes>=0.44.0 timm==0.9.10 ENV VLLM_USAGE_SOURCE production-docker-image diff --git a/Dockerfile.cpu b/Dockerfile.cpu index 78730f39721cb9c029737a50f500429d1f6c4ac2..f1a21d6bd13fca364e6f1ea122cc907b994840c6 100644 --- a/Dockerfile.cpu +++ b/Dockerfile.cpu @@ -2,37 +2,61 @@ FROM ubuntu:22.04 AS cpu-test-1 -RUN apt-get update -y \ - && apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ +ENV CCACHE_DIR=/root/.cache/ccache + +ENV CMAKE_CXX_COMPILER_LAUNCHER=ccache + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update -y \ + && apt-get install -y curl ccache git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \ + && apt-get install -y ffmpeg libsm6 libxext6 libgl1 \ && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 # https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html # intel-openmp provides additional performance improvement vs. openmp # tcmalloc provides better memory allocation efficiency, e.g, holding memory in caches to speed up access of commonly-used objects. -RUN pip install intel-openmp +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install intel-openmp -ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD" +ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so" RUN echo 'ulimit -c 0' >> ~/.bashrc -RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl +RUN pip install intel_extension_for_pytorch==2.4.0 -RUN pip install --upgrade pip \ - && pip install wheel packaging ninja "setuptools>=49.4.0" numpy +WORKDIR /workspace -FROM cpu-test-1 AS build +ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" +ENV PIP_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL} +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-build.txt,target=requirements-build.txt \ + pip install --upgrade pip && \ + pip install -r requirements-build.txt -COPY ./ /workspace/vllm +FROM cpu-test-1 AS build WORKDIR /workspace/vllm -RUN pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,src=requirements-common.txt,target=requirements-common.txt \ + --mount=type=bind,src=requirements-cpu.txt,target=requirements-cpu.txt \ + pip install -v -r requirements-cpu.txt + +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ... ARG VLLM_CPU_DISABLE_AVX512 ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512} -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=cache,target=/root/.cache/ccache \ + --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel && \ + pip install dist/*.whl && \ + rm -rf dist WORKDIR /workspace/ diff --git a/Dockerfile.neuron b/Dockerfile.neuron index 010f23a143010effc0c84285aa2815f68aa98f6b..3d9d8e7da487c532b61992e278b66f1f70f0c5ba 100644 --- a/Dockerfile.neuron +++ b/Dockerfile.neuron @@ -1,36 +1,41 @@ # default base image -ARG BASE_IMAGE="763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference-neuronx:2.1.1-neuronx-py310-sdk2.17.0-ubuntu20.04" +ARG BASE_IMAGE="public.ecr.aws/neuron/pytorch-inference-neuronx:2.1.2-neuronx-py310-sdk2.20.0-ubuntu20.04" FROM $BASE_IMAGE RUN echo "Base image is $BASE_IMAGE" # Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +RUN apt-get update && \ + apt-get install -y \ + git \ + python3 \ + python3-pip \ + ffmpeg libsm6 libxext6 libgl1 ### Mount Point ### # When launching the container, mount the code directory to /app ARG APP_MOUNT=/app VOLUME [ ${APP_MOUNT} ] -WORKDIR ${APP_MOUNT} +WORKDIR ${APP_MOUNT}/vllm RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas RUN python3 -m pip install sentencepiece transformers==4.36.2 -U RUN python3 -m pip install transformers-neuronx --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -RUN python3 -m pip install --pre neuronx-cc==2.12.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U +RUN python3 -m pip install --pre neuronx-cc==2.15.* --extra-index-url=https://pip.repos.neuron.amazonaws.com -U -COPY ./vllm /app/vllm/vllm -COPY ./setup.py /app/vllm/setup.py -COPY ./requirements-common.txt /app/vllm/requirements-common.txt -COPY ./requirements-neuron.txt /app/vllm/requirements-neuron.txt +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi -RUN cd /app/vllm \ - && python3 -m pip install -U -r requirements-neuron.txt +RUN python3 -m pip install -U \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-neuron.txt ENV VLLM_TARGET_DEVICE neuron -RUN cd /app/vllm \ - && pip install -e . \ - && cd .. +RUN --mount=type=bind,source=.git,target=.git \ + pip install --no-build-isolation -v -e . \ CMD ["/bin/bash"] diff --git a/Dockerfile.openvino b/Dockerfile.openvino index c84dea419e58a8c961010e3974a3bee4fbaf5d79..c89864da91180e36f94630931bbf4d2b06b11f99 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -4,24 +4,20 @@ FROM ubuntu:22.04 AS dev RUN apt-get update -y && \ - apt-get install -y python3-pip git + apt-get install -y \ + git python3-pip \ + ffmpeg libsm6 libxext6 libgl1 WORKDIR /workspace -# copy requirements -COPY requirements-build.txt /workspace/vllm/ -COPY requirements-common.txt /workspace/vllm/ -COPY requirements-openvino.txt /workspace/vllm/ - -COPY vllm/ /workspace/vllm/vllm -COPY csrc/core /workspace/vllm/csrc/core -COPY cmake/utils.cmake /workspace/vllm/cmake/ -COPY CMakeLists.txt /workspace/vllm/ -COPY setup.py /workspace/vllm/ +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # install build requirements RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt # build vLLM with OpenVINO backend -RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/ COPY examples/ /workspace/vllm/examples COPY benchmarks/ /workspace/vllm/benchmarks diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index d4e4c483cada8b0d6fa34998350796f33fea1511..a84e00fd5677f7a2c27f8caf9cd8fb0ea92fe7d6 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -2,21 +2,35 @@ FROM mambaorg/micromamba ARG MAMBA_DOCKERFILE_ACTIVATE=1 USER root -RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" + +RUN apt-get update -y && apt-get install -y git wget curl vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential ffmpeg libsm6 libxext6 libgl1 # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes COPY ./ /workspace/vllm WORKDIR /workspace/vllm +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi # These packages will be in rocketce eventually -RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install -v --prefer-binary --extra-index-url https://repo.fury.io/mgiessing \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + torch==2.3.1 \ + -r requirements-cpu.txt \ + xformers uvloop==0.20.0 + +RUN --mount=type=bind,source=.git,target=.git \ + VLLM_TARGET_DEVICE=cpu python3 setup.py install + +WORKDIR /workspace/ -RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks -WORKDIR /vllm-workspace -ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 33423fde4ff9618d5b8d2dc03d6c330bc6e41df7..d35889f053e27f1a2952a97297fec3b7ff47f59b 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -1,5 +1,5 @@ -# Default ROCm 6.1 base image -ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging" +# Default ROCm 6.2 base image +ARG BASE_IMAGE="rocm/pytorch:rocm6.2_ubuntu20.04_py3.9_pytorch_release_2.3.0" # Default ROCm ARCHes to build vLLM for. ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" @@ -7,18 +7,12 @@ ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100" # Whether to install CK-based flash-attention # If 0, will not install flash-attention ARG BUILD_FA="1" -# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL` -# If this succeeds, we use the downloaded wheel and skip building flash-attention. -# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the -# architectures specified in `FA_GFX_ARCHS` -ARG TRY_FA_WHEEL="1" -ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl" ARG FA_GFX_ARCHS="gfx90a;gfx942" -ARG FA_BRANCH="23a2b1c2" +ARG FA_BRANCH="3cea2fb" # Whether to build triton on rocm ARG BUILD_TRITON="1" -ARG TRITON_BRANCH="e0fc12c" +ARG TRITON_BRANCH="e192dba" ### Base image build stage FROM $BASE_IMAGE AS base @@ -50,14 +44,17 @@ RUN python3 -m pip install --upgrade pip # Remove sccache so it doesn't interfere with ccache # TODO: implement sccache support across components RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)" -# Install torch == 2.5.0 on ROCm -RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ + +# Install torch == 2.6.0 on ROCm +RUN --mount=type=cache,target=/root/.cache/pip \ + case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ + *"rocm-6.2"*) \ python3 -m pip uninstall -y torch torchvision \ - && python3 -m pip install --no-cache-dir --pre \ - torch==2.5.0.dev20240726 \ - torchvision==0.20.0.dev20240726 \ - --index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \ + && python3 -m pip install --pre \ + torch==2.6.0.dev20240918 \ + setuptools-scm>=8 \ + torchvision==0.20.0.dev20240918 \ + --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer @@ -79,25 +76,18 @@ RUN cd /opt/rocm/share/amd_smi \ ### Flash-Attention wheel build stage FROM base AS build_fa ARG BUILD_FA -ARG TRY_FA_WHEEL -ARG FA_WHEEL_URL ARG FA_GFX_ARCHS ARG FA_BRANCH # Build ROCm flash-attention wheel if `BUILD_FA = 1` RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_FA" = "1" ]; then \ - if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \ - # If a suitable wheel exists, we download it instead of building FA - mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \ - else \ - mkdir -p libs \ - && cd libs \ - && git clone https://github.com/ROCm/flash-attention.git \ - && cd flash-attention \ - && git checkout "${FA_BRANCH}" \ - && git submodule update --init \ - && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ - fi; \ + mkdir -p libs \ + && cd libs \ + && git clone https://github.com/ROCm/flash-attention.git \ + && cd flash-attention \ + && git checkout "${FA_BRANCH}" \ + && git submodule update --init \ + && GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \ # Create an empty directory otherwise as later build stages expect one else mkdir -p /install; \ fi @@ -112,6 +102,7 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ + && python3 -m pip install ninja cmake wheel pybind11 \ && git clone https://github.com/OpenAI/triton.git \ && cd triton \ && git checkout "${TRITON_BRANCH}" \ @@ -126,10 +117,13 @@ RUN --mount=type=cache,target=${CCACHE_DIR} \ FROM base AS final # Import the vLLM development directory from the build context COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh ; fi # Package upgrades for useful functionality or to avoid dependency issues RUN --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install --upgrade numba scipy huggingface-hub[cli] + python3 -m pip install --upgrade numba scipy huggingface-hub[cli] pytest-shard # Workaround for ray >= 2.10.0 @@ -138,15 +132,9 @@ ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 ENV TOKENIZERS_PARALLELISM=false RUN --mount=type=cache,target=${CCACHE_DIR} \ + --mount=type=bind,source=.git,target=.git \ --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -Ur requirements-rocm.txt \ - && case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \ - *"rocm-6.1"*) \ - # Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM - wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \ - # Prevent interference if torch bundles its own HIP runtime - && rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \ - *) ;; esac \ && python3 setup.py clean --all \ && python3 setup.py develop diff --git a/Dockerfile.tpu b/Dockerfile.tpu index adebb8ab5adca0d39a3e3dc02563a0405e906762..bdfab3f61910fd876e684f0806de86a496228c05 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,23 +1,32 @@ -ARG NIGHTLY_DATE="20240726" +ARG NIGHTLY_DATE="20240828" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE -WORKDIR /workspace +WORKDIR /workspace/vllm -# Install aiohttp separately to avoid build errors. -RUN pip install aiohttp -# Install NumPy 1 instead of NumPy 2. -RUN pip install "numpy<2" -# Install the TPU and Pallas dependencies. -RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html -RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + git \ + ffmpeg libsm6 libxext6 libgl1 -# Fix FastAPI dependence -RUN pip install "starlette<0.38.0" +# Install the TPU and Pallas dependencies. +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html # Build vLLM. -COPY . /workspace/vllm +COPY . . +ARG GIT_REPO_CHECK=0 +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi + ENV VLLM_TARGET_DEVICE="tpu" -RUN cd /workspace/vllm && python setup.py develop +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + python3 -m pip install \ + cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ + -r requirements-tpu.txt +RUN python3 setup.py develop CMD ["/bin/bash"] diff --git a/Dockerfile.xpu b/Dockerfile.xpu index f91baa11a37537d4423bd625a2da43bb52db5ad3..0ecb46df6256c110efb538a8b5c7725bb101d899 100644 --- a/Dockerfile.xpu +++ b/Dockerfile.xpu @@ -1,22 +1,58 @@ -FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04 +FROM intel/oneapi-basekit:2024.2.1-0-devel-ubuntu22.04 AS vllm-base RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ chmod 644 /usr/share/keyrings/intel-oneapi-archive-keyring.gpg && \ - rm /etc/apt/sources.list.d/intel-graphics.list && \ wget -O- https://repositories.intel.com/graphics/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null && \ echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg -RUN apt-get update -y \ -&& apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip - -COPY ./ /workspace/vllm +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends --fix-missing \ + curl \ + ffmpeg \ + git \ + libsndfile1 \ + libsm6 \ + libxext6 \ + libgl1 \ + lsb-release \ + numactl \ + python3 \ + python3-dev \ + python3-pip \ + # vim \ + wget WORKDIR /workspace/vllm +COPY requirements-xpu.txt /workspace/vllm/requirements-xpu.txt +COPY requirements-common.txt /workspace/vllm/requirements-common.txt + +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install --no-cache-dir \ + --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ + -r requirements-xpu.txt -RUN pip install -v -r requirements-xpu.txt +COPY . . +ARG GIT_REPO_CHECK +RUN --mount=type=bind,source=.git,target=.git \ + if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi -RUN VLLM_TARGET_DEVICE=xpu python3 setup.py install +ENV VLLM_TARGET_DEVICE=xpu + +RUN --mount=type=cache,target=/root/.cache/pip \ + --mount=type=bind,source=.git,target=.git \ + python3 setup.py install CMD ["/bin/bash"] + +FROM vllm-base AS vllm-openai + +# install additional dependencies for openai api server +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' + +ENV VLLM_USAGE_SOURCE production-docker-image \ + TRITON_XPU_PROFILE 1 + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/MANIFEST.in b/MANIFEST.in index 5a41e5e7141849bd06a5cfb540dda42400d0265e..82be639ef4d739ce67ff982ecbe82de09aae1afd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include LICENSE -include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/README.md b/README.md index eb75d27f46928110709fcf79cb131d99e4da8b9c..e09f30cb353243d66d5ae0c1c4c7da50af2f019a 100644 --- a/README.md +++ b/README.md @@ -4,30 +4,23 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention ## 暂不支持的官方功能 - **量化推理**:目前支持fp16的推理和gptq,awq-int4推理,mralin的权重量化、kv-cache fp8推理方案暂不支持 -- **模块支持**:目前不支持Sliding window attention、 moe kernel模块 +- **模块支持**:目前不支持Sliding window attention ## 支持模型结构列表 -| 结构 | 模型 | 模型并行 | FP16 | -| :----------: | :----------: | :------: | :--: | -| LlamaForCausalLM | LLaMA | Yes | Yes | -| LlamaForCausalLM | LLaMA-2 | Yes | Yes | -| LlamaForCausalLM | LLaMA-3 | Yes | Yes | -| LlamaForCausalLM | Codellama | Yes | Yes | -| QWenLMHeadModel | QWen | Yes | Yes | -| Qwen2ForCausalLM | QWen1.5 | Yes | Yes | -| Qwen2ForCausalLM | CodeQwen1.5 | Yes | Yes | -| Qwen2ForCausalLM | QWen2 | Yes | Yes | -| ChatGLMModel | chatglm2 | Yes | Yes | -| ChatGLMModel | chatglm3 | Yes | Yes | -| BaiChuanForCausalLM | Baichuan-7B | Yes | Yes | -| BaiChuanForCausalLM | Baichuan2-7B | Yes | Yes | -| InternLMForCausalLM | InternLM | Yes | Yes | -| InternLM2ForCausalLM | InternLM2 | Yes | Yes | -| LlamaForCausalLM | deepseek | Yes | Yes | -| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes | -| LlamaForCausalLM | Yi | Yes | Yes | -| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes | +| 结构 | 模型 | 模型并行 | FP16 | +| :------: | :------: | :------: | :------: | +| LlamaForCausalLM | LLaMA、LLaMA-2、LLaMA-3、Codellama、deepseek、Yi | Yes | Yes | +| QWenLMHeadModel | QWen、Qwen-VL | Yes | Yes | +| Qwen2ForCausalLM | QWen1.5、CodeQwen1.5、QWen2 | Yes | Yes | +| ChatGLMModel | chatglm2、chatglm3、chatglm4、glm-4v-9b | Yes | Yes | +| BaiChuanForCausalLM | Baichuan、Baichuan2 | Yes | Yes | +| BloomForCausalLM | BLOOM | Yes | Yes | +| InternLMForCausalLM | InternLM | Yes | Yes | +| InternLM2ForCausalLM | InternLM2 | Yes | Yes | +| DeepseekV2ForCausalLM | DeepSeek-V2 | Yes | Yes | +| MixtralForCausalLM | Mixtral-8x7B | Yes | Yes | +| TeleChat12BForCausalLM (#TelechatForCausalLM) | TeleChat-12B | Yes | Yes | ## 安装 @@ -36,15 +29,16 @@ vLLM支持 + Python 3.9. + Python 3.10. + Python 3.11. ++ Python 3.12. ### 使用源码编译方式安装 #### 编译环境准备 提供2种环境准备方式: -1. 基于光源pytorch2.1.0基础镜像环境:镜像下载地址:[https://sourcefind.cn/#/image/dcu/pytorch](https://sourcefind.cn/#/image/dcu/pytorch),根据pytorch2.1.0、python、dtk及系统下载对应的镜像版本。 +1. 基于光源pytorch2.3.0基础镜像环境:镜像下载地址:[https://sourcefind.cn/#/image/dcu/pytorch](https://sourcefind.cn/#/image/dcu/pytorch),根据pytorch2.1.0、python、dtk及系统下载对应的镜像版本。 -2. 基于现有python环境:安装pytorch2.1.0,pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch),根据python、dtk版本,下载对应pytorch2.1.0的whl包。安装命令如下: +2. 基于现有python环境:安装pytorch2.3.0,pytorch whl包下载目录:[https://cancon.hpccube.com:65024/4/main/pytorch](https://cancon.hpccube.com:65024/4/main/pytorch),根据python、dtk版本,下载对应pytorch2.1.0的whl包。安装命令如下: ```shell pip install torch* (下载的torch的whl包) pip install setuptools wheel @@ -70,9 +64,9 @@ VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install ``` #### 运行基础环境准备 -1、使用上面基于光源pytorch2.1.0基础镜像环境 +1、使用上面基于光源pytorch2.3.0基础镜像环境 -2、根据pytorch2.1.0、python、dtk及系统下载对应的依赖包: +2、根据pytorch2.3.0、python、dtk及系统下载对应的依赖包: - triton:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton/) - xformers:[https://cancon.hpccube.com:65024/4/main/xformers](https://cancon.hpccube.com:65024/4/main/xformers) - flash_attn: [https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn) @@ -82,7 +76,7 @@ VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ ## 验证 -- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.5.4; +- python -c "import vllm; print(vllm.\_\_version__)",版本号与官方版本同步,查询该软件的版本号,例如0.6.3.post1; ## Known Issue - 无 diff --git a/README_ORIGIN.md b/README_ORIGIN.md index b0e02c4bb729733bdde35f773d83cae73112f31b..53332719d588477edb194d30e68a03f9a0e8e651 100644 --- a/README_ORIGIN.md +++ b/README_ORIGIN.md @@ -10,13 +10,14 @@ Easy, fast, and cheap LLM serving for everyone

    -| Documentation | Blog | Paper | Discord | - +| Documentation | Blog | Paper | Discord | Twitter/X | Developer Slack |

    ---- *Latest News* 🔥 +- [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! +- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users! +- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). @@ -36,10 +37,12 @@ vLLM is fast with: - Efficient management of attention key and value memory with **PagedAttention** - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache -- Optimized CUDA kernels +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), INT4, INT8, and FP8. +- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. +- Speculative decoding +- Chunked prefill -**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)). +**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script. vLLM is flexible and easy to use with: @@ -48,20 +51,21 @@ vLLM is flexible and easy to use with: - Tensor parallelism and pipeline parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs -- (Experimental) Prefix caching support -- (Experimental) Multi-lora support +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. +- Prefix caching support +- Multi-lora support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) - Mixture-of-Expert LLMs (e.g., Mixtral) +- Embedding Models (e.g. E5-Mistral) - Multi-modal LLMs (e.g., LLaVA) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). ## Getting Started -Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): +Install vLLM with `pip` or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): ```bash pip install vllm @@ -99,6 +103,7 @@ vLLM is a community project. Our compute resources for development and testing a - Roblox - RunPod - Sequoia Capital +- Skywork AI - Trainy - UC Berkeley - UC San Diego @@ -116,4 +121,12 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs booktitle={Proceedings of the ACM SIGOPS 29th Symposium on Operating Systems Principles}, year={2023} } -``` \ No newline at end of file +``` + +## Contact Us + +* For technical questions and feature requests, please use Github issues or discussions. +* For discussing with fellow users, please use Discord. +* For coordinating contributions and development, please use Slack. +* For security disclosures, please use Github's security advisory feature. +* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. \ No newline at end of file diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..ad3f1f16ab5607293b64f1d2b92aeebfa63297e2 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,11 @@ +# Security Policy + +## Reporting a Vulnerability + +If you believe you have found a security vulnerability in vLLM, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. + +Please report security issues privately using [the vulnerability submission form](https://github.com/vllm-project/vllm/security/advisories/new). + +--- + +Please see [PyTorch's Security Policy](https://github.com/pytorch/pytorch/blob/main/SECURITY.md) for more information and recommendations on how to securely interact with models. diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index fbab547d094fe6b7dd99af8eeddc6a64cbff2d64..4813fde27f0bc145809eae8a78ca4f1381c29bfb 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -23,7 +23,9 @@ class RequestFuncInput: output_len: int model: str best_of: int = 1 - use_beam_search: bool = False + logprobs: Optional[int] = None + multi_modal_content: Optional[dict] = None + ignore_eos: bool = False @dataclass @@ -46,13 +48,13 @@ async def async_request_tgi( assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. + # TGI does not accept ignore_eos flag. } payload = { "inputs": request_func_input.prompt, @@ -117,7 +119,6 @@ async def async_request_trt_llm( assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -127,6 +128,8 @@ async def async_request_trt_llm( "max_tokens": request_func_input.output_len, "stream": True, } + if request_func_input.ignore_eos: + payload["min_length"] = request_func_input.output_len output = RequestFuncOutput() output.prompt_len = request_func_input.prompt_len @@ -181,7 +184,6 @@ async def async_request_deepspeed_mii( ) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 - assert not request_func_input.use_beam_search payload = { "prompt": request_func_input.prompt, @@ -225,18 +227,19 @@ async def async_request_openai_completions( ) -> RequestFuncOutput: api_url = request_func_input.api_url assert api_url.endswith( - "completions" - ), "OpenAI Completions API URL must end with 'completions'." + ("completions", "profile") + ), "OpenAI Completions API URL must end with 'completions' or 'profile'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, "stream": True, + "ignore_eos": request_func_input.ignore_eos, } headers = { "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}" @@ -276,8 +279,9 @@ async def async_request_openai_completions( output.ttft = ttft # Decoding phase - output.itl.append(timestamp - - most_recent_timestamp) + else: + output.itl.append(timestamp - + most_recent_timestamp) most_recent_timestamp = timestamp generated_text += data["choices"][0]["text"] @@ -308,18 +312,21 @@ async def async_request_openai_chat_completions( ), "OpenAI Chat Completions API URL must end with 'chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search + content = [{"type": "text", "text": request_func_input.prompt}] + if request_func_input.multi_modal_content: + content.append(request_func_input.multi_modal_content) payload = { "model": request_func_input.model, "messages": [ { "role": "user", - "content": request_func_input.prompt, + "content": content }, ], "temperature": 0.0, "max_tokens": request_func_input.output_len, "stream": True, + "ignore_eos": request_func_input.ignore_eos, } headers = { "Content-Type": "application/json", @@ -423,4 +430,5 @@ ASYNC_REQUEST_FUNCS = { "openai-chat": async_request_openai_chat_completions, "tensorrt-llm": async_request_trt_llm, "scalellm": async_request_openai_completions, + "sglang": async_request_openai_completions, } diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 97afd301c8f24f911714001a10d73bc86e9062c8..ea1a7788f621d657c9286548490cc1c91fa81ce0 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -10,8 +10,8 @@ import torch from tqdm import tqdm from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptInputs +from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -38,7 +38,6 @@ def main(args: argparse.Namespace): quantization_param_path=args.quantization_param_path, device=args.device, ray_workers_use_nsight=args.ray_workers_use_nsight, - use_v2_block_manager=args.use_v2_block_manager, enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size, @@ -51,9 +50,8 @@ def main(args: argparse.Namespace): sampling_params = SamplingParams( n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=args.use_beam_search, ignore_eos=True, max_tokens=args.output_len, ) @@ -61,7 +59,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +72,13 @@ def main(args: argparse.Namespace): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() @@ -205,13 +203,11 @@ if __name__ == '__main__': default=None, help=('path to save the pytorch profiler output. Can be visualized ' 'with ui.perfetto.dev or Tensorboard.')) - parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') + parser.add_argument("--device", + type=str, + default="auto", + choices=DEVICE_OPTIONS, + help='device type for vLLM execution') parser.add_argument('--block-size', type=int, default=16, @@ -224,7 +220,6 @@ if __name__ == '__main__': parser.add_argument("--enable-prefix-caching", action='store_true', help="Enable automatic prefix caching") - parser.add_argument('--use-v2-block-manager', action='store_true') parser.add_argument( "--ray-workers-use-nsight", action='store_true', diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index 395107a5ec74761167eb1b23d771a38219556325..a354358e43aa382bb516b0797b9bc8a6c21d2208 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,8 +1,45 @@ +""" +Benchmark the efficiency of prefix caching. + +This script allows you to benchmark the performance of +a model with and without prefix caching using either fixed prompts +or prompts sampled from the ShareGPT dataset. + +Fixed example usage: + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --enable-prefix-caching \ + --num-prompts 1 \ + --repeat-count 100 + +ShareGPT example usage: + # This command samples 20 prompts with input lengths + # between 128 and 256 tokens from the ShareGPT dataset, + # then replicates each prompt 5 times. + python benchmark_prefix_caching.py \ + --model meta-llama/Llama-2-7b-chat-hf \ + --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ + --enable-prefix-caching \ + --num-prompts 20 \ + --repeat-count 5 \ + --input-length-range 128:256 +""" + +import json +import random import time +from typing import List, Optional, Tuple + +from transformers import PreTrainedTokenizerBase from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser +try: + from vllm.transformers_utils.tokenizer import get_tokenizer +except ImportError: + from backend_request_func import get_tokenizer + PROMPT = "You are a helpful assistant in recognizes the content of tables in markdown format. Here is a table as fellows. You need to answer my question about the table.\n# Table\n|Opening|Opening|Sl. No.|Film|Cast|Director|Music Director|Notes|\n|----|----|----|----|----|----|----|----|\n|J A N|9|1|Agni Pushpam|Jayabharathi, Kamalahasan|Jeassy|M. K. Arjunan||\n|J A N|16|2|Priyamvada|Mohan Sharma, Lakshmi, KPAC Lalitha|K. S. Sethumadhavan|V. Dakshinamoorthy||\n|J A N|23|3|Yakshagaanam|Madhu, Sheela|Sheela|M. S. Viswanathan||\n|J A N|30|4|Paalkkadal|Sheela, Sharada|T. K. Prasad|A. T. Ummer||\n|F E B|5|5|Amma|Madhu, Srividya|M. Krishnan Nair|M. K. Arjunan||\n|F E B|13|6|Appooppan|Thikkurissi Sukumaran Nair, Kamal Haasan|P. Bhaskaran|M. S. Baburaj||\n|F E B|20|7|Srishti|Chowalloor Krishnankutty, Ravi Alummoodu|K. T. Muhammad|M. S. Baburaj||\n|F E B|20|8|Vanadevatha|Prem Nazir, Madhubala|Yusufali Kechery|G. Devarajan||\n|F E B|27|9|Samasya|Madhu, Kamalahaasan|K. Thankappan|Shyam||\n|F E B|27|10|Yudhabhoomi|K. P. Ummer, Vidhubala|Crossbelt Mani|R. K. Shekhar||\n|M A R|5|11|Seemantha Puthran|Prem Nazir, Jayabharathi|A. B. Raj|M. K. Arjunan||\n|M A R|12|12|Swapnadanam|Rani Chandra, Dr. Mohandas|K. G. George|Bhaskar Chandavarkar||\n|M A R|19|13|Thulavarsham|Prem Nazir, sreedevi, Sudheer|N. Sankaran Nair|V. Dakshinamoorthy||\n|M A R|20|14|Aruthu|Kaviyoor Ponnamma, Kamalahasan|Ravi|G. Devarajan||\n|M A R|26|15|Swimming Pool|Kamal Haasan, M. G. Soman|J. Sasikumar|M. K. Arjunan||\n\n# Question\nWhat' s the content in the (1,1) cells\n" # noqa: E501 @@ -15,19 +52,97 @@ def test_prefix(llm=None, sampling_params=None, prompts=None): print(f"cost time {end_time - start_time}") +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + input_length_range: Tuple[int, int], + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + min_len, max_len = input_length_range + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if min_len <= prompt_len <= max_len: + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + +def repeat_and_sort_requests(requests: List[Tuple[str, int, int]], + repeat_count: int, + sort: bool = False) -> List[str]: + repeated_requests = requests * repeat_count + if sort: + repeated_requests.sort(key=lambda x: x[1]) + else: + random.shuffle(repeated_requests) + return [req[0] for req in repeated_requests] + + def main(args): + tokenizer = get_tokenizer(args.model, trust_remote_code=True) + input_length_range = tuple(map(int, args.input_length_range.split(':'))) + random.seed(args.seed) + if args.dataset_path is not None: + print(f"Start to sample {args.num_prompts} prompts" + "from {args.dataset_path}") + filtered_datasets = sample_requests( + dataset_path=args.dataset_path, + num_requests=args.num_prompts, + tokenizer=tokenizer, + input_length_range=input_length_range, + fixed_output_len=args.output_len, + ) + else: + prompt_len = len(tokenizer(PROMPT).input_ids) + filtered_datasets = [(PROMPT, prompt_len, args.output_len) + ] * args.num_prompts + llm = LLM(model=args.model, tokenizer_mode='auto', trust_remote_code=True, enforce_eager=True, - use_v2_block_manager=args.use_v2_block_manager, tensor_parallel_size=args.tensor_parallel_size, enable_prefix_caching=args.enable_prefix_caching) - num_prompts = 100 - prompts = [PROMPT] * num_prompts sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) + print("Testing filtered datasets") + prompts = repeat_and_sort_requests(filtered_datasets, + repeat_count=args.repeat_count, + sort=args.sort) + print("------warm up------") test_prefix( llm=llm, @@ -45,18 +160,39 @@ def main(args): if __name__ == "__main__": parser = FlexibleArgumentParser( - description='Benchmark the performance with or without automatic ' - 'prefix caching.') + description= + 'Benchmark the performance with or without automatic prefix caching.') parser.add_argument('--model', type=str, default='baichuan-inc/Baichuan2-13B-Chat') + parser.add_argument("--dataset-path", + type=str, + default=None, + help="Path to the dataset.") parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--output-len', type=int, default=10) parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', + parser.add_argument('--num-prompts', + type=int, + default=1, + help="Number of the prompts sampled from dataset") + parser.add_argument('--repeat-count', + type=int, + default=100, + help='Number of times to repeat each prompt') + parser.add_argument('--sort', action='store_true', - help='Use BlockSpaceMangerV2') + help='Sort prompts by input length') + parser.add_argument('--input-length-range', + type=str, + default='128:256', + help='Range of input lengths for sampling prompts,' + 'specified as "min:max" (e.g., "128:256").') + parser.add_argument("--seed", + type=int, + default=0, + help='Random seed for reproducibility') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py new file mode 100644 index 0000000000000000000000000000000000000000..8843e3a927a0160d552b99a16384b5504739e80e --- /dev/null +++ b/benchmarks/benchmark_prioritization.py @@ -0,0 +1,293 @@ +"""Benchmark offline prioritization.""" +import argparse +import json +import random +import time +from typing import List, Optional, Tuple + +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int], +) -> List[Tuple[str, int, int]]: + if fixed_output_len is not None and fixed_output_len < 4: + raise ValueError("output_len too small") + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [(data["conversations"][0]["value"], + data["conversations"][1]["value"]) for data in dataset] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + filtered_dataset: List[Tuple[str, int, int]] = [] + for i in range(len(dataset)): + if len(filtered_dataset) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + completion = dataset[i][1] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + + #Select a equi-probable random priority + priority = 0 if random.random() < 0.5 else 1 + + filtered_dataset.append((prompt, prompt_len, output_len, priority)) + + return filtered_dataset + + +def run_vllm( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + gpu_memory_utilization: float = 0.9, + download_dir: Optional[str] = None, +) -> float: + from vllm import LLM, SamplingParams + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + disable_log_stats=False, + ) + + # Add the requests to the engine. + prompts = [] + sampling_params = [] + priority = [] + for prompt, _, output_len, _priority in requests: + prompts.append(prompt) + priority.append(_priority) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + )) + + start = time.perf_counter() + llm.generate(prompts, sampling_params, priority=priority, use_tqdm=True) + end = time.perf_counter() + return end - start + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + + # Sample the requests. + tokenizer = AutoTokenizer.from_pretrained( + args.tokenizer, trust_remote_code=args.trust_remote_code) + if args.dataset is None: + # Synthesize a prompt with the given input length. + prompt = "hi" * (args.input_len - 1) + requests = [(prompt, args.input_len, args.output_len) + for _ in range(args.num_prompts)] + else: + requests = sample_requests(args.dataset, args.num_prompts, tokenizer, + args.output_len) + + if args.backend == "vllm": + elapsed_time = run_vllm(requests, args.model, args.tokenizer, + args.quantization, args.tensor_parallel_size, + args.seed, args.n, args.trust_remote_code, + args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, + args.enable_chunked_prefill, + args.max_num_batched_tokens, + args.gpu_memory_utilization, args.download_dir) + else: + raise ValueError(f"Unknown backend: {args.backend}") + total_num_tokens = sum(prompt_len + output_len + for _, prompt_len, output_len, priority in requests) + print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, " + f"{total_num_tokens / elapsed_time:.2f} tokens/s") + + # Output JSON results if specified + if args.output_json: + results = { + "elapsed_time": elapsed_time, + "num_requests": len(requests), + "total_num_tokens": total_num_tokens, + "requests_per_second": len(requests) / elapsed_time, + "tokens_per_second": total_num_tokens / elapsed_time, + } + with open(args.output_json, "w") as f: + json.dump(results, f, indent=4) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the throughput.") + parser.add_argument("--backend", + type=str, + choices=["vllm", "hf", "mii"], + default="vllm") + parser.add_argument("--dataset", + type=str, + default=None, + help="Path to the dataset.") + parser.add_argument("--input-len", + type=int, + default=None, + help="Input prompt length for each request") + parser.add_argument("--output-len", + type=int, + default=None, + help="Output length for each request. Overrides the " + "output length from the dataset.") + parser.add_argument("--model", type=str, default="facebook/opt-125m") + parser.add_argument("--tokenizer", type=str, default=None) + parser.add_argument('--quantization', + '-q', + choices=[*QUANTIZATION_METHODS, None], + default=None) + parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) + parser.add_argument("--n", + type=int, + default=1, + help="Number of generated sequences per prompt.") + parser.add_argument("--num-prompts", + type=int, + default=200, + help="Number of prompts to process.") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument( + '--max-model-len', + type=int, + default=None, + help='Maximum length of a sequence (including prompt and output). ' + 'If None, will be derived from the model.') + parser.add_argument( + '--dtype', + type=str, + default='auto', + choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--gpu-memory-utilization', + type=float, + default=0.9, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], + default="auto", + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' + 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') + parser.add_argument( + '--quantization-param-path', + type=str, + default=None, + help='Path to the JSON file containing the KV cache scaling factors. ' + 'This should generally be supplied, when KV cache dtype is FP8. ' + 'Otherwise, KV cache scaling factors default to 1.0, which may cause ' + 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' + 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' + 'instead supported for common inference criteria.') + parser.add_argument( + "--device", + type=str, + default="cuda", + choices=["cuda", "cpu"], + help='device type for vLLM execution, supporting CUDA and CPU.') + parser.add_argument( + "--enable-prefix-caching", + action='store_true', + help="enable automatic prefix caching for vLLM backend.") + parser.add_argument("--enable-chunked-prefill", + action='store_true', + help="enable chunked prefill for vLLM backend.") + parser.add_argument('--max-num-batched-tokens', + type=int, + default=None, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--download-dir', + type=str, + default=None, + help='directory to download and load the weights, ' + 'default to the default cache dir of huggingface') + parser.add_argument( + '--output-json', + type=str, + default=None, + help='Path to save the throughput results in JSON format.') + + args = parser.parse_args() + if args.tokenizer is None: + args.tokenizer = args.model + if args.dataset is None: + assert args.input_len is not None + assert args.output_len is not None + else: + assert args.input_len is None + + main(args) diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fc0dbf77f16b9aa98bf846e477842bb7c873341e..c1a396c81f666731f31f0a560a79c7110938edf8 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,4 +1,4 @@ -"""Benchmark online serving throughput. +r"""Benchmark online serving throughput. On the server side, run one of the following commands: vLLM OpenAI API server @@ -24,6 +24,8 @@ On the client side, run: """ import argparse import asyncio +import base64 +import io import json import os import random @@ -31,11 +33,13 @@ import time import warnings from dataclasses import dataclass from datetime import datetime -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple +from typing import Any, AsyncGenerator, Collection, Dict, List, Optional, Tuple import numpy as np from backend_request_func import (ASYNC_REQUEST_FUNCS, RequestFuncInput, RequestFuncOutput) +from datasets import load_dataset +from PIL.Image import Image from tqdm.asyncio import tqdm from transformers import PreTrainedTokenizerBase @@ -56,20 +60,27 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float - input_throughput: float output_throughput: float + total_token_throughput: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - p99_ttft_ms: float + percentiles_ttft_ms: List[Tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - p99_tpot_ms: float + percentiles_tpot_ms: List[Tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - p99_itl_ms: float + percentiles_itl_ms: List[Tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: List[Tuple[float, float]] def sample_sharegpt_requests( @@ -77,11 +88,9 @@ def sample_sharegpt_requests( num_requests: int, tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, -) -> List[Tuple[str, int, int]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") +) -> List[Tuple[str, int, int, None]]: # Load the dataset. - with open(dataset_path) as f: + with open(dataset_path, encoding='utf-8') as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] @@ -106,13 +115,13 @@ def sample_sharegpt_requests( prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: + if prompt_len < 4 or (fixed_output_len is None and output_len < 4): # Prune too short sequences. continue if prompt_len > 1024 or prompt_len + output_len > 2048: # Prune too long sequences. continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((prompt, prompt_len, output_len, None)) return filtered_dataset @@ -124,13 +133,13 @@ def sample_sonnet_requests( output_len: int, prefix_len: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, str, int, int]]: +) -> List[Tuple[str, str, int, int, None]]: assert ( input_len > prefix_len ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." # Load the dataset. - with open(dataset_path) as f: + with open(dataset_path, encoding='utf-8') as f: poem_lines = f.readlines() # Tokenize the poem lines. @@ -167,9 +176,9 @@ def sample_sonnet_requests( # Sample the rest of lines per request. sampled_requests: List[Tuple[str, int, int]] = [] for _ in range(num_requests): - sampled_lines = "".join( - prefix_lines + - random.sample(poem_lines, num_input_lines - num_prefix_lines)) + num_lines_needed = num_input_lines - num_prefix_lines + sampled_lines = "".join(prefix_lines + + random.choices(poem_lines, k=num_lines_needed)) prompt = f"{base_prompt}{sampled_lines}" message = [ @@ -182,14 +191,81 @@ def sample_sonnet_requests( message, add_generation_prompt=True, tokenize=False) prompt_len = len(tokenizer(prompt_formatted).input_ids) sampled_requests.append( - (prompt, prompt_formatted, prompt_len, output_len)) + (prompt, prompt_formatted, prompt_len, output_len, None)) + + return sampled_requests + + +def sample_hf_requests( + dataset_path: str, + dataset_subset: str, + dataset_split: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + fixed_output_len: Optional[int] = None, +) -> List[Tuple[str, str, int, Optional[Dict[str, Collection[str]]]]]: + dataset = load_dataset(dataset_path, + name=dataset_subset, + split=dataset_split, + streaming=True) + assert "conversations" in dataset.features, ( + "HF Dataset must have 'conversations' column.") + filtered_dataset = dataset.shuffle().filter( + lambda x: len(x["conversations"]) >= 2) + sampled_requests: List[Tuple[str, int, int, Dict[str, + Collection[str]]]] = [] + for data in filtered_dataset: + if len(sampled_requests) == num_requests: + break + + # Tokenize the prompts and completions. + prompt = data["conversations"][0]["value"] + prompt_token_ids = tokenizer(prompt).input_ids + completion = data["conversations"][1]["value"] + completion_token_ids = tokenizer(completion).input_ids + prompt_len = len(prompt_token_ids) + output_len = len(completion_token_ids + ) if fixed_output_len is None else fixed_output_len + if fixed_output_len is None and (prompt_len < 4 or output_len < 4): + # Prune too short sequences. + continue + if fixed_output_len is None and \ + (prompt_len > 1024 or prompt_len + output_len > 2048): + # Prune too long sequences. + continue + + if "image" in data and isinstance(data["image"], Image): + image: Image = data["image"] + image = image.convert("RGB") + image_data = io.BytesIO() + image.save(image_data, format='JPEG') + image_base64 = base64.b64encode( + image_data.getvalue()).decode("utf-8") + mm_content = { + "type": "image_url", + "image_url": { + "url": f"data:image/jpeg;base64,{image_base64}" + }, + } + else: + mm_content = None + + sampled_requests.append((prompt, prompt_len, output_len, mm_content)) return sampled_requests def sample_random_requests( - input_len: int, output_len: int, num_prompts: int, range_ratio: float, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]: + prefix_len: int, + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int]]: + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=prefix_len).tolist() input_lens = np.random.randint( int(input_len * range_ratio), @@ -204,10 +280,12 @@ def sample_random_requests( offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size + prompt = tokenizer.decode(prefix_token_ids + + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) - input_requests.append( - (prompt, int(input_lens[i]), int(output_lens[i]))) + + input_requests.append((prompt, int(prefix_len + input_lens[i]), + int(output_lens[i]), None)) return input_requests @@ -235,6 +313,8 @@ def calculate_metrics( outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: List[str], + selected_percentiles: List[float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -242,6 +322,7 @@ def calculate_metrics( itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] + e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: # We use the tokenizer to count the number of output tokens for all @@ -258,6 +339,7 @@ def calculate_metrics( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) completed += 1 else: actual_output_lens.append(0) @@ -272,21 +354,29 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, - input_throughput=total_input / dur_s, output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend - median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, - p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.median(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.mean(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics, actual_output_lens @@ -295,13 +385,18 @@ def calculate_metrics( async def benchmark( backend: str, api_url: str, + base_url: str, model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], + logprobs: Optional[int], best_of: int, - use_beam_search: bool, request_rate: float, disable_tqdm: bool, + profile: bool, + selected_percentile_metrics: List[str], + selected_percentiles: List[str], + ignore_eos: bool, ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -309,15 +404,22 @@ async def benchmark( raise ValueError(f"Unknown backend: {backend}") print("Starting initial single prompt test run...") - test_prompt, test_prompt_len, test_output_len = input_requests[0] + test_prompt, test_prompt_len, test_output_len, test_mm_content = ( + input_requests[0]) + if backend != "openai-chat" and test_mm_content is not None: + # multi-modal benchmark is only available on OpenAI Chat backend. + raise ValueError( + "Multi-modal content is only supported on 'openai-chat' backend.") test_input = RequestFuncInput( model=model_id, prompt=test_prompt, api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, - use_beam_search=use_beam_search, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos, ) test_output = await request_func(request_func_input=test_input) if not test_output.success: @@ -326,6 +428,22 @@ async def benchmark( f"are correctly specified. Error: {test_output.error}") else: print("Initial test run completed. Starting main benchmark run...") + + if profile: + print("Starting profiler...") + profile_input = RequestFuncInput(model=model_id, + prompt=test_prompt, + api_url=base_url + "/start_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=test_mm_content, + ignore_eos=ignore_eos) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler started") + print(f"Traffic request rate: {request_rate}") pbar = None if disable_tqdm else tqdm(total=len(input_requests)) @@ -333,22 +451,37 @@ async def benchmark( benchmark_start_time = time.perf_counter() tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - request_func_input = RequestFuncInput( - model=model_id, - prompt=prompt, - api_url=api_url, - prompt_len=prompt_len, - output_len=output_len, - best_of=best_of, - use_beam_search=use_beam_search, - ) + prompt, prompt_len, output_len, mm_content = request + request_func_input = RequestFuncInput(model=model_id, + prompt=prompt, + api_url=api_url, + prompt_len=prompt_len, + output_len=output_len, + logprobs=logprobs, + best_of=best_of, + multi_modal_content=mm_content, + ignore_eos=ignore_eos) tasks.append( asyncio.create_task( request_func(request_func_input=request_func_input, pbar=pbar))) outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks) + if profile: + print("Stopping profiler...") + profile_input = RequestFuncInput( + model=model_id, + prompt=test_prompt, + api_url=base_url + "/stop_profile", + prompt_len=test_prompt_len, + output_len=test_output_len, + logprobs=logprobs, + best_of=best_of, + ) + profile_output = await request_func(request_func_input=profile_input) + if profile_output.success: + print("Profiler stopped") + if pbar is not None: pbar.close() @@ -359,6 +492,8 @@ async def benchmark( outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -370,27 +505,10 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", - metrics.input_throughput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) - print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) - print("{:<40} {:<10.2f}".format("Median TTFT (ms):", - metrics.median_ttft_ms)) - print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) - print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', - n=50, - c='-')) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", - metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) - print("=" * 50) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) result = { "duration": benchmark_duration, @@ -398,20 +516,8 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, + "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -419,6 +525,47 @@ async def benchmark( "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function prints and adds statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + return result @@ -433,8 +580,10 @@ def main(args: argparse.Namespace): if args.base_url is not None: api_url = f"{args.base_url}{args.endpoint}" + base_url = f"{args.base_url}" else: api_url = f"http://{args.host}:{args.port}{args.endpoint}" + base_url = f"http://{args.host}:{args.port}" tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code) @@ -471,9 +620,9 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt, prompt_len, output_len) + input_requests = [(prompt, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] else: assert ( tokenizer.chat_template or tokenizer.default_chat_template @@ -486,12 +635,23 @@ def main(args: argparse.Namespace): prefix_len=args.sonnet_prefix_len, tokenizer=tokenizer, ) - input_requests = [(prompt_formatted, prompt_len, output_len) + input_requests = [(prompt_formatted, prompt_len, output_len, None) for prompt, prompt_formatted, prompt_len, - output_len in input_requests] + output_len, _ in input_requests] + + elif args.dataset_name == "hf": + input_requests = sample_hf_requests( + dataset_path=args.dataset_path, + dataset_subset=args.hf_subset, + dataset_split=args.hf_split, + num_requests=args.num_prompts, + tokenizer=tokenizer, + fixed_output_len=args.hf_output_len, + ) elif args.dataset_name == "random": input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, @@ -506,13 +666,20 @@ def main(args: argparse.Namespace): benchmark( backend=backend, api_url=api_url, + base_url=base_url, model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, + logprobs=args.logprobs, best_of=args.best_of, - use_beam_search=args.use_beam_search, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, + profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], + ignore_eos=args.ignore_eos, )) # Save config and results to json @@ -526,7 +693,6 @@ def main(args: argparse.Namespace): result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id result_json["best_of"] = args.best_of - result_json["use_beam_search"] = args.use_beam_search result_json["num_prompts"] = args.num_prompts # Metadata @@ -554,7 +720,7 @@ def main(args: argparse.Namespace): file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w") as outfile: + with open(file_name, "w", encoding='utf-8') as outfile: json.dump(result_json, outfile) @@ -592,13 +758,14 @@ if __name__ == "__main__": "--dataset-name", type=str, default="sharegpt", - choices=["sharegpt", "sonnet", "random"], + choices=["sharegpt", "sonnet", "random", "hf"], help="Name of the dataset to benchmark on.", ) parser.add_argument("--dataset-path", type=str, default=None, - help="Path to the dataset.") + help="Path to the sharegpt/sonnet dataset. " + "Or the huggingface dataset ID if using HF dataset.") parser.add_argument( "--model", type=str, @@ -626,52 +793,14 @@ if __name__ == "__main__": help="Number of prompts to process.", ) parser.add_argument( - "--sharegpt-output-len", + "--logprobs", type=int, default=None, - help="Output length for each request. Overrides the output length " - "from the ShareGPT dataset.") - parser.add_argument( - "--sonnet-input-len", - type=int, - default=550, - help= - "Number of input tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-output-len", - type=int, - default=150, - help= - "Number of output tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--sonnet-prefix-len", - type=int, - default=200, - help= - "Number of prefix tokens per request, used only for sonnet dataset.", - ) - parser.add_argument( - "--random-input-len", - type=int, - default=1024, - help= - "Number of input tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-output-len", - type=int, - default=128, - help= - "Number of output tokens per request, used only for random sampling.", - ) - parser.add_argument( - "--random-range-ratio", - type=float, - default=1.0, - help="Range of sampled ratio of input/output length, " - "used only for random sampling.", + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), ) parser.add_argument( "--request-rate", @@ -693,6 +822,12 @@ if __name__ == "__main__": action="store_true", help="Specify to disable tqdm progress bar.", ) + parser.add_argument( + "--profile", + action="store_true", + help="Use Torch Profiler. The endpoint must be launched with " + "VLLM_TORCH_PROFILER_DIR to enable profiler.", + ) parser.add_argument( "--save-result", action="store_true", @@ -722,6 +857,108 @@ if __name__ == "__main__": "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" " format.", ) + parser.add_argument( + "--ignore-eos", + action="store_true", + help="Set ignore_eos flag when sending the benchmark request." + "Warning: ignore_eos is not supported in deepspeed_mii and tgi.") + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) + + # group for dataset specific arguments + sonnet_group = parser.add_argument_group("sonnet dataset options") + sonnet_group.add_argument( + "--sonnet-input-len", + type=int, + default=550, + help= + "Number of input tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-output-len", + type=int, + default=150, + help= + "Number of output tokens per request, used only for sonnet dataset.", + ) + sonnet_group.add_argument( + "--sonnet-prefix-len", + type=int, + default=200, + help= + "Number of prefix tokens per request, used only for sonnet dataset.", + ) + + sharegpt_group = parser.add_argument_group("sharegpt dataset options") + sharegpt_group.add_argument( + "--sharegpt-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output length " + "from the ShareGPT dataset.") + + random_group = parser.add_argument_group("random dataset options") + random_group.add_argument( + "--random-input-len", + type=int, + default=1024, + help= + "Number of input tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-output-len", + type=int, + default=128, + help= + "Number of output tokens per request, used only for random sampling.", + ) + random_group.add_argument( + "--random-range-ratio", + type=float, + default=1.0, + help="Range of sampled ratio of input/output length, " + "used only for random sampling.", + ) + random_group.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") + + hf_group = parser.add_argument_group("hf dataset options") + hf_group.add_argument("--hf-subset", + type=str, + default=None, + help="Subset of the HF dataset.") + hf_group.add_argument("--hf-split", + type=str, + default=None, + help="Split of the HF dataset.") + hf_group.add_argument( + "--hf-output-len", + type=int, + default=None, + help="Output length for each request. Overrides the output lengths " + "from the sampled HF dataset.", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index f445ea8b625a765d14a6d28edf4b5656c0c19cb7..b40bc5ef9f7e99369aea28a4f543dea9a8c81068 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -7,14 +7,17 @@ from typing import List, Optional, Tuple import numpy as np import torch +import uvloop from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import EngineArgs -from vllm.inputs import PromptInputs +from vllm.engine.arg_utils import DEVICE_OPTIONS, AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser +from vllm.sampling_params import BeamSearchParams +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def sample_requests( @@ -72,7 +75,6 @@ def run_vllm( tensor_parallel_size: int, seed: int, n: int, - use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], @@ -85,8 +87,10 @@ def run_vllm( max_num_batched_tokens: int, distributed_executor_backend: Optional[str], gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -109,6 +113,8 @@ def run_vllm( max_num_batched_tokens=max_num_batched_tokens, distributed_executor_backend=distributed_executor_backend, load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + disable_async_output_proc=disable_async_output_proc, ) # Add the requests to the engine. @@ -119,13 +125,12 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) - + # warmup warmup_prompts = [] warmup_sampling_params = [] @@ -134,9 +139,8 @@ def run_vllm( warmup_sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) @@ -148,7 +152,7 @@ def run_vllm( # dummy_prompt_token_ids = np.random.randint(10000, # size=(args.num_prompts, # args.input_len)) - # dummy_inputs: List[PromptInputs] = [{ + # dummy_prompts: List[PromptType] = [{ # "prompt_token_ids": batch # } for batch in dummy_prompt_token_ids.tolist()] @@ -160,23 +164,122 @@ def run_vllm( # print("Warming up...") # for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): # run_to_completion() + - start = time.perf_counter() - llm.generate(prompts, sampling_params, use_tqdm=True) - end = time.perf_counter() + use_beam_search = False + + if not use_beam_search: + start = time.perf_counter() + llm.generate(prompts, sampling_params, use_tqdm=True) + end = time.perf_counter() + else: + prompts = [prompt for prompt, _, _ in requests] + # output_len should be the same for all requests. + output_len = requests[0][2] + for prompt, input_len, _output_len in requests: + assert _output_len == output_len + start = time.perf_counter() + llm.beam_search( + prompts, + BeamSearchParams( + beam_width=n, + max_tokens=output_len, + ignore_eos=True, + )) + end = time.perf_counter() return end - start +async def run_vllm_async( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, + disable_frontend_multiprocessing: bool = False, +) -> float: + from vllm import SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + disable_log_requests=True, + ) + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=1.0, + top_p=1.0, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + def run_hf( requests: List[Tuple[str, int, int]], model: str, tokenizer: PreTrainedTokenizerBase, n: int, - use_beam_search: bool, max_batch_size: int, trust_remote_code: bool, ) -> float: - assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": @@ -208,7 +311,7 @@ def run_hf( padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), - do_sample=not use_beam_search, + do_sample=True, num_return_sequences=n, temperature=1.0, top_p=1.0, @@ -252,12 +355,11 @@ def main(args: argparse.Namespace): # Sample the requests. tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) + warmup_prompt = "hi" * 10 + warmup_requests = [(warmup_prompt, 10, 10) + for _ in range(1)] if args.dataset is None: # Synthesize a prompt with the given input length. - warmup_prompt = "hi" * 10 - warmup_requests = [(warmup_prompt, 10, 10) - for _ in range(1)] - prompt = "hi" * (args.input_len - 1) requests = [(prompt, args.input_len, args.output_len) for _ in range(args.num_prompts)] @@ -266,20 +368,40 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm( - warmup_requests, requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.distributed_executor_backend, - args.gpu_memory_utilization, args.download_dir, args.load_format) + if args.async_engine: + run_args = [ + requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.download_dir, args.load_format, args.disable_async_output_proc + ] + else: + run_args = [ + warmup_requests, requests, args.model, args.tokenizer, args.quantization, + args.tensor_parallel_size, args.seed, args.n, + args.trust_remote_code, args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, args.enable_chunked_prefill, + args.max_num_batched_tokens, args.distributed_executor_backend, + args.gpu_memory_utilization, args.num_scheduler_steps, + args.download_dir, args.load_format, args.disable_async_output_proc + ] + + if args.async_engine: + run_args.append(args.disable_frontend_multiprocessing) + elapsed_time = uvloop.run(run_vllm_async(*run_args)) + else: + elapsed_time = run_vllm(*run_args) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.use_beam_search, args.hf_max_batch_size, - args.trust_remote_code) + args.hf_max_batch_size, args.trust_remote_code) elif args.backend == "mii": elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len) @@ -341,7 +463,6 @@ if __name__ == "__main__": type=int, default=1, help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") parser.add_argument('--num-iters-warmup', type=int, default=1, @@ -400,17 +521,20 @@ if __name__ == "__main__": 'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' 'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' 'instead supported for common inference criteria.') + parser.add_argument("--device", + type=str, + default="auto", + choices=DEVICE_OPTIONS, + help='device type for vLLM execution') parser.add_argument( - "--device", - type=str, - default="auto", - choices=["auto", "cuda", "cpu", "openvino", "tpu", "xpu"], - help='device type for vLLM execution, supporting CUDA, OpenVINO and ' - 'CPU.') + "--num-scheduler-steps", + type=int, + default=1, + help="Maximum number of forward steps per scheduler call.") parser.add_argument( "--enable-prefix-caching", action='store_true', - help="enable automatic prefix caching for vLLM backend.") + help="Enable automatic prefix caching for vLLM backend.") parser.add_argument("--enable-chunked-prefill", action='store_true', help="enable chunked prefill for vLLM backend.") @@ -459,6 +583,19 @@ if __name__ == "__main__": 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model @@ -481,8 +618,6 @@ if __name__ == "__main__": raise ValueError("dtype must be auto for MII backend.") if args.n != 1: raise ValueError("n must be 1 for MII backend.") - if args.use_beam_search: - raise ValueError("Beam search is not supported for MII backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") if args.hf_max_batch_size is not None: diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 64011b2db23952f41fdf1df934049fa13e591129..63cf5d50cac751671f020e4a6dd5b59424c51753 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -32,7 +32,6 @@ def to_int8(tensor: torch.Tensor) -> torch.Tensor: def make_rand_tensors(dtype: torch.dtype, m: int, n: int, k: int) -> Tuple[torch.Tensor, torch.Tensor]: - a = torch.randn((m, k), device='cuda') * 5 b = torch.randn((n, k), device='cuda').t() * 5 @@ -44,59 +43,18 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int, raise ValueError("unsupported dtype") -# impl - - -def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch.mm(a, b) - - -def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype) - - -def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor, - scale_a: torch.Tensor, scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return torch._scaled_mm(a, - b, - scale_a=scale_a, - scale_b=scale_b, - out_dtype=out_dtype, - use_fast_accum=True) - - -def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, - out_dtype: torch.dtype) -> torch.Tensor: - return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype) - - # bench -def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, - scale_b: torch.Tensor, out_dtype: torch.dtype, label: str, - sub_label: str, fn: Callable, description: str) -> TMeasurement: - +def bench_fn(label: str, sub_label: str, description: str, fn: Callable, *args, + **kwargs) -> TMeasurement: min_run_time = 1 globals = { - "a": a, - "b": b, - "scale_a": scale_a, - "scale_b": scale_b, - "out_dtype": out_dtype, + "args": args, + "kwargs": kwargs, "fn": fn, } return TBenchmark.Timer( - stmt="fn(a, b, scale_a, scale_b, out_dtype)", + stmt="fn(*args, **kwargs)", globals=globals, label=label, sub_label=sub_label, @@ -110,26 +68,58 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.int8, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) + azp = torch.zeros((m, ), device="cuda", dtype=torch.int32) + azp_adj = torch.zeros((n, ), device="cuda", dtype=torch.int32) timers = [] # pytorch impl - bfloat16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16), + b.to(dtype=torch.bfloat16))) # pytorch impl - float16 timers.append( - bench_fn(a.to(dtype=torch.float16, device="cuda"), - b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b, - torch.float16, label, sub_label, pytorch_mm_impl, - "pytorch_fp16_fp16_fp16_matmul-no-scales")) + bench_fn(label, sub_label, + "pytorch_fp16_fp16_fp16_matmul-no-scales", torch.mm, + a.to(dtype=torch.float16), b.to(dtype=torch.float16))) # cutlass impl timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_i8_i8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) + + # cutlass with bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass with azp per-tensor + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj)) + + # cutlass with azp per-tensor + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, None, bias)) + + # cutlass with azp per-token + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp)) + + # cutlass with azp per-token + bias + timers.append( + bench_fn(label, sub_label, "cutlass_i8_i8_bf16_scaled_mm_azp_pt_bias", + ops.cutlass_scaled_mm_azp, a, b, scale_a, scale_b, + torch.bfloat16, azp_adj, azp, bias)) return timers @@ -140,46 +130,88 @@ def bench_fp8(dtype: torch.dtype, m: int, k: int, n: int, label: str, a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k) scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) + bias = torch.zeros((n, ), device="cuda", dtype=torch.bfloat16) timers = [] # pytorch impl w. bf16 timers.append( - bench_fn(a.to(dtype=torch.bfloat16, device="cuda"), - b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b, - torch.bfloat16, label, sub_label, pytorch_mm_impl, - "pytorch_bf16_bf16_bf16_matmul-no-scales")) + bench_fn(label, sub_label, "pytorch_bf16_bf16_bf16_matmul-no-scales", + torch.mm, a.to(dtype=torch.bfloat16, device="cuda"), + b.to(dtype=torch.bfloat16, device="cuda"))) # pytorch impl: bf16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16)) # pytorch impl: bf16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_bf16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.bfloat16, + use_fast_accum=True)) # pytorch impl: fp16 output, without fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl, "pytorch_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16)) # pytorch impl: fp16 output, with fp8 fast accum timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - pytorch_fp8_impl_fast_accum, - "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum")) + bench_fn(label, + sub_label, + "pytorch_fp8_fp8_fp16_scaled_mm_fast_accum", + torch._scaled_mm, + a, + b, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=torch.float16, + use_fast_accum=True)) # cutlass impl: bf16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_bf16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, + torch.bfloat16)) # cutlass impl: fp16 output timers.append( - bench_fn(a, b, scale_a, scale_b, torch.float16, label, sub_label, - cutlass_impl, "cutlass_fp8_fp8_fp16_scaled_mm")) + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16)) + + # cutlass impl: bf16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_bf16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.bfloat16, + bias)) + + # cutlass impl: fp16 output, with bias + timers.append( + bench_fn(label, sub_label, "cutlass_fp8_fp8_fp16_scaled_mm_bias", + ops.cutlass_scaled_mm, a, b, scale_a, scale_b, torch.float16, + bias.to(dtype=torch.float16))) + return timers @@ -200,7 +232,6 @@ def print_timers(timers: Iterable[TMeasurement]): def run(dtype: torch.dtype, MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: - results = [] for m, k, n in MKNs: timers = bench(dtype, m, k, n, f"scaled-{dtype}-gemm", @@ -216,7 +247,6 @@ def make_output(data: Iterable[TMeasurement], MKNs: Iterable[Tuple[int, int, int]], base_description: str, timestamp=None): - print(f"== All Results {base_description} ====") print_timers(data) @@ -251,7 +281,6 @@ def run_range_bench(args): def run_model_bench(args): - print("Benchmarking models:") for i, model in enumerate(args.models): print(f"[{i}] {model}") diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py new file mode 100644 index 0000000000000000000000000000000000000000..92f6053cc6d7e6d76b2ff3bbc8fe92bd2a9c0b8a --- /dev/null +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -0,0 +1,86 @@ +import time + +import torch + +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + add_residual: bool, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + seed_everything(seed) + torch.set_default_device("cuda") + + layer = RMSNorm(hidden_size).to(dtype=dtype) + layer.weight.data.normal_(mean=1.0, std=0.1) + scale = 1 / (2 * hidden_size) + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + x *= scale + residual = torch.randn_like(x) * scale if add_residual else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + layer(x, residual) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser( + description="Benchmark the layernorm kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--add-residual", action="store_true") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + add_residual=args.add_residual, + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py new file mode 100644 index 0000000000000000000000000000000000000000..b70c4b94c97a1af37cd5090ac26876fc5cc40738 --- /dev/null +++ b/benchmarks/kernels/benchmark_machete.py @@ -0,0 +1,420 @@ +import argparse +import copy +import itertools +import math +import pickle as pkl +import time +from itertools import product +from typing import Callable, Iterable, List, Optional, Tuple + +import pandas as pd +import torch +import torch.utils.benchmark as TBenchmark +from torch.utils.benchmark import Measurement as TMeasurement +from weight_shapes import WEIGHT_SHAPES + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N, marlin_permute_scales) +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + MarlinWorkspace) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + gptq_pack, pack_rows, quantize_weights) +from vllm.scalar_type import ScalarType, scalar_types +from vllm.utils import FlexibleArgumentParser + +DEFAULT_MODELS = ["meta-llama/Llama-3-8b", "meta-llama/Llama-2-70b-hf"] +DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] +DEFAULT_TP_SIZES = [1] + + +def machete_pack_weights(w_q: torch.tensor, wtype: ScalarType) -> torch.tensor: + w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape) + w_q = w_q.t().contiguous().t() # make col major + return ops.machete_prepack_B(w_q, wtype) + + +def make_bench_tensors( + atype: torch.dtype, wtype: ScalarType, group_size: int, m: int, n: int, + k: int +) -> Tuple[torch.tensor, List[Tuple[torch.tensor, torch.tensor, torch.tensor, + torch.tensor]]]: + assert wtype.is_integer(), "TODO: support floating point weights" + + # we want to make sure that weights don't fit into L2 cache between runs so + # we construct enough weights to exceed L2 cache, which is 50mb on a H100 + # so we target total weight size > 2*50mb + num_weights = math.ceil(2 * 50 * 1024**2 * 8 / (k * n * wtype.size_bits)) + + a = torch.randn((m, k), device="cuda", dtype=atype) * 5 + weights = [ + torch.randn((k, n), device="cuda", dtype=atype) + for _ in range(num_weights) + ] + quanitized_weights = [ + quantize_weights(w, wtype, group_size) for w in weights + ] + + return a, quanitized_weights + + +# impl + + +# bench +def bench_fn(label: str, sub_label: str, description: str, + fn: Callable) -> TMeasurement: + + min_run_time = 1 + return TBenchmark.Timer( + stmt="fn()", + globals={ + "fn": fn + }, + label=label, + sub_label=sub_label, + description=description, + ).blocked_autorange(min_run_time=min_run_time) + + +def loop_over_weights( + a: torch.tensor, weights: List[Tuple[torch.tensor, torch.tensor, + torch.tensor, torch.tensor]], + fn: Callable[[torch.tensor, torch.tensor, torch.tensor, torch.tensor], + None]): + for w_ref, w_q, w_s, _ in weights: + fn(a, w_ref, w_q, w_s) + + +_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None +_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None + + +def bench(atype: torch.dtype, + wtype: ScalarType, + group_size: int, + m: int, + k: int, + n: int, + label: str, + sub_label: str, + benchmark_marlinv1: bool = True, + sweep_schedules: bool = True) -> Iterable[TMeasurement]: + global _SWEEP_SCHEDULES_RESULTS + + a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k) + sub_label += f", L={len(weights)}" + + weights_machete = [(w_ref, machete_pack_weights(w_q, wtype), w_s, w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + timers = [] + # pytorch impl + timers.append( + bench_fn( + label, sub_label, "torch.matmul", lambda: loop_over_weights( + a, + weights, + lambda a, w_ref, w_q, w_s: torch.matmul(a, w_ref), + ))) + + if benchmark_marlinv1: + w_ref = weights[0][0] + + w_zp_empty = torch.empty(0, dtype=torch.int, device=w_ref.device) + sort_indices = torch.empty(0, dtype=torch.int, device=w_ref.device) + g_idx = torch.empty(0, dtype=torch.int, device=w_ref.device) + + def marlinv1_pack_weights(w_q: torch.tensor) -> torch.tensor: + w_q_gptq = gptq_pack(w_q, wtype.size_bits, *w_ref.shape) + return ops.gptq_marlin_repack(w_q_gptq, sort_indices, *w_ref.shape, + wtype.size_bits) + + def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor: + return marlin_permute_scales(w_s, *w_ref.shape, group_size) + + weights_marlinv1 = [(w_ref, marlinv1_pack_weights(w_q), + marlinv1_permute_scales(w_s), w_zp) + for w_ref, w_q, w_s, w_zp in weights] + + workspace = MarlinWorkspace(w_ref.shape[1], GPTQ_MARLIN_MIN_THREAD_N, + GPTQ_MARLIN_MAX_PARALLEL) + + # marlinv1 + timers.append( + bench_fn( + label, sub_label, "marlin_orig", lambda: loop_over_weights( + a, weights_marlinv1, lambda a, w_ref, w_q, w_s: ops. + gptq_marlin_gemm(a, + w_q, + w_s, + w_zp_empty, + g_idx, + sort_indices, + workspace.scratch, + wtype, + size_m=a.shape[0], + size_n=w_ref.shape[1], + size_k=w_ref.shape[0], + is_k_full=True)))) + + # machete + timers.append( + bench_fn( + label, sub_label, "machete_heuristic", lambda: loop_over_weights( + a, weights_machete, lambda a, _, w_q, w_s: ops.machete_gemm( + a, w_q, wtype, b_scales=w_s, b_group_size=group_size)))) + + if sweep_schedules: + print("Finding best schedule for machete") + best = None + best_schedule = None + schedules = ops.machete_supported_schedules(wtype) + for schedule in reversed(schedules): + schedule_M = int(schedule.split("_")[0].split("x")[1]) + + # Prune known bad schedules + if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4: + continue + + def run(a, _, w_q, w_s, schedule=schedule): + ops.machete_gemm(a, + w_q, + wtype, + w_s, + b_group_size=group_size, + schedule=schedule) + + res = bench_fn(label, sub_label, "machete_best", + lambda: loop_over_weights(a, weights_machete, run)) + + results_row = { + "M": m, + "K": k, + "N": n, + "group_size": group_size, + "schedule": schedule, + "median": res.median, + } + if _SWEEP_SCHEDULES_RESULTS is None: + _SWEEP_SCHEDULES_RESULTS = pd.DataFrame( + columns=results_row.keys()) + _SWEEP_SCHEDULES_RESULTS.\ + loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row + + print(f" {res.median:5.5} ", schedule) + if not best or res.median < best.median: + best = res + best_schedule = schedule + print("Best schedule:", best_schedule) + timers.append(best) + + return timers + + +# runner +def print_timers(timers: Iterable[TMeasurement]): + compare = TBenchmark.Compare(timers) + compare.print() + + +def run(dtype: torch.dtype, sweep_schedules: bool, + MKNs: Iterable[Tuple[int, int, int]]) -> Iterable[TMeasurement]: + + results = [] + for m, k, n in MKNs: + timers = bench(dtype, + scalar_types.uint4b8, + 128, + m, + k, + n, + f"{dtype}-gemm", + f"MKN=({m}x{k}x{n})", + sweep_schedules=sweep_schedules) + print_timers(timers) + results.extend(timers) + + return results + + +# output makers +def make_output( + data: Iterable[TMeasurement], + MKNs: Iterable[Tuple[int, int, int]], + base_description: str, + timestamp=None, +): + + print(f"== All Results {base_description} ====") + print_timers(data) + + # pickle all the results + timestamp = int(time.time()) if timestamp is None else timestamp + with open(f"{base_description}-{timestamp}.pkl", "wb") as f: + pkl.dump(data, f) + + +# argparse runners + + +def run_square_bench(args): + dim_sizes = list( + range(args.dim_start, args.dim_end + 1, args.dim_increment)) + MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"square_bench-{args.dtype}") + + +def run_range_bench(args): + m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")] + m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")] + m_increment, k_increment, n_increment = \ + [int(x) for x in args.dim_increment.split(",")] + Ms = list(range(m_start, m_end + 1, m_increment)) + Ks = list(range(k_start, k_end + 1, k_increment)) + Ns = list(range(n_start, n_end + 1, n_increment)) + MKNs = list(product(Ms, Ks, Ns)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + + make_output(data, MKNs, f"range_bench-{args.dtype}") + + +def run_model_bench(args): + + print("Benchmarking models:") + for i, model in enumerate(args.models): + print(f"[{i}] {model}") + + def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]: + KNs = [] + for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]): + KN[tp_split_dim] = KN[tp_split_dim] // tp_size + KNs.append(KN) + return KNs + + model_bench_data = [] + models_tps = list(itertools.product(args.models, args.tp_sizes)) + for model, tp_size in models_tps: + Ms = args.batch_sizes + KNs = model_shapes(model, tp_size) + MKNs = [] + for m in Ms: + for k, n in KNs: + MKNs.append((m, k, n)) + + data = run(args.dtype, args.sweep_schedules, MKNs) + model_bench_data.append(data) + + # Print all results + for data, model_tp in zip(model_bench_data, models_tps): + model, tp_size = model_tp + print(f"== Results {args.dtype} {model}-TP{tp_size} ====") + print_timers(data) + + timestamp = int(time.time()) + + all_data = [] + for d in model_bench_data: + all_data.extend(d) + # pickle all data + with open(f"model_bench-{args.dtype}-{timestamp}.pkl", "wb") as f: + pkl.dump(all_data, f) + + +if __name__ == "__main__": + + def to_torch_dtype(dt): + if dt == "bfloat16": + return torch.bfloat16 + if dt == "float16": + return torch.float16 + raise ValueError("unsupported dtype") + + parser = FlexibleArgumentParser( + description=""" +Benchmark Machete GEMM. + + To run square GEMMs: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 square_bench --dim-start 128 --dim-end 512 --dim-increment 64 + + To run constant N and K and sweep M: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 range_bench --dim-start 128 --dim-end 512 --dim-increment 64 --n-constant 16384 --k-constant 16384 + + To run dimensions from a model: + python3 ./benchmarks/kernels/benchmark_machete.py --dtype float16 model_bench --models meta-llama/Llama-2-7b-hf --batch-sizes 16 --tp-sizes 1 + + Output: + - a .pkl file, that is a list of raw torch.benchmark.utils.Measurements for the pytorch and cutlass implementations for the various GEMMs. + """, # noqa: E501 + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--dtype", + type=to_torch_dtype, + required=True, + help="Available options are ['bfloat16', 'float16']", + ) + parser.add_argument( + "--sweep-schedules", + action="store_true", + help="Run a sweep over all supported schedules", + ) + parser.add_argument("--sweep-csv-out", + help="CSV to store sweep results", + default="sch_sweep_results.csv") + subparsers = parser.add_subparsers(dest="cmd", required=True) + + square_parser = subparsers.add_parser("square_bench") + square_parser.add_argument("--dim-start", type=int, required=True) + square_parser.add_argument("--dim-end", type=int, required=True) + square_parser.add_argument("--dim-increment", type=int, required=True) + square_parser.set_defaults(func=run_square_bench) + + range_parser = subparsers.add_parser("range_bench") + range_parser.add_argument( + "--dim-start", + type=str, + required=True, + help="Start value for M,K,N as common separated list") + range_parser.add_argument( + "--dim-end", + type=str, + required=True, + help="End value (inclusive) for M,K,N as common separated list") + range_parser.add_argument( + "--dim-increment", + type=str, + required=True, + help="Increment value for M,K,N as common separated list") + range_parser.set_defaults(func=run_range_bench) + + model_parser = subparsers.add_parser("model_bench") + model_parser.add_argument( + "--models", + nargs="+", + type=str, + default=DEFAULT_MODELS, + choices=WEIGHT_SHAPES.keys(), + ) + model_parser.add_argument("--tp-sizes", + nargs="+", + type=int, + default=DEFAULT_TP_SIZES) + model_parser.add_argument("--batch-sizes", + nargs="+", + type=int, + default=DEFAULT_BATCH_SIZES) + model_parser.set_defaults(func=run_model_bench) + + args = parser.parse_args() + + _SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out + args.func(args) + + if _SWEEP_SCHEDULES_RESULTS is not None: + _SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV) diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index e00696d6d43cb5a02c38cef48455033ed23380b6..c2ad98b7e265651275e4c113bf26abc6ef497564 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -10,7 +10,7 @@ from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig from vllm.model_executor.layers.fused_moe.fused_moe import * -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything class BenchmarkConfig(TypedDict): @@ -30,19 +30,36 @@ def benchmark_config( hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, num_iters: int = 100, ) -> float: - init_dtype = torch.float16 if use_fp8 else dtype + init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) - w1 = torch.randn(num_experts, - shard_intermediate_size, - hidden_size, - dtype=init_dtype) - w2 = torch.randn(num_experts, - hidden_size, - shard_intermediate_size // 2, - dtype=init_dtype) + if use_int8_w8a16: + w1 = torch.randint(-127, + 127, ( + num_experts, + shard_intermediate_size, + hidden_size, + ), + dtype=torch.int8) + w2 = torch.randint(-127, + 127, ( + num_experts, + hidden_size, + shard_intermediate_size // 2, + ), + dtype=torch.int8) + else: + w1 = torch.randn(num_experts, + shard_intermediate_size, + hidden_size, + dtype=init_dtype) + w2 = torch.randn(num_experts, + hidden_size, + shard_intermediate_size // 2, + dtype=init_dtype) gating_output = torch.randn(num_iters, num_tokens, num_experts, @@ -52,7 +69,11 @@ def benchmark_config( w2_scale = None a1_scale = None a2_scale = None - if use_fp8: + if use_int8_w8a16: + w1_scale = torch.randn((num_experts, 2 * shard_intermediate_size), + dtype=torch.float32) + w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) + if use_fp8_w8a8: w1_scale = torch.randn(num_experts, dtype=torch.float32) w2_scale = torch.randn(num_experts, dtype=torch.float32) a1_scale = torch.randn(1, dtype=torch.float32) @@ -76,7 +97,8 @@ def benchmark_config( renormalize=True, inplace=True, override_config=config, - use_fp8=use_fp8, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, @@ -144,7 +166,7 @@ class BenchmarkWorker: def __init__(self, seed: int) -> None: torch.set_default_device("cuda") - torch.cuda.manual_seed_all(seed) + seed_everything(seed) self.seed = seed def benchmark( @@ -155,11 +177,13 @@ class BenchmarkWorker: hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, + use_fp8_w8a8: bool, + use_int8_w8a16: bool, ) -> Tuple[Dict[str, int], float]: - torch.cuda.manual_seed_all(self.seed) - - dtype_str = "float8" if use_fp8 else None + seed_everything(self.seed) + dtype_str = get_config_dtype_str(dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8) # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. op_config = get_moe_configs(num_experts, shard_intermediate_size // 2, @@ -173,7 +197,8 @@ class BenchmarkWorker: key=lambda x: abs(x - num_tokens))] kernel_time = benchmark_config(config, num_tokens, num_experts, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) + topk, dtype, use_fp8_w8a8, + use_int8_w8a16) return config, kernel_time def tune( @@ -184,9 +209,10 @@ class BenchmarkWorker: hidden_size: int, topk: int, dtype: torch.dtype, - use_fp8: bool, - search_space: List[BenchmarkConfig], - ) -> BenchmarkConfig: + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + search_space: List[Dict[str, int]], + ) -> Dict[str, int]: best_config = None best_time = float("inf") for config in tqdm(search_space): @@ -198,7 +224,8 @@ class BenchmarkWorker: hidden_size, topk, dtype, - use_fp8, + use_fp8_w8a8, + use_int8_w8a16, num_iters=10) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. @@ -224,20 +251,19 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: } -def save_configs( - configs: Dict[int, BenchmarkConfig], - num_experts: int, - shard_intermediate_size: int, - hidden_size: int, - topk: int, - dtype: torch.dtype, - use_fp8: bool, -) -> None: - dtype_str = "float8" if use_fp8 else None +def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, + shard_intermediate_size: int, hidden_size: int, topk: int, + dtype: torch.dtype, use_fp8_w8a8: bool, + use_int8_w8a16: bool) -> None: + dtype_str = get_config_dtype_str(dtype, + use_int8_w8a16=use_int8_w8a16, + use_fp8_w8a8=use_fp8_w8a8) + # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name(num_experts, shard_intermediate_size // 2, dtype_str) + print(f"Writing best config to {filename}...") with open(filename, "w") as f: json.dump(configs, f, indent=4) @@ -253,6 +279,11 @@ def main(args: argparse.Namespace): topk = config.ffn_config.moe_top_k intermediate_size = config.ffn_config.ffn_hidden_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] == "JambaForCausalLM": + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size else: # Default: Mixtral. E = config.num_local_experts @@ -262,7 +293,8 @@ def main(args: argparse.Namespace): hidden_size = config.hidden_size dtype = config.torch_dtype - use_fp8 = args.dtype == "fp8" + use_fp8_w8a8 = args.dtype == "fp8_w8a8" + use_int8_w8a16 = args.dtype == "int8_w8a16" if args.batch_size is None: batch_sizes = [ @@ -294,21 +326,21 @@ def main(args: argparse.Namespace): start = time.time() configs = _distribute( "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8, search_space) + topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8) + topk, dtype, use_fp8_w8a8, use_int8_w8a16) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: - outputs = _distribute("benchmark", - [(batch_size, E, shard_intermediate_size, - hidden_size, topk, dtype, use_fp8) - for batch_size in batch_sizes]) + outputs = _distribute( + "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, + topk, dtype, use_fp8_w8a8, use_int8_w8a16) + for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") @@ -323,7 +355,7 @@ if __name__ == "__main__": parser.add_argument("--tp-size", "-tp", type=int, default=2) parser.add_argument("--dtype", type=str, - choices=["auto", "fp8"], + choices=["auto", "fp8_w8a8", "int8_w8a16"], default="auto") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--batch-size", type=int, required=False) diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index a04433142da422a4b059eaf0d7444b495f1d758f..cc80546c309673d035e0570c6af4ba5434a2d8ab 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -6,7 +6,7 @@ import torch from vllm import _custom_ops as ops from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, - create_kv_caches_with_random) + create_kv_caches_with_random, seed_everything) NUM_BLOCKS = 1024 PARTITION_SIZE = 512 @@ -28,10 +28,7 @@ def main( device: str = "cuda", kv_cache_dtype: Optional[str] = None, ) -> None: - random.seed(seed) - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) scale = float(1.0 / (head_size**0.5)) query = torch.empty(num_seqs, @@ -104,7 +101,25 @@ def main( for _ in range(num_iters): if version == "v1": - ops.paged_attention_v1( + if envs.VLLM_USE_OPT_OP: + ops.paged_attention_v1_opt( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_v1( output, query, key_cache, @@ -121,7 +136,28 @@ def main( v_scale, ) elif version == "v2": - ops.paged_attention_v2( + if envs.VLLM_USE_OPT_OP: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_v2_opt( output, exp_sums, max_logits, diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py new file mode 100644 index 0000000000000000000000000000000000000000..743a5744e8614fb014c138961d729e063d57d661 --- /dev/null +++ b/benchmarks/kernels/benchmark_quant.py @@ -0,0 +1,100 @@ +import time + +import torch + +from vllm import _custom_ops as ops +from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, + seed_everything) + + +@torch.inference_mode() +def main(num_tokens: int, + hidden_size: int, + static_scale: bool, + quant_dtype: torch.dtype, + dtype: torch.dtype, + seed: int = 0, + do_profile: bool = False, + num_warmup_iters: int = 5, + num_iters: int = 100) -> None: + seed_everything(seed) + torch.set_default_device("cuda") + + x = torch.randn(num_tokens, hidden_size, dtype=dtype) + scale = torch.randn(1, 1, dtype=torch.float32) if static_scale else None + + def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: + torch.cuda.synchronize() + if profile: + torch.cuda.cudart().cudaProfilerStart() + start_time = time.perf_counter() + + for _ in range(num_iters): + if quant_dtype == torch.int8: + ops.scaled_int8_quant(x, scale) + else: + ops.scaled_fp8_quant(x, scale) + torch.cuda.synchronize() + + end_time = time.perf_counter() + if profile: + torch.cuda.cudart().cudaProfilerStart() + return (end_time - start_time) / num_iters + + # Warmup. + print("Warming up...") + run_benchmark = run_cuda_benchmark + run_benchmark(num_iters=num_warmup_iters, profile=False) + + # Benchmark. + if do_profile: + latency = run_benchmark(num_iters=1, profile=True) + else: + latency = run_benchmark(num_iters=num_iters, profile=False) + print(f"Kernel running time: {latency * 1000000:.3f} us") + + +if __name__ == '__main__': + + def to_torch_dtype(dt): + if dt == "int8": + return torch.int8 + if dt == "fp8": + return torch.float8_e4m3fn + raise ValueError(f"Unsupported dtype: {dt}") + + parser = FlexibleArgumentParser( + description="Benchmark the quantization (fp8 or int8) kernel.") + parser.add_argument("--num-tokens", type=int, default=4096) + parser.add_argument("--hidden-size", type=int, default=8192) + parser.add_argument("--static-scale", action="store_true") + parser.add_argument("--quant-dtype", + type=str, + choices=["fp8", "int8"], + default="int8") + parser.add_argument("--dtype", + type=str, + choices=["half", "bfloat16", "float"], + default="half") + + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--profile", action="store_true") + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-iters", + type=int, + default=100, + help="Number of benchmark iterations. " + "If --profile is set, this number is ignored") + + args = parser.parse_args() + print(args) + + main(num_tokens=args.num_tokens, + hidden_size=args.hidden_size, + static_scale=args.static_scale, + quant_dtype=to_torch_dtype(args.quant_dtype), + dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], + seed=args.seed, + do_profile=args.profile, + num_warmup_iters=args.num_warmup_iters, + num_iters=args.num_iters) diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index f542684a9a2a9827c32adb8c0db6bd107919795f..784b1cf9844e4e63e85683fa96499eba1f2efea3 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -6,7 +6,7 @@ import torch from vllm.model_executor.layers.rotary_embedding import (RotaryEmbedding, get_rope) -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, seed_everything def benchmark_rope_kernels_multi_lora( @@ -22,9 +22,7 @@ def benchmark_rope_kernels_multi_lora( max_position: int = 8192, base: int = 10000, ) -> None: - torch.random.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) + seed_everything(seed) torch.set_default_device(device) if rotary_dim is None: rotary_dim = head_size @@ -33,7 +31,7 @@ def benchmark_rope_kernels_multi_lora( # batched RoPE can take multiple scaling factors batched_rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": tuple(scaling_factors) }) # non-batched RoPE takes only one scaling factor, we create multiple @@ -43,7 +41,7 @@ def benchmark_rope_kernels_multi_lora( non_batched_ropes.append( get_rope(head_size, rotary_dim, max_position, base, is_neox_style, { - "type": "linear", + "rope_type": "linear", "factor": (scaling_factor, ) })) diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py new file mode 100644 index 0000000000000000000000000000000000000000..de608fd05af708762c3179104108cc0ecec27976 --- /dev/null +++ b/benchmarks/kernels/graph_machete_bench.py @@ -0,0 +1,62 @@ +import math +import pickle +import re +from collections import defaultdict +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch.utils.benchmark import Measurement as TMeasurement + +from vllm.utils import FlexibleArgumentParser + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Benchmark the latency of processing a single batch of ' + 'requests till completion.') + parser.add_argument('filename', type=str) + + args = parser.parse_args() + + with open(args.filename, 'rb') as f: + data: List[TMeasurement] = pickle.load(f) + + results = defaultdict(lambda: list()) + for v in data: + result = re.search(r"MKN=\(\d+x(\d+x\d+)\)", v.task_spec.sub_label) + if result is not None: + KN = result.group(1) + else: + raise Exception("MKN not found") + result = re.search(r"MKN=\((\d+)x\d+x\d+\)", v.task_spec.sub_label) + if result is not None: + M = result.group(1) + else: + raise Exception("MKN not found") + + kernel = v.task_spec.description + results[KN].append({ + "kernel": kernel, + "batch_size": M, + "median": v.median + }) + + rows = int(math.ceil(len(results) / 2)) + fig, axs = plt.subplots(rows, 2, figsize=(12, 5 * rows)) + axs = axs.flatten() + for axs_idx, (shape, data) in enumerate(results.items()): + plt.sca(axs[axs_idx]) + df = pd.DataFrame(data) + sns.lineplot(data=df, + x="batch_size", + y="median", + hue="kernel", + style="kernel", + markers=True, + dashes=False, + palette="Dark2") + plt.title(f"Shape: {shape}") + plt.ylabel("time (median, s)") + plt.tight_layout() + plt.savefig("graph_machete_bench.pdf") diff --git a/benchmarks/kernels/requirements.txt b/benchmarks/kernels/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..1411a4a0b5ab886adfb744e685d150151ab10023 --- /dev/null +++ b/benchmarks/kernels/requirements.txt @@ -0,0 +1 @@ +pandas \ No newline at end of file diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..25ec9d602862742369152ad31a7b90d171980431 --- /dev/null +++ b/benchmarks/kernels/weight_shapes.py @@ -0,0 +1,43 @@ +# Weight Shapes are in the format +# ([K, N], TP_SPLIT_DIM) +# Example: +# A shape of ([14336, 4096], 0) indicates the following GEMM shape, +# - TP1 : K = 14336, N = 4096 +# - TP2 : K = 7168, N = 4096 +# A shape of ([4096, 6144], 1) indicates the following GEMM shape, +# - TP1 : K = 4096, N = 6144 +# - TP4 : K = 4096, N = 1536 + +# TP1 shapes +WEIGHT_SHAPES = { + "mistralai/Mistral-7B-v0.1": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-7b-hf": [ + ([4096, 12288], 1), + ([4096, 4096], 0), + ([4096, 22016], 1), + ([11008, 4096], 0), + ], + "meta-llama/Llama-3-8b": [ + ([4096, 6144], 1), + ([4096, 4096], 0), + ([4096, 28672], 1), + ([14336, 4096], 0), + ], + "meta-llama/Llama-2-13b-hf": [ + ([5120, 15360], 1), + ([5120, 5120], 0), + ([5120, 27648], 1), + ([13824, 5120], 0), + ], + "meta-llama/Llama-2-70b-hf": [ + ([8192, 10240], 1), + ([8192, 8192], 0), + ([8192, 57344], 1), + ([28672, 8192], 0), + ], +} diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index f491c90d0683e6497a4b63cc62b678c1ce2f84a0..8c5cd454fbbee09e157ea428352446899292ef7e 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -6,7 +6,7 @@ TOKENS=$2 docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ - ghcr.io/huggingface/text-generation-inference:1.4.0 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $MODEL \ --sharded false \ --max-input-length 1024 \ diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index 203699e9a8d0607569601daee74aca2845ac0386..d16d6f9fba44213c1a85d2c0ce12d4ac5bd54b84 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -16,7 +16,6 @@ def main(args): enforce_eager=True, enable_prefix_caching=True, tensor_parallel_size=args.tensor_parallel_size, - use_v2_block_manager=args.use_v2_block_manager, ) sampling_params = SamplingParams(temperature=0, max_tokens=args.output_len) @@ -56,8 +55,5 @@ if __name__ == "__main__": parser.add_argument('--enable-prefix-caching', action='store_true', help='enable prefix caching') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2') args = parser.parse_args() main(args) diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 3ba3a2b6a93cddbbdccc5707cc09ce31de249238..7237d246ddf552143809f081d21f59316fceddd7 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -1,3 +1,7 @@ +include(FetchContent) + +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) # @@ -81,14 +85,39 @@ else() message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.") endif() -message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") - -list(APPEND LIBS "numa") - - # -# Define extension targets +# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms) # +if (AVX512_FOUND AND NOT AVX512_DISABLED) + FetchContent_Declare( + oneDNN + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git + GIT_TAG v3.5.3 + GIT_PROGRESS TRUE + GIT_SHALLOW TRUE + ) + + set(ONEDNN_LIBRARY_TYPE "STATIC") + set(ONEDNN_BUILD_DOC "OFF") + set(ONEDNN_BUILD_EXAMPLES "OFF") + set(ONEDNN_BUILD_TESTS "OFF") + set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") + set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") + set(ONEDNN_BUILD_GRAPH "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_ITT_TASKS "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + + FetchContent_MakeAvailable(oneDNN) + + list(APPEND LIBS dnnl) +endif() + +message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") + +list(APPEND LIBS numa) # # _C extension @@ -102,6 +131,16 @@ set(VLLM_EXT_SRC "csrc/cpu/pos_encoding.cpp" "csrc/cpu/torch_bindings.cpp") +if (AVX512_FOUND AND NOT AVX512_DISABLED) + set(VLLM_EXT_SRC + "csrc/cpu/quant.cpp" + ${VLLM_EXT_SRC}) +endif() + +# +# Define extension targets +# + define_gpu_extension_target( _C DESTINATION vllm @@ -114,4 +153,3 @@ define_gpu_extension_target( ) message(STATUS "Enabling C extension.") -add_dependencies(default _C) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index fc75fabb6d39bffe5e8cc842997276b7e0a3ea7b..6bb19f42cdbf0831edf40ce7c1966ab2a6fe6a93 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -138,10 +138,181 @@ macro(string_to_ver OUT_VER IN_STR) string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) endmacro() +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() + +# +# Extract unique CUDA architectures from a list of compute capabilities codes in +# the form `[]`, convert them to the form sort +# `.`, dedupes them and then sorts them in ascending order and +# stores them in `OUT_ARCHES`. +# +# Example: +# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" +# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) +# OUT_ARCHES="7.5;...;9.0" +function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) + set(_CUDA_ARCHES) + foreach(_ARCH ${CUDA_ARCH_FLAGS}) + string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) + if (_COMPUTE) + set(_COMPUTE ${CMAKE_MATCH_1}) + endif() + + string_to_ver(_COMPUTE_VER ${_COMPUTE}) + list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHES) + list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) + set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) +endfunction() + +# +# For a specific file set the `-gencode` flag in compile options conditionally +# for the CUDA language. +# +# Example: +# set_gencode_flag_for_srcs( +# SRCS "foo.cu" +# ARCH "compute_75" +# CODE "sm_75") +# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for +# `foo.cu` (only for the CUDA language). +# +macro(set_gencode_flag_for_srcs) + set(options) + set(oneValueArgs ARCH CODE) + set(multiValueArgs SRCS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) + set_property( + SOURCE ${arg_SRCS} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:${_FLAG}>" + ) + + message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") +endmacro(set_gencode_flag_for_srcs) + +# +# For a list of source files set the `-gencode` flags in the files specific +# compile options (specifically for the CUDA language). +# +# arguments are: +# SRCS: list of source files +# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` +# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built +# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS +# that is larger than BUILD_PTX_FOR_ARCH. +# +macro(set_gencode_flags_for_srcs) + set(options) + set(oneValueArgs BUILD_PTX_FOR_ARCH) + set(multiValueArgs SRCS CUDA_ARCHS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + foreach(_ARCH ${arg_CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_ARCH}" + CODE "sm_${_ARCH}") + endforeach() + + if (${arg_BUILD_PTX_FOR_ARCH}) + list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) + if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) + string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_PTX_ARCH}" + CODE "compute_${_PTX_ARCH}") + endif() + endif() +endmacro() + +# +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the +# `TGT_CUDA_ARCHS` list of gencodes. +# The loose intersection is defined as: +# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } +# where `<=` is the version comparison operator. +# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version +# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. +# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is +# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add +# 9.0a to the result. +# The result is stored in `OUT_CUDA_ARCHS`. +# +# Example: +# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" +# TGT_CUDA_ARCHS="8.0;8.9;9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" +# +function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) + list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) + + # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should + # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS + set(_CUDA_ARCHS) + if ("9.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + set(_CUDA_ARCHS "9.0a") + endif() + endif() + + list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + + # for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is + # less or eqault to ARCH + foreach(_ARCH ${CUDA_ARCHS}) + set(_TMP_ARCH) + foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) + set(_TMP_ARCH ${_SRC_ARCH}) + else() + break() + endif() + endforeach() + if (_TMP_ARCH) + list(APPEND _CUDA_ARCHS ${_TMP_ARCH}) + endif() + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHS) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) +endfunction() + # # Override the GPU architectures detected by cmake/torch and filter them by # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in -# `GPU_ARCHES`. +# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set +# the architectures on a per file basis. # # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. # @@ -179,109 +350,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() - - elseif(${GPU_LANG} STREQUAL "CUDA") - # - # Setup/process CUDA arch flags. - # - # The torch cmake setup hardcodes the detected architecture flags in - # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis. - # So, all the `-gencode` flags need to be extracted and removed from - # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. - # Since it's not possible to use `target_compiler_options` for adding target - # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property - # must be used instead. This requires repackaging the architecture flags - # into a format that cmake expects for `CUDA_ARCHITECTURES`. - # - # This is a bit fragile in that it depends on torch using `-gencode` as opposed - # to one of the other nvcc options to specify architectures. - # - # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override - # detected architectures. - # - message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - - # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` - string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified - # and passed back via the `CUDA_ARCHITECTURES` property. - string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # If this error is triggered, it might mean that torch has changed how it sets - # up nvcc architecture code generation flags. - if (NOT _CUDA_ARCH_FLAGS) - message(FATAL_ERROR - "Could not find any architecture related code generation flags in " - "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})") - endif() - - message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}") - - # Initialize the architecture lists to empty. - set(${GPU_ARCHES}) - - # Process each `gencode` flag. - foreach(_ARCH ${_CUDA_ARCH_FLAGS}) - # For each flag, extract the version number and whether it refers to PTX - # or native code. - # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding - # for that match. - - string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) - if (_COMPUTE) - set(_COMPUTE ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH}) - if (_SM) - set(_SM ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH}) - if (_CODE) - set(_CODE ${CMAKE_MATCH_1}) - endif() - - # Make sure the virtual architecture can be matched. - if (NOT _COMPUTE) - message(FATAL_ERROR - "Could not determine virtual architecture from: ${_ARCH}.") - endif() - - # One of sm_ or compute_ must exist. - if ((NOT _SM) AND (NOT _CODE)) - message(FATAL_ERROR - "Could not determine a codegen architecture from: ${_ARCH}.") - endif() - - if (_SM) - # -real suffix let CMake to only generate elf code for the kernels. - # we want this, otherwise the added ptx (default) will increase binary size. - set(_VIRT "-real") - set(_CODE_ARCH ${_SM}) - else() - # -virtual suffix let CMake to generate ptx code for the kernels. - set(_VIRT "-virtual") - set(_CODE_ARCH ${_CODE}) - endif() - - # Check if the current version is in the supported arch list. - string_to_ver(_CODE_VER ${_CODE_ARCH}) - if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST) - message(STATUS "discarding unsupported CUDA arch ${_VER}.") - continue() - endif() - - # Add it to the arch list. - list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}") - endforeach() endif() - message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}") endmacro() # @@ -355,17 +424,19 @@ function (define_gpu_extension_target GPU_MOD_NAME) target_include_directories(${GPU_MOD_NAME} PRIVATE csrc ${GPU_INCLUDE_DIRECTORIES}) - target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${torch_python_LIBRARY} - ${GPU_LIBRARIES}) + target_link_libraries(${GPU_MOD_NAME} PRIVATE torch ${GPU_LIBRARIES}) # Don't use `TORCH_LIBRARIES` for CUDA since it pulls in a bunch of # dependencies that are not necessary and may not be installed. if (GPU_LANGUAGE STREQUAL "CUDA") + if ("${CUDA_CUDA_LIB}" STREQUAL "") + set(CUDA_CUDA_LIB "${CUDA_CUDA_LIBRARY}") + endif() target_link_libraries(${GPU_MOD_NAME} PRIVATE ${CUDA_CUDA_LIB} ${CUDA_LIBRARIES}) else() target_link_libraries(${GPU_MOD_NAME} PRIVATE ${TORCH_LIBRARIES}) endif() - install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION}) + install(TARGETS ${GPU_MOD_NAME} LIBRARY DESTINATION ${GPU_DESTINATION} COMPONENT ${GPU_MOD_NAME}) endfunction() \ No newline at end of file diff --git a/collect_env.py b/collect_env.py index 244e4ddd5aed57df9c941f2247cce689b379c90d..80403d576d78f906181289cb52d90f3b07daa027 100644 --- a/collect_env.py +++ b/collect_env.py @@ -66,6 +66,8 @@ DEFAULT_CONDA_PATTERNS = { "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } DEFAULT_PIP_PATTERNS = { @@ -79,6 +81,8 @@ DEFAULT_PIP_PATTERNS = { "nccl", "transformers", "zmq", + "nvidia", + "pynvml", } @@ -263,12 +267,16 @@ def get_neuron_sdk_version(run_lambda): def get_vllm_version(): - try: - import vllm - return vllm.__version__ - except ImportError: - return 'N/A' + from vllm import __version__, __version_tuple__ + + if __version__ == "dev": + return "N/A (dev)" + if len(__version_tuple__) == 4: # dev build + git_sha = __version_tuple__[-1][1:] # type: ignore + return f"{__version__} (git sha: {git_sha}" + + return __version__ def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. @@ -280,9 +288,14 @@ def summarize_vllm_build_flags(): def get_gpu_topo(run_lambda): + output = None + if get_platform() == 'linux': - return run_and_read_all(run_lambda, 'nvidia-smi topo -m') - return None + output = run_and_read_all(run_lambda, 'nvidia-smi topo -m') + if output is None: + output = run_and_read_all(run_lambda, 'rocm-smi --showtopo') + + return output # example outputs of CPU infos diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index d15dc2e9a657cf63129fc907e82a133870d543db..53b55307f855637b31389a9dfcea9b7bc39b664a 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -757,6 +757,9 @@ void paged_attention_v1_launcher( case 128: LAUNCH_PAGED_ATTENTION_V1(128); break; + case 160: + LAUNCH_PAGED_ATTENTION_V1(160); + break; case 192: LAUNCH_PAGED_ATTENTION_V1(192); break; @@ -921,6 +924,9 @@ void paged_attention_v2_launcher( case 128: LAUNCH_PAGED_ATTENTION_V2(128); break; + case 160: + LAUNCH_PAGED_ATTENTION_V2(160); + break; case 192: LAUNCH_PAGED_ATTENTION_V2(192); break; diff --git a/csrc/attention/attention_kernels_opt_tc.cu b/csrc/attention/attention_kernels_opt_tc.cu new file mode 100644 index 0000000000000000000000000000000000000000..62d0b4ac3eae73da22be0e600e172b5f18568ecd --- /dev/null +++ b/csrc/attention/attention_kernels_opt_tc.cu @@ -0,0 +1,1160 @@ +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "../quantization/fp8/amd/quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#include "static_switch_tc.h" +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +inline std::string get_device_name() +{ + hipDeviceProp_t props{}; + int device; + auto status = hipGetDevice(&device); + if(status != hipSuccess) + { + return std::string(); + } + + status = hipGetDeviceProperties(&props, device); + if(status != hipSuccess) + { + return std::string(); + } + const std::string raw_name(props.gcnArchName); + return raw_name.substr(0, raw_name.find(':')); // str.substr(0, npos) returns str. +} +static inline int get_env_(const char *env_var) { + if (char *value = std::getenv(env_var)) { + return atoi(value); + } + return 0; +} + +static const int PA_REUSE_KV_TIMES = get_env_("PA_REUSE_KV_TIMES"); +static const int PA_BLOCK_SIZE = get_env_("PA_BLOCK_SIZE"); +static const int PA_PRINT_PARAM = get_env_("PA_PRINT_PARAM"); +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +using half4_t = __attribute__( (__vector_size__(4 * sizeof(_Float16)) )) _Float16; +using v4bh = __attribute__( (__vector_size__(4 * sizeof(short)) )) short; +using float4_t = __attribute__( (__vector_size__(4 * sizeof(float)) )) float; +struct half4x2{ + half4_t data[2]; +}; + +template +inline __device__ void float4_2_half4(half4_t& dst,const float4_t& src) +{ + if constexpr(is_half){ + #pragma unroll + for(int i=0;i<4;i++){ + dst[i]=src[i]; + } + } + else{ + __nv_bfloat16* out = reinterpret_cast<__nv_bfloat16 *>(&dst); + #pragma unroll + for(int i=0;i<4;i++){ + out[i]=__float2bfloat16(src[i]); + } + } +} + +template +inline __device__ void v_mmac_f32_16x16x16_f16(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) +{ + + if constexpr (is_half){ + asm volatile("v_mmac_f32_16x16x16_f16 %0, %1, %2, %0" : + "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); + } + else{ + asm volatile("v_mmac_f32_16x16x16_bf16 %0, %1, %2, %0" : + "=v"(reg_c) : "v"(reg_a), "v"(reg_b), "0"(reg_c)); + } +} + +template +inline __device__ void builtin_amdgcn_mmac(const half4_t& reg_a, const half4_t& reg_b, float4_t& reg_c) +{ + if constexpr (use_vmac){v_mmac_f32_16x16x16_f16(reg_a,reg_b,reg_c);} + else{ + if constexpr (is_half){reg_c=__builtin_amdgcn_mmac_f32_16x16x16f16(reg_a,reg_b,reg_c);} + else{ + reg_c=__builtin_amdgcn_mmac_f32_16x16x16bf16(*(v4bh*)®_a,*(v4bh*)®_b,reg_c); + } + } +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel_TC( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_heads, + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.z; + const int partition_idx = blockIdx.y; + const int max_num_partitions = gridDim.y; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = __builtin_amdgcn_readfirstlane(seq_lens[seq_idx]); + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + constexpr bool is_half = std::is_same::value; + static_assert(HEAD_SIZE<=4*NUM_THREADS,"HEAD_SIZE<=4*NUM_THREADS"); + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + const int partition_size = USE_PARTITIONING ? PARTITION_SIZE : num_seq_blocks * BLOCK_SIZE; + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = partition_idx * num_blocks_per_partition;//0,64,128… + const int end_block_idx =MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks);//64,128,192… + const int num_blocks = end_block_idx - start_block_idx;//64 or 1-63 + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE;//0,1024,2048… + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len);//1024,2048,3072… + const int num_tokens = end_token_idx - start_token_idx;//1024 or 1-1023 + // divides NUM_THREADS + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;//4 + constexpr int x = 16 / sizeof(cache_t);//8 + const int thread_idx = threadIdx.x; + const int warp_idx = __builtin_amdgcn_readfirstlane(thread_idx / WARP_SIZE); + const int lane = thread_idx % WARP_SIZE; + const int rowid = lane%16; + const int rows = lane/16; + + const int num_queries_per_kv = num_heads / num_kv_heads; + const int num_blocks_per_kv = ((num_queries_per_kv + REUSE_KV_TIMES -1) / REUSE_KV_TIMES); + const int odd_tg_round = (((blockIdx.z * gridDim.y * gridDim.x) + blockIdx.y * gridDim.x) / 128) % 2; + const int mid_x = gridDim.x / 2; + const int blockIdx_shift = (odd_tg_round | (gridDim.x & 1)) ? blockIdx.x : (blockIdx.x < mid_x ? (blockIdx.x + mid_x) : (blockIdx.x - mid_x)); + const int head_idx = (blockIdx_shift / num_blocks_per_kv) * num_queries_per_kv + (blockIdx_shift % num_blocks_per_kv) * REUSE_KV_TIMES; + //const int head_idx=(blockIdx.x / num_blocks_per_kv) * num_queries_per_kv + (blockIdx.x % num_blocks_per_kv) * REUSE_KV_TIMES; + + int q_boundary=REUSE_KV_TIMES; + if(num_heads < REUSE_KV_TIMES*gridDim.x && (num_blocks_per_kv-1)*REUSE_KV_TIMES == head_idx%num_queries_per_kv) + q_boundary=num_queries_per_kv-(num_blocks_per_kv-1)*REUSE_KV_TIMES; + const int kv_head_idx = head_idx / num_queries_per_kv; + constexpr int reuse_group=(REUSE_KV_TIMES-1)/4+1; + float alibi_slope[reuse_group]={0.f}; + if(alibi_slopes != nullptr){ + for(int i=0;i(q_ptr+i*HEAD_SIZE+thread_idx*8); + } + } + __syncthreads(); + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + scalar_t* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + const cache_t* k_ptr_base = k_cache+kv_head_idx * kv_head_stride+lane*8; + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + + const int64_t physical_block_number = static_cast(block_table[block_idx]); + const cache_t* k_ptr=k_ptr_base + physical_block_number * kv_block_stride; + float4_t qk_vec={0,0,0,0}; + + half4x2 k_vec[2]; + k_vec[0]=*reinterpret_cast(k_ptr); + #pragma unroll + for(int i=0;i<3;i++){ + if(rowid(k_ptr+(i+1)*512); + builtin_amdgcn_mmac(k_vec[i%2].data[0],q_vec.data[0],qk_vec); + builtin_amdgcn_mmac(k_vec[i%2].data[1],q_vec.data[1],qk_vec); + } + //tail + { + if(rowid(k_vec[1].data[0],q_vec.data[0],qk_vec); + v_mmac_f32_16x16x16_f16(k_vec[1].data[1],q_vec.data[1],qk_vec); + } + #pragma unroll + for(int i=0;i=q_boundary)qk_vec[i]=0; + else qk_vec[i]*=scale; + const int token_idx = block_idx * BLOCK_SIZE+rowid; + if(alibi_slope[i] != 0){ + float alibi=alibi_slope[i]* (token_idx - seq_len + 1); + qk_vec[i] += alibi; + } + const bool mask = (token_idx >= seq_len); + if(mask){ + from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , 0.f); + } + else{ + from_float(logits[partition_size*reuse_kv_idx+token_idx - start_token_idx] , qk_vec[i]); + qk_max[i] = fmaxf(qk_max[i], qk_vec[i]); + } + } + } + } + // if(blockIdx.x==0)printf("%d,qkmax=%f\n",threadIdx.x,qk_max[0]); + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. + for(int reuse_kv_idx=0; reuse_kv_idx= 1; mask /= 2) { + qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); + } + if (rowid==0 && reuse_kv_idx%4==rows) { + red_smem[warp_idx] = qk_max_tmp; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max_tmp = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max_tmp = fmaxf(qk_max_tmp, VLLM_SHFL_XOR_SYNC(qk_max_tmp, mask)); + } + // Broadcast the max qk value to all threads. + qk_max_tmp = VLLM_SHFL_SYNC(qk_max_tmp, 0); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(to_float(logits[(reuse_kv_idx * partition_size) + i]) - qk_max_tmp); + from_float(logits[(reuse_kv_idx * partition_size) + i] , val); + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + from_float(logits[(reuse_kv_idx * partition_size) + i] ,to_float(logits[(reuse_kv_idx * partition_size) + i])*inv_sum); + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx_ * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max_tmp; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx_ * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + } + + constexpr int NUM_ROWS_PER_THREAD =DIVIDE_ROUND_UP(HEAD_SIZE, WARP_SIZE);//2 + if constexpr(REUSE_KV_TIMES<=2){ + float accs[REUSE_KV_TIMES][NUM_ROWS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + #pragma unroll + for(int k=0;k(block_table[block_idx]); + const int token_idx = block_idx * BLOCK_SIZE +rows*4; + half4_t logits_vec={0,0,0,0}; + if(rowid<4*q_boundary){ + logits_vec=*reinterpret_cast(logits + rowid/4 * partition_size+token_idx - start_token_idx); + } + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + rows*4+rowid*16; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + #pragma unroll + for(int k=0;k<4;k++){ + int offset=i*1024+k*256; + half4_t v_vec=*reinterpret_cast(v_ptr + offset); + if (block_idx == num_seq_blocks - 1) { + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); + #pragma unroll + for (int j = 0; j < 4; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + float4_t out_vec={0,0,0,0}; + builtin_amdgcn_mmac(v_vec,logits_vec,out_vec); + if(rows==k){ + for(int resuseid=0;resuseid64){ + floatV_t* out_smem = reinterpret_cast(shared_mem); + #pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]); + } + __syncthreads(); + // Lower warps update the output. + if (warp_idx < mid) { + floatV_t tmp=out_smem[thread_idx]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[reuse_kv_idx][i] += tmp[i]; + } + } + __syncthreads(); + } + } + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + (head_idx+reuse_kv_idx) * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane + i * WARP_SIZE; + from_float(*(out_ptr + row_idx), accs[reuse_kv_idx][i]); + } + } + } + } + else{ + constexpr int GROUPS=reuse_group*4; + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[GROUPS][NUM_ROWS_PER_THREAD]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + #pragma unroll + for(int k=0;k(block_table[block_idx]); + const int token_idx = block_idx * BLOCK_SIZE +rows*4; + half4_t logits_vec={0,0,0,0}; + if(rowid(logits + rowid * partition_size+token_idx - start_token_idx); + } + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + rows*4+rowid*16; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + #pragma unroll + for(int k=0;k<4;k++){ + int offset=i*1024+k*256; + half4_t v_vec=*reinterpret_cast(v_ptr + offset); + if (block_idx == num_seq_blocks - 1) { + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); + #pragma unroll + for (int j = 0; j < 4; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + float4_t out_vec={0,0,0,0}; + builtin_amdgcn_mmac(v_vec,logits_vec,out_vec); + for(int g=0;g64){ + __syncthreads(); + using floatV_t = __attribute__( (__vector_size__(NUM_ROWS_PER_THREAD * sizeof(float)) )) float; + // Perform reduction across warps. + + for(int reuse_kv_idx=0; reuse_kv_idx(shared_mem); + #pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + out_smem[(warp_idx - mid) * 64+lane]=*(floatV_t*)(accs[reuse_kv_idx]); + } + __syncthreads(); + // Lower warps update the output. + if (warp_idx < mid) { + floatV_t tmp=out_smem[thread_idx]; + #pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[reuse_kv_idx][i] += tmp[i]; + } + } + __syncthreads(); + } + } + } + if (warp_idx == 0) { + for(int g=0;g +__global__ void paged_attention_v1_kernel_TC( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_heads, + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + #ifdef __gfx928__ + paged_attention_kernel_TC( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_heads,num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); + #endif + } + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_heads, // [num_heads] + const int num_kv_heads, // [num_kv_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + #ifdef __gfx928__ + paged_attention_kernel_TC( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_heads, + num_kv_heads, scale, block_tables, seq_lens, max_num_blocks_per_seq, + alibi_slopes, q_stride, kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); + #endif +} + +// Grid: (num_heads, num_seqs). +template +__global__ __launch_bounds__(256, 1) void paged_attention_v2_reduce_kernel_opt_tc( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; + #pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + #pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + + +#define LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel_TC), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel_TC \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_heads,num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +void get_numberthread_and_reuse_kv_v1(int& num_thread,int& reusekv,int batchsize,int seq,int qheads,int kvheads){ + //mha + reusekv=1; + if(qheads==kvheads){ + //llama 7B ,其他模型未可知 + if(seq<=16||batchsize>=32)num_thread=64; + else if(batchsize<=2)num_thread=256; + else if(batchsize<8)num_thread=128; + else num_thread=64; + return; + } + // mqa + if(qheads>kvheads*4){ + if(seq<64){ + if(batchsize<=64){reusekv=1;num_thread=64;} + else if(batchsize<128){reusekv=2;num_thread=64;} + else {reusekv=4;num_thread=64;} + } + else if(seq<=400){ + if(batchsize<16){reusekv=1;num_thread=256;} + else if(batchsize<64){reusekv=2;num_thread=256;} + else if(batchsize<=128){ + reusekv=4; + if(qheads%7==0)num_thread=64;//qwen7b + else num_thread=256;//llama70b + } + else {reusekv=8;num_thread=64;} + } + else if(seq<=1000){ + if(batchsize<16){reusekv=1;num_thread=256;} + else if(qheads%7==0&&batchsize<=128){//qwen7b + if(batchsize<64){reusekv=4;num_thread=256;} + else{reusekv=4;num_thread=64;} + } + else if(batchsize<=64){reusekv=4;num_thread=256;} + else {reusekv=8;num_thread=128;} + } + else if(seq<3900) {reusekv=8;num_thread=256;} + else if(seq<7800) {reusekv=4;num_thread=256;} + else {reusekv=2;num_thread=256;} + return; + } + + if(qheads/kvheads >4 && seq<3900)reusekv=8; + else if(qheads/kvheads >2 && seq<7800)reusekv=4; + else if(qheads/kvheads >=2 && seq<15600)reusekv=2; + + if(seq<=64){ + num_thread=64; + if(batchsize<=64)reusekv=1; + } + else num_thread=256; +} + +// TODO(woosuk): Tune NUM_THREADS. +template +void paged_attention_v1_launcher_opt_tc( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + int num_threads = 128; + // printf("paged_attention_v1\n"); + if (num_heads != num_kv_heads) { + num_threads = 256; + } + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + int padded_max_seq_len = DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ + constexpr int HEAD_SIZE=128; + constexpr static int use_vmac = false; + int reusekv, num_thread; + get_numberthread_and_reuse_kv_v1(num_thread,reusekv,num_seqs,padded_max_seq_len,num_heads,num_kv_heads); + if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; + if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; + REUSEKV_SWITCH(reusekv,[&] { + NUM_THREADS_SWITCH(num_thread , [&] { + //constexpr int NUM_THREADS = WARP_SIZE * REUSE_KV_TIMES; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int logits_size = REUSE_KV_TIMES * padded_max_seq_len * 2; + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + if (NUM_WARPS==64)outputs_size=0; + int shared_mem_size = ::max(logits_size, outputs_size); + dim3 grid((num_heads/num_kv_heads + REUSE_KV_TIMES - 1) / REUSE_KV_TIMES*num_kv_heads, 1,num_seqs); + dim3 block(NUM_THREADS); + if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", + reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs); + LAUNCH_PAGED_ATTENTION_V1_TC(HEAD_SIZE); + }); + }); + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher_opt_tc( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1_opt( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v1_opt_tc( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| + block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ + paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, + scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, + k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, + blocksparse_block_size,blocksparse_head_sliding_step); + } + else{ + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) + } +} + +#define LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE) \ + hipLaunchKernelGGL( \ + (vllm::paged_attention_v2_kernel_TC< \ + T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, KV_DTYPE, \ + IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, PARTITION_SIZE>), \ + dim3(grid), dim3(block), shared_mem_size, stream, exp_sums_ptr, \ + max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, \ + num_heads, num_kv_heads, scale, block_tables_ptr, seq_lens_ptr, \ + max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \ + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); \ + hipLaunchKernelGGL( \ + (vllm::paged_attention_v2_reduce_kernel_opt_tc), \ + dim3(reduce_grid), dim3(block), reduce_shared_mem_size, stream, out_ptr, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int batchsize,int max_num_partitions,int qheads,int kvheads){ + reusekv=1; + int blocks=batchsize*qheads*max_num_partitions; + if(qheads==kvheads){ + if(blocks<=80||blocks>8000){num_thread=256;} + else if(blocks<=160){num_thread=128;} + else num_thread=64; + return; + } + if(qheads/kvheads>8&&blocks>4000){ + reusekv=16; + if(blocks>40000)num_thread=64; + else num_thread=128; + } + else if(qheads/kvheads==5||qheads/kvheads==7){ + if(blocks<=160){reusekv=1;num_thread=256;} + else if(blocks<640/5*qheads/kvheads){reusekv=4;num_thread=256;} + else if(blocks<1920){reusekv=8;num_thread=128;} + else {reusekv=8;num_thread=64;} + } + else if(qheads>kvheads*4){ + if(blocks<=128){reusekv=1;num_thread=256;} + else if(blocks<1536){reusekv=4;num_thread=256;} + else if(blocks<6144){reusekv=8;num_thread=128;} + else {reusekv=8;num_thread=64;} + } + else { + if(blocks<=128){reusekv=1;num_thread=256;} + else if(blocks<3000){reusekv=4;num_thread=256;} + else {reusekv=4;num_thread=64;} + } +} + +template +void paged_attention_v2_launcher_opt_tc( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + // printf("paged_attention_v2\n"); + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + dim3 reduce_grid(num_heads, num_seqs); + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + if constexpr(BLOCK_SIZE==16 && IS_BLOCK_SPARSE==false && sizeof(T)==2 && KV_DTYPE==vllm::Fp8KVCacheDataType::kAuto){ + //if(head_size==128&&get_device_name()=="gfx928"){ + constexpr int HEAD_SIZE=128; + constexpr static int use_vmac = false; + int reusekv, num_thread; + get_numberthread_and_reuse_kv_v2(num_thread,reusekv,num_seqs,max_num_partitions,num_heads,num_kv_heads); + if(PA_REUSE_KV_TIMES!=0&&num_heads>num_kv_heads)reusekv=PA_REUSE_KV_TIMES; + if(PA_BLOCK_SIZE!=0)num_thread=PA_BLOCK_SIZE; + REUSEKV_SWITCH(reusekv,[&] { + NUM_THREADS_SWITCH(num_thread , [&] { + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int logits_size = REUSE_KV_TIMES*PARTITION_SIZE * 2; + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + dim3 grid; + grid.x = (num_heads/num_kv_heads + REUSE_KV_TIMES -1)/REUSE_KV_TIMES * num_kv_heads; + grid.y = max_num_partitions; + grid.z = num_seqs; + dim3 block(NUM_THREADS); + int shared_mem_size = ::max(logits_size, outputs_size); + if(PA_PRINT_PARAM)printf("reusekv=%d,num_thread=%d,grid={%d,%d,%d},qhead=%d,kvhead=%d,seq=%d,batch=%d\n", + reusekv,num_thread,grid.x,grid.y,grid.z,num_heads,num_kv_heads,max_seq_len,num_seqs); + LAUNCH_PAGED_ATTENTION_V2_TC(HEAD_SIZE); + }); + }); + } + //} +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher_opt_tc( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2_opt( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v2_opt_tc( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| + block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ + paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, + scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, + k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, + blocksparse_block_size,blocksparse_head_sliding_step); + } + else{ + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 68d2528e11247f9c19a4e3a94342cc72039b234f..e19f1bf8144a501e45315c6159e619900018829f 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -122,7 +122,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { A_vec qk_vec = mul(q[0], k[0]); #pragma unroll for (int ii = 1; ii < N; ++ii) { - qk_vec = fma(q[ii], k[ii], qk_vec); + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); } // Finalize the reduction across lanes. float qk = sum(qk_vec); diff --git a/csrc/attention/static_switch_tc.h b/csrc/attention/static_switch_tc.h new file mode 100644 index 0000000000000000000000000000000000000000..1ec170872d75dc43326e62c84269d779e192d1b9 --- /dev/null +++ b/csrc/attention/static_switch_tc.h @@ -0,0 +1,81 @@ +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define NUM_THREADS_SWITCH(NUM_THREAD, ...) \ + [&] { \ + if (NUM_THREAD == 256) { \ + constexpr static int NUM_THREADS = 256; \ + return __VA_ARGS__(); \ + }else if (NUM_THREAD == 128) { \ + constexpr static int NUM_THREADS = 128; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int NUM_THREADS = 64; \ + return __VA_ARGS__(); \ + } \ + }() + + #define HEADSIZE_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int HEAD_SIZE = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 80) { \ + constexpr static int HEAD_SIZE = 80; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 96) { \ + constexpr static int HEAD_SIZE = 96; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 112) { \ + constexpr static int HEAD_SIZE = 112; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int HEAD_SIZE = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int HEAD_SIZE = 256; \ + return __VA_ARGS__(); \ + } \ + else { \ + TORCH_CHECK(false, "Unsupported head size: ", HEADDIM);\ + } \ + }() + +#define REUSEKV_SWITCH(reusekv,...) \ +[&] { \ + if (reusekv==16){ \ + constexpr static int REUSE_KV_TIMES = 16; \ + return __VA_ARGS__();} \ + else if (reusekv==8){ \ + constexpr static int REUSE_KV_TIMES = 8; \ + return __VA_ARGS__(); \ + }else if (reusekv==4){ \ + constexpr static int REUSE_KV_TIMES = 4; \ + return __VA_ARGS__(); \ + }else if (reusekv==2){ \ + constexpr static int REUSE_KV_TIMES = 2; \ + return __VA_ARGS__(); \ + }else { \ + constexpr static int REUSE_KV_TIMES = 1; \ + return __VA_ARGS__(); \ + } \ +}() + +#define USEVMAC_SWITCH_V1(num_blocks , ...) \ +[&] { \ + if (REUSE_KV_TIMES==1&&(num_blocks >2500 || padded_max_seq_len > 2048)){ \ + constexpr static int use_vmac = false; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static int use_vmac = true; \ + return __VA_ARGS__(); \ + } \ +}() \ No newline at end of file diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f3b2ffaef6cce0b85f25fdd5090a227b581d4d3f --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/core/registration.h b/csrc/core/registration.h index e5396e9a8b1378c69729e4c421d7a8735490668f..4d0ce1c572c1c1ea947db0720ace5e7abe2a5624 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -12,6 +12,11 @@ // could be a macro instead of a literal token. #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + // REGISTER_EXTENSION allows the shared library to be loaded and initialized // via python's import statement. #define REGISTER_EXTENSION(NAME) \ diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index fce816b71a0002f38f0c316f1e7aa1565ff147a3..5658674102970d96315db6c3bbc749a54938dead 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -21,7 +21,7 @@ namespace vllm { // class ScalarType { public: - enum NanRepr : int64_t { + enum NanRepr : uint8_t { NAN_NONE = 0, // nans are not supported NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s @@ -29,33 +29,33 @@ class ScalarType { NAN_REPR_ID_MAX }; - constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa, - int64_t bias, bool finite_values_only = false, + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, + int32_t bias, bool finite_values_only = false, NanRepr nan_repr = NAN_IEEE_754) : exponent(exponent), mantissa(mantissa), - bias(bias), signed_(signed_), + bias(bias), finite_values_only(finite_values_only), nan_repr(nan_repr){}; - static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) { - return ScalarType(true, 0, size_bits - 1, bias); + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits - 1, true, bias); } - static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) { - return ScalarType(false, 0, size_bits, bias); + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { + return ScalarType(0, size_bits, false, bias); } // IEEE 754 compliant floating point type - static constexpr ScalarType float_IEEE754(int64_t exponent, - int64_t mantissa) { + static constexpr ScalarType float_IEEE754(uint8_t exponent, + uint8_t mantissa) { TORCH_CHECK(mantissa > 0 && exponent > 0); - return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); } // IEEE 754 non-compliant floating point type - static constexpr ScalarType float_(int64_t exponent, int64_t mantissa, + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr"); @@ -63,36 +63,121 @@ class ScalarType { TORCH_CHECK(nan_repr != NAN_IEEE_754, "use `float_IEEE754` constructor for floating point types that " "follow IEEE 754 conventions"); - return ScalarType(true, exponent, mantissa, 0, finite_values_only, + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); } - int64_t const exponent; // size of the exponent field (0 for integer types) - int64_t const mantissa; // size of the mantissa field (size of the integer + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer // excluding the sign bit for integer types) - int64_t const bias; // stored values equal value + bias, - // used for quantized type bool const signed_; // flag if the type supports negative numbers (i.e. has a // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type // Extra Floating point info bool const finite_values_only; // i.e. no +/-inf if true NanRepr const nan_repr; // how NaNs are represented // (not applicable for integer types) - int64_t size_bits() const { return mantissa + exponent + is_signed(); } - bool is_signed() const { return signed_; } - bool is_integer() const { return exponent == 0; } - bool is_floating_point() const { return exponent > 0; } - bool is_ieee_754() const { + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, + Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, + finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types( + [](int acc, auto member) -> int { + return acc + member_id_field_width(); + }, + 0); + } + + public: + // unique id for this scalar type that can be computed at compile time for + // c++17 template specialization this is not needed once we migrate to + // c++20 and can pass literal classes as template parameters + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, + "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, + auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) + << bit_offset, + bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + // create a ScalarType from an id, for c++17 template specialization, + // this is not needed once we migrate to c++20 and can pass literal + // classes as template parameters + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & + ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, + std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, + tuple_args); + } + + constexpr int64_t size_bits() const { + return mantissa + exponent + is_signed(); + } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; } - bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } - bool has_infs() const { + constexpr bool has_nans() const { + return is_floating_point() && nan_repr != NAN_NONE; + } + constexpr bool has_infs() const { return is_floating_point() && finite_values_only == false; } - bool has_bias() const { return bias != 0; } + constexpr bool has_bias() const { return bias != 0; } private: double _floating_point_max() const { @@ -132,7 +217,7 @@ class ScalarType { return *reinterpret_cast(&double_raw); } - std::variant _raw_max() const { + constexpr std::variant _raw_max() const { if (is_floating_point()) { return {_floating_point_max()}; } else { @@ -142,7 +227,7 @@ class ScalarType { } } - std::variant _raw_min() const { + constexpr std::variant _raw_min() const { if (is_floating_point()) { TORCH_CHECK(is_signed(), "We currently assume all floating point types are signed"); @@ -169,7 +254,7 @@ class ScalarType { public: // Max representable value for this scalar type. // (accounting for bias if there is one) - std::variant max() const { + constexpr std::variant max() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); @@ -177,7 +262,7 @@ class ScalarType { // Min representable value for this scalar type. // (accounting for bias if there is one) - std::variant min() const { + constexpr std::variant min() const { return std::visit( [this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); @@ -216,7 +301,7 @@ class ScalarType { } } - bool operator==(ScalarType const& other) const { + constexpr bool operator==(ScalarType const& other) const { return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ && finite_values_only == other.finite_values_only && @@ -229,6 +314,8 @@ class ScalarType { // have ScalarType inherit from torch::CustomClassHolder and have a constexpr // constructor at the same time (torch::CustomClassHolder does not have a // constexpr destructor) +// See also: +// https://docs.google.com/document/d/18fBMPuOJ0fY5ZQ6YyrHUppw9FA332CpNtgB6SOIgyuA class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { public: ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias, @@ -241,31 +328,91 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { using Self = ScalarTypeTorch; using SelfPtr = c10::intrusive_ptr; + static void check_size_bits(int64_t size_bits, bool signed_) { + TORCH_CHECK( + size_bits <= + std::numeric_limits().mantissa)>::max(), + "size_bits bit width is too large to be represented"); + } + + static void check_bias(int64_t bias) { + using Bias = decltype(std::declval().bias); + TORCH_CHECK(bias <= std::numeric_limits::max() && + bias >= std::numeric_limits::min(), + "bias too large or small to be represented"); + } + + static void check_exponent(int64_t exponent) { + TORCH_CHECK( + exponent <= + std::numeric_limits().exponent)>::max(), + "exponent bit width is too large to be represented"); + } + + static void check_mantissa(int64_t mantissa) { + TORCH_CHECK( + mantissa <= + std::numeric_limits().mantissa)>::max(), + "mantissa bit width is too large to be represented"); + } + static SelfPtr int_(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); return c10::make_intrusive( ScalarType::int_(size_bits, bias.value_or(0))); } static SelfPtr uint(int64_t size_bits, c10::optional bias) { + check_size_bits(size_bits, true); + check_bias(bias.value_or(0)); return c10::make_intrusive( ScalarType::uint(size_bits, bias.value_or(0))); } static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) { + check_mantissa(mantissa); + check_exponent(exponent); return c10::make_intrusive( ScalarType::float_IEEE754(exponent, mantissa)); } static SelfPtr float_(int64_t exponent, int64_t mantissa, bool finite_values_only, int64_t nan_repr) { + check_mantissa(mantissa); + check_exponent(exponent); return c10::make_intrusive(ScalarType::float_( exponent, mantissa, finite_values_only, NanRepr(nan_repr))); } + // This needs to be implemented and throw a TypeError in order for + // PyTorch's opcheck to work on ops that use ScalarTypes. + int64_t len() const { + throw c10::TypeError({__func__, __FILE__, static_cast(__LINE__)}, + "__len__ not implemented"); + return 0; + } + + // Serialize a ScalarType into a tuple of pairs. Where each pair + // is a (fieldname, value). + // For simplicity, we are just going to convert to a ScalarTypeId. + std::tuple> obj_flatten() const { + return {{"ScalarType", id()}}; + } + + // Deserialize a scalar type that has been serialized by obj_flatten, + // ostensibly from a tuple of (member name, value) pairs, but in reality + // just a ScalarTypeId. + static SelfPtr obj_unflatten( + std::tuple> const& flat_type) { + return c10::make_intrusive( + from_id(std::get<1>(std::get<0>(flat_type)))); + } + template static void bind_readonly_property(torch::class_& cls, std::string const& name, T Base::*field) { - auto getter_func = [field = std::move(field)](SelfPtr const& self) { + auto getter_func_helper = [field = std::move(field)](SelfPtr const& self) { if constexpr (std::is_member_function_pointer_v) { return (self.get()->*field)(); } else { @@ -273,6 +420,18 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { } }; + auto getter_func = [field = std::move(field), + getter_func_helper = std::move(getter_func_helper)]( + SelfPtr const& self) { + auto val = getter_func_helper(self); + // upconvert uint8_t, int32_t etc. to int64_t for python + if constexpr (std::is_integral_v) { + return static_cast(val); + } else { + return val; + } + }; + cls.def_property(name, getter_func); } @@ -325,6 +484,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { self.get()->min()); }); + bind_function(cls, "__len__", &ScalarTypeTorch::len); bind_function(cls, "__str__", &Base::str); bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) { return *self == *other; @@ -333,6 +493,10 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { return "ScalarType." + self.get()->str(); }); + bind_function(cls, "__obj_flatten__", &ScalarTypeTorch::obj_flatten); + bind_static_function(cls, "__obj_unflatten__", + &ScalarTypeTorch::obj_unflatten); + // Bind static functions (convenience constructors) bind_static_function(cls, "int_", &ScalarTypeTorch::int_); bind_static_function(cls, "uint", &ScalarTypeTorch::uint); @@ -341,6 +505,7 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { } }; +using ScalarTypeId = int64_t; using ScalarTypeTorchPtr = c10::intrusive_ptr; // "rust style" names generally following: @@ -380,4 +545,5 @@ static inline constexpr auto kHalf = kFE5M10; static inline constexpr auto kFloat16 = kHalf; static inline constexpr auto kBFloat16 = kFE8M7; +static inline constexpr auto kFloat16Id = kFloat16.id(); }; // namespace vllm diff --git a/csrc/cpu/cpu_types_x86.hpp b/csrc/cpu/cpu_types_x86.hpp index f50620a5287d4d09ec6b271c09d3a0ce9835902a..a325153b470ccdfa114cc54b75f0e36f0dee546b 100644 --- a/csrc/cpu/cpu_types_x86.hpp +++ b/csrc/cpu/cpu_types_x86.hpp @@ -24,8 +24,8 @@ namespace vec_op { #define CPU_KERNEL_GUARD_OUT(NAME) #else #define CPU_KERNEL_GUARD_IN(NAME) \ - std::cout << #NAME << " invoked." << std::endl; -#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl; + RECORD_FUNCTION(#NAME, c10::ArrayRef({})); +#define CPU_KERNEL_GUARD_OUT(NAME) #endif #define FORCE_INLINE __attribute__((always_inline)) inline @@ -106,6 +106,12 @@ struct BF16Vec16 : public Vec { explicit BF16Vec16(const FP32Vec16 &); void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; } + + void save(void* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm256_mask_storeu_epi16(ptr, mask, reg); + } }; #ifdef __AVX512F__ @@ -259,6 +265,30 @@ struct FP32Vec8 : public Vec { void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); } }; +#ifdef __AVX512F__ +struct INT32Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m512i reg; + int32_t values[VEC_ELEM_NUM]; + }; + + __m512i reg; + + explicit INT32Vec16(const void* data_ptr) : reg(_mm512_loadu_epi32(data_ptr)) {} + + void save(int32_t* ptr) const { + _mm512_storeu_epi32(ptr, reg); + } + + void save(int32_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_epi32(ptr, mask, reg); + } +}; +#endif + #ifdef __AVX512F__ struct FP32Vec16 : public Vec { constexpr static int VEC_ELEM_NUM = 16; @@ -277,8 +307,6 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(__m512 data) : reg(data) {} - explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {} - explicit FP32Vec16(const FP32Vec4 &data) : reg((__m512)_mm512_inserti32x4( _mm512_inserti32x4( @@ -297,6 +325,9 @@ struct FP32Vec16 : public Vec { explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {} + explicit FP32Vec16(const INT32Vec16 &v) + : reg(_mm512_cvt_roundepi32_ps(v.reg, _MM_FROUND_TO_NEAREST_INT |_MM_FROUND_NO_EXC)) {} + FP32Vec16 operator*(const FP32Vec16 &b) const { return FP32Vec16(_mm512_mul_ps(reg, b.reg)); } @@ -313,8 +344,40 @@ struct FP32Vec16 : public Vec { return FP32Vec16(_mm512_div_ps(reg, b.reg)); } + FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const { + return FP32Vec16(_mm512_min_ps(max.reg, _mm512_max_ps(min.reg, reg))); + } + + FP32Vec16 max(const FP32Vec16& b) const { + return FP32Vec16(_mm512_max_ps(reg, b.reg)); + } + + FP32Vec16 max(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_max_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b) const { + return FP32Vec16(_mm512_min_ps(reg, b.reg)); + } + + FP32Vec16 min(const FP32Vec16& b, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + return FP32Vec16(_mm512_mask_min_ps(reg, mask, reg, b.reg)); + } + + FP32Vec16 abs() const { + return FP32Vec16(_mm512_abs_ps(reg)); + } + float reduce_sum() const { return _mm512_reduce_add_ps(reg); } + float reduce_max() const { return _mm512_reduce_max_ps(reg); } + + float reduce_min() const { return _mm512_reduce_min_ps(reg); } + template float reduce_sub_sum(int idx) { static_assert(VEC_ELEM_NUM % group_size == 0); constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size)); @@ -323,6 +386,12 @@ struct FP32Vec16 : public Vec { } void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); } + + void save(float* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm512_mask_storeu_ps(ptr, mask, reg); + } }; #else struct FP32Vec16 : public Vec { @@ -433,6 +502,32 @@ struct FP32Vec16 : public Vec { }; #endif +#ifdef __AVX512F__ +struct INT8Vec16: public Vec { + constexpr static int VEC_ELEM_NUM = 16; + union AliasReg { + __m128i reg; + int8_t values[VEC_ELEM_NUM]; + }; + + __m128i reg; + + explicit INT8Vec16(const FP32Vec16& vec) : reg( + _mm512_cvtepi32_epi8(_mm512_cvt_roundps_epi32(vec.reg, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + ) {} + + void save(int8_t* ptr) const { + _mm_storeu_epi8(ptr, reg); + } + + void save(int8_t* ptr, const int elem_num) const { + constexpr uint32_t M = 0xFFFFFFFF; + __mmask16 mask = _cvtu32_mask16(M >> (32 - elem_num)); + _mm_mask_storeu_epi8(ptr, mask, reg); + } +}; +#endif + template struct VecType { using vec_type = void; }; template using vec_t = typename VecType::vec_type; diff --git a/csrc/cpu/dnnl_helper.hpp b/csrc/cpu/dnnl_helper.hpp new file mode 100644 index 0000000000000000000000000000000000000000..024ad4ae43da8c8247ffaa6c042d2f80fec74eb8 --- /dev/null +++ b/csrc/cpu/dnnl_helper.hpp @@ -0,0 +1,168 @@ +#ifndef DNNL_HELPER_HPP +#define DNNL_HELPER_HPP + +#include + +#include "oneapi/dnnl/dnnl.hpp" + +namespace { +template +struct DNNLType { + static constexpr dnnl::memory::data_type type = + dnnl::memory::data_type::undef; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s8; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::s32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::f32; +}; + +template <> +struct DNNLType { + static constexpr dnnl::memory::data_type type = dnnl::memory::data_type::bf16; +}; + +template +constexpr inline dnnl::memory::data_type get_dnnl_type() { + return DNNLType>::type; +} +}; // namespace + +template +class DNNLPrimitiveHelper { + public: + // I8 input GEMM kernel (C = a_scales * A @ (b_scales * B^T) + bias) + // A: [M, K], row-major + // B: [K, N], column-major + // C: [M, N], row-major + // bias: [N], row-major, optional + // a_scales: [MS] + // b_scales: [NS] + // Note: Due to the limitation of oneDNN + // (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is + // not supported. + template + static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c, + const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N, + dnnl_dim_t K, const float* a_scales, + const float* b_scales, dnnl_dim_t MS, + dnnl_dim_t NS) { + auto&& OutputType = get_dnnl_type(); + auto&& BiasType = get_dnnl_type(); + + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {1, K}); + dnnl::memory::desc c_md({M, N}, OutputType, {N, 1}); + + dnnl::primitive_attr attr; + if constexpr (!InputNoScale) { + if (MS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_SRC, 0); + } else { + // per-token + TORCH_CHECK(false, "per-token quantization is unsupported."); + } + } + + if (NS == 1) { + // per-tensor + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0); + } else { + // per-channel + attr.set_scales_mask(DNNL_ARG_WEIGHTS, 2); + } + + dnnl::matmul::primitive_desc matmul_pd; + if (bias) { + dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1}); + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + bias_md, c_md, attr); + } else { + matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md, + c_md, attr); + } + dnnl::matmul matmul(matmul_pd); + + auto& engine = default_engine(); + + dnnl::memory a_m(a_md, engine, (void*)a); + dnnl::memory b_m(b_md, engine, (void*)b); + dnnl::memory c_m(c_md, engine, (void*)c); + dnnl::memory a_scales_m({{MS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)a_scales); + dnnl::memory b_scales_m({{NS}, dnnl::memory::data_type::f32, {1}}, engine, + (void*)b_scales); + + auto& stream = default_stream(); + if constexpr (InputNoScale) { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } else { + if (bias) { + dnnl::memory::desc bias_md({N}, BiasType, {1}); + dnnl::memory bias_m(bias_md, engine, (void*)bias); + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_BIAS, bias_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } else { + matmul.execute( + stream, { + {DNNL_ARG_SRC, a_m}, + {DNNL_ARG_WEIGHTS, b_m}, + {DNNL_ARG_DST, c_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m}, + }); + } + } + stream.wait(); + } + + private: + static dnnl::engine& default_engine() { + static dnnl::engine engine(dnnl::engine::kind::cpu, 0); + return engine; + } + + static dnnl::stream& default_stream() { + static dnnl::stream stream(default_engine()); + return stream; + } +}; + +#endif diff --git a/csrc/cpu/quant.cpp b/csrc/cpu/quant.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b493fd793818acdabc79f06b34d653d759da558f --- /dev/null +++ b/csrc/cpu/quant.cpp @@ -0,0 +1,600 @@ +#include "cpu_types.hpp" +#include "dnnl_helper.hpp" + +namespace { +template +struct KernelVecType { + using load_vec_type = void; + using azp_adj_load_vec_type = void; + using cvt_vec_type = void; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::FP32Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +template <> +struct KernelVecType { + using load_vec_type = vec_op::BF16Vec16; + using azp_adj_load_vec_type = vec_op::INT32Vec16; + using cvt_vec_type = vec_op::FP32Vec16; +}; + +#ifdef __AVX512F__ +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t inv_scale(1.0 / *scale); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + cvt_vec_t zp_vec; + if constexpr (AZP) { + zp_vec = cvt_vec_t(static_cast(*azp)); + } + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = elems_fp32 * inv_scale; + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + zp_vec; + } + + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + using load_vec_t = typename KernelVecType::load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + constexpr float i8_min = + static_cast(std::numeric_limits::min()); + constexpr float i8_max = + static_cast(std::numeric_limits::max()); + const cvt_vec_t i8_min_vec(i8_min); + const cvt_vec_t i8_max_vec(i8_max); + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t max_value(std::numeric_limits::lowest()); + cvt_vec_t min_value(std::numeric_limits::max()); + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + + if (j + vec_elem_num == hidden_size) { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32); + min_value = min_value.min(elems_fp32); + } else { + max_value = max_value.max(elems_fp32.abs()); + } + } else { + if constexpr (AZP) { + max_value = max_value.max(elems_fp32, hidden_size - j); + min_value = min_value.min(elems_fp32, hidden_size - j); + } else { + max_value = max_value.max(elems_fp32.abs(), hidden_size - j); + } + } + } + + float scale_val, azp_val; + if constexpr (AZP) { + float max_scalar = max_value.reduce_max(); + float min_scalar = min_value.reduce_min(); + scale_val = (max_scalar - min_scalar) / 255.0f; + azp_val = std::nearbyint(-128.0f - min_scalar / scale_val); + azp[i] = static_cast(azp_val); + scale[i] = scale_val; + } else { + scale_val = max_value.reduce_max() / 127.0f; + scale[i] = scale_val; + } + + const cvt_vec_t inv_scale(1.0 / scale_val); + const cvt_vec_t azp_vec(azp_val); + + { + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j); + } + + load_vec_t elems(input + i * hidden_size + j); + cvt_vec_t elems_fp32(elems); + elems_fp32 = (elems_fp32 * inv_scale); + + if constexpr (AZP) { + elems_fp32 = elems_fp32 + azp_vec; + } + elems_fp32 = elems_fp32.clamp(i8_min_vec, i8_max_vec); + vec_op::INT8Vec16 elems_int8(elems_fp32); + elems_int8.save(output + i * hidden_size + j, hidden_size - j); + } + } +} + +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_output_scale_impl) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + cvt_vec_t a_scale_vec(a_scale); + cvt_vec_t b_scale_vec(*b_scale); + cvt_vec_t scale_vec = a_scale_vec * b_scale_vec; + + int j = 0; + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + azp_adj_load_vec_t azp_adj_vec(azp_with_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + + if constexpr (PerChannel) { + b_scale_vec = cvt_vec_t(b_scale + j); + scale_vec = b_scale_vec * a_scale_vec; + } + + elems_fp32 = elems_fp32 - scale_vec * azp_adj_fp32; + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + CPU_KERNEL_GUARD_IN(dynamic_quant_epilogue) + using load_vec_t = typename KernelVecType::load_vec_type; + using azp_adj_load_vec_t = + typename KernelVecType::azp_adj_load_vec_type; + using cvt_vec_t = typename KernelVecType::cvt_vec_type; + constexpr int vec_elem_num = load_vec_t::VEC_ELEM_NUM; + + #pragma omp parallel for + for (int i = 0; i < num_tokens; ++i) { + int j = 0; + cvt_vec_t token_scale_vec(a_scale[i]); + cvt_vec_t token_zp_scale_vec; + if constexpr (AZP) { + float zp_scale_val = a_scale[i] * static_cast(azp[i]); + if constexpr (!PerChannel) { + zp_scale_val *= *b_scale; + } + token_zp_scale_vec = cvt_vec_t(zp_scale_val); + } + + for (; j < hidden_size - vec_elem_num; j += vec_elem_num) { + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j); + } + + cvt_vec_t elems_fp32(input + i * hidden_size + j); + elems_fp32 = elems_fp32 * token_scale_vec; + + if constexpr (AZP) { + azp_adj_load_vec_t azp_adj_vec(azp_adj + j); + cvt_vec_t azp_adj_fp32(azp_adj_vec); + azp_adj_fp32 = azp_adj_fp32 * token_zp_scale_vec; + + if constexpr (PerChannel) { + cvt_vec_t b_scale_vec(b_scale + j); + azp_adj_fp32 = azp_adj_fp32 * b_scale_vec; + } + + elems_fp32 = elems_fp32 - azp_adj_fp32; + } + + if constexpr (Bias) { + load_vec_t bias_vec(bias + j); + cvt_vec_t bias_vec_fp32(bias_vec); + elems_fp32 = elems_fp32 + bias_vec_fp32; + } + + load_vec_t elems_out(elems_fp32); + elems_out.save(output + i * hidden_size + j, hidden_size - j); + } +} +#else +template +void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + const float* scale, const int32_t* azp, + const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output, + float* scale, int32_t* azp, + const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_scaled_int8_quant_impl requires AVX512 support.") +} + +template +void static_quant_epilogue(const float* input, scalar_t* output, + const float a_scale, const float* b_scale, + const int32_t* azp_with_adj, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "static_quant_epilogue requires AVX512 support.") +} + +template +void dynamic_quant_epilogue(const float* input, scalar_t* output, + const float* a_scale, const float* b_scale, + const int32_t* azp, const int32_t* azp_with_adj, + const scalar_t* bias, const int num_tokens, + const int hidden_size) { + TORCH_CHECK(false, "dynamic_quant_epilogue requires AVX512 support.") +} +#endif +} // namespace + +void int8_scaled_mm(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() && + bias->dim() == 1); + } + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm", [&] { + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Ideally we want to fuse the GEMM and the scale procedure with oneDNN + // JIT, the intermediate data is cached in registers or L1. But for now + // the oneDNN GEMM code generation only supports two quantization + // patterns: per-tensor or per-output-channel of weight. + // So we have to apply the per-token scale with a 'epilogue'. In C=s_a * + // s_b * (A@B) + bias, the C_inter = s_b * (A@B) is computed by oneDNN + // GEMM, then the per-token scale (and bias) is applied with the epilogue + // C=s_a * C_inter + bias. + torch::Tensor tmp_fp32_out = + torch::empty_like(c, ::at::ScalarType::Float); + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter + bias + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Compute C=s_a * C_inter + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), nullptr, nullptr, nullptr, nullptr, + c.size(0), c.size(1)); + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + bias->data_ptr(), a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } else { + // Compute C=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), c.data_ptr(), + nullptr, a.size(0), b.size(1), a.size(1), + a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + } + }); +} + +void int8_scaled_mm_azp(torch::Tensor& c, // [M, OC], row-major + const torch::Tensor& a, // [M, IC], row-major + const torch::Tensor& b, // [IC, OC], column-major + const torch::Tensor& a_scales, // [1] or [M] + const torch::Tensor& b_scales, // [1] or [OC] + const torch::Tensor& azp_adj, // [OC] + const c10::optional& azp, // [1] or [M] + const c10::optional& bias // [OC] +) { + CPU_KERNEL_GUARD_IN(cutlass_scaled_mm_azp) + // Checks for conformality + TORCH_CHECK(a.dtype() == torch::kInt8 && b.dtype() == torch::kInt8, + "int8_scaled_mm_azp only supports INT8 inputs.") + TORCH_CHECK(a.dim() == 2 && b.dim() == 2 && c.dim() == 2); + TORCH_CHECK(c.size(0) == a.size(0) && a.size(1) == b.size(0) && + b.size(1) == c.size(1)); + TORCH_CHECK(a_scales.numel() == 1 || a_scales.numel() == a.size(0)); + TORCH_CHECK(b_scales.numel() == 1 || b_scales.numel() == b.size(1)); + + // Check for strides and alignment + TORCH_CHECK(a.stride(1) == 1 && c.stride(1) == 1); // Row-major + TORCH_CHECK(b.stride(0) == 1); // Column-major + TORCH_CHECK(c.stride(0) % 16 == 0 && + b.stride(1) % 16 == 0); // 16 Byte Alignment + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + + if (bias) { + TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous()); + } + if (azp) { + TORCH_CHECK(azp->numel() == a.size(0) && azp->is_contiguous()); + } + TORCH_CHECK(azp_adj.numel() == b.size(1) && azp_adj.is_contiguous()); + + // azp & bias types + TORCH_CHECK(azp_adj.dtype() == torch::kInt32); + TORCH_CHECK(!azp || azp->dtype() == torch::kInt32); + TORCH_CHECK(!bias || bias->dtype() == c.dtype(), + "currently bias dtype must match output dtype ", c.dtype()); + + VLLM_DISPATCH_FLOATING_TYPES(c.scalar_type(), "int8_scaled_mm_azp", [&] { + torch::Tensor tmp_fp32_out = torch::empty_like(c, ::at::ScalarType::Float); + if (a_scales.numel() != 1) { + // per-token + // Note: oneDNN doesn't support per-token activation quantization + // Compute C_inter=s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), nullptr, b_scales.data_ptr(), 0, b_scales.numel()); + if (bias.has_value()) { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + bias + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), + bias->data_ptr(), c.size(0), c.size(1)); + } + } else { + // Compute C=s_a * C_inter - s_a * s_b * azp * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } else { + // Per-Tensor + dynamic_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + a_scales.data_ptr(), b_scales.data_ptr(), + azp->data_ptr(), azp_adj.data_ptr(), nullptr, + c.size(0), c.size(1)); + } + } + } else { + // per-tensor + if (bias.has_value()) { + // Compute C_inter=s_a * s_b * (A@B) + bias + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), bias->data_ptr(), + a.size(0), b.size(1), a.size(1), a_scales.data_ptr(), + b_scales.data_ptr(), a_scales.numel(), b_scales.numel()); + } else { + // Compute C_inter=s_a * s_b * (A@B) + DNNLPrimitiveHelper::gemm_s8s8_jit( + a.data_ptr(), b.data_ptr(), + tmp_fp32_out.data_ptr(), nullptr, a.size(0), b.size(1), + a.size(1), a_scales.data_ptr(), b_scales.data_ptr(), + a_scales.numel(), b_scales.numel()); + } + + // Compute C=C_inter - s_a * s_b * azp_adj + if (b_scales.numel() != 1) { + // Per-Channel + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } else { + // Per-Tensor + static_quant_epilogue( + tmp_fp32_out.data_ptr(), c.data_ptr(), + *a_scales.data_ptr(), b_scales.data_ptr(), + azp_adj.data_ptr(), a.size(0), b.size(1)); + } + } + }); +} + +// static-per-tensor quantization. +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + const torch::Tensor& scale, + c10::optional const& azp) { + CPU_KERNEL_GUARD_IN(static_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp.has_value() || azp->numel() == 1); + + const int hidden_size = input.size(-1); + const int num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + static_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } + }); +} + +// dynamic-per-token quantization. +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + const torch::Tensor& input, // [..., hidden_size] + torch::Tensor& scale, // [..., 1] + c10::optional const& azp) { + CPU_KERNEL_GUARD_IN(dynamic_scaled_int8_quant) + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_impl", [&] { + if (azp.has_value()) { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), num_tokens, + hidden_size); + } else { + dynamic_scaled_int8_quant_impl( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), nullptr, num_tokens, hidden_size); + } + }); +} diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index cf7d977da7c1c189ed792fff9e46bc274a64322f..03beefbc6de7d60fe23370f2fc9cefbb139de85d 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -4,7 +4,19 @@ #include -void init_cpu_threads_env(const std::string& cpu_ids); +std::string init_cpu_threads_env(const std::string& cpu_ids); + +void int8_scaled_mm(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const c10::optional& bias); + +void int8_scaled_mm_azp(torch::Tensor& c, const torch::Tensor& a, + const torch::Tensor& b, const torch::Tensor& a_scales, + const torch::Tensor& b_scales, + const torch::Tensor& azp_adj, + const c10::optional& azp, + const c10::optional& bias); TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // vLLM custom ops @@ -27,8 +39,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // PagedAttention V2. ops.def( "paged_attention_v2(" - " Tensor! out, Tensor exp_sums, Tensor max_logits," - " Tensor tmp_out, Tensor query, Tensor key_cache," + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," @@ -84,6 +96,37 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor! key, int head_size," " Tensor cos_sin_cache, bool is_neox) -> ()"); ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding); + + // Quantization +#ifdef __AVX512F__ + // Compute int8 quantized tensor for given scaling factor. + ops.def( + "static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale," + "Tensor? azp) -> ()"); + ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant); + + // Compute int8 quantized tensor and scaling factor + ops.def( + "dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, " + "Tensor!? azp) -> ()"); + ops.impl("dynamic_scaled_int8_quant", torch::kCPU, + &dynamic_scaled_int8_quant); + // W8A8 GEMM, supporting symmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm); + // w8a8 GEMM, supporting asymmetric per-tensor or per-row/column + // quantization. + ops.def( + "cutlass_scaled_mm_azp(Tensor! out, Tensor a," + " Tensor b, Tensor a_scales," + " Tensor b_scales, Tensor azp_adj," + " Tensor? azp, Tensor? bias) -> ()"); + ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp); +#endif } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { @@ -95,8 +138,8 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { // Copy the cache blocks from src to dst. cache_ops.def( - "copy_blocks(Tensor[]! key_caches, Tensor[]! value_caches, Tensor " - "block_mapping) -> ()"); + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); cache_ops.impl("copy_blocks", torch::kCPU, ©_blocks); // Reshape the key and value tensors and cache them. @@ -111,7 +154,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { // CPU utils - utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env); + utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cpu/utils.cpp b/csrc/cpu/utils.cpp index 5782580baa861eeaedba93e9605588665d5e1956..1138a55df2f05c6fd3d35a86936a9e2452307489 100644 --- a/csrc/cpu/utils.cpp +++ b/csrc/cpu/utils.cpp @@ -5,7 +5,7 @@ #include "cpu_types.hpp" -void init_cpu_threads_env(const std::string& cpu_ids) { +std::string init_cpu_threads_env(const std::string& cpu_ids) { bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str()); TORCH_CHECK(omp_cpu_mask->size > 0); std::vector omp_cpu_ids; @@ -51,15 +51,40 @@ void init_cpu_threads_env(const std::string& cpu_ids) { torch::set_num_threads((int)omp_cpu_ids.size()); TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads()); TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads()); + + std::vector> thread_core_mapping; + thread_core_mapping.reserve(omp_cpu_ids.size()); + omp_lock_t writelock; + omp_init_lock(&writelock); + #pragma omp parallel for schedule(static, 1) for (size_t i = 0; i < omp_cpu_ids.size(); ++i) { - cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size); - size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size); - CPU_ZERO_S(size, mask); - CPU_SET_S(omp_cpu_ids[i], size, mask); - sched_setaffinity(0, sizeof(cpu_set_t), mask); - CPU_FREE(mask); + cpu_set_t mask; + CPU_ZERO(&mask); + CPU_SET(omp_cpu_ids[i], &mask); + int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask); + if (ret == -1) { + TORCH_CHECK(false, + "sched_setaffinity failed. errno: " + std::to_string(errno)); + } + + omp_set_lock(&writelock); + thread_core_mapping.emplace_back(gettid(), omp_cpu_ids[i]); + omp_unset_lock(&writelock); } + omp_destroy_lock(&writelock); + numa_free_nodemask(omp_cpu_mask); + + std::stringstream ss; + ss << "OMP threads binding of Process " << getpid() << ":\n"; + std::sort(thread_core_mapping.begin(), thread_core_mapping.end(), + [](auto&& a, auto&& b) { return a.second < b.second; }); + for (auto&& item : thread_core_mapping) { + ss << "\t" + << "OMP tid: " << item.first << ", core " << item.second << "\n"; + } + + return ss.str(); } diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 73944f4c14890d05b5b905e309e101fa419fb015..c35224218e91cba3a20d76b15531c8541e6596ba 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,5 +1,15 @@ #pragma once +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) + #define HOST_DEVICE_INLINE __forceinline__ __host__ __device__ + #define DEVICE_INLINE __forceinline__ __device__ + #define HOST_INLINE __forceinline__ __host__ +#else + #define HOST_DEVICE_INLINE inline + #define DEVICE_INLINE inline + #define HOST_INLINE inline +#endif + int64_t get_device_attribute(int64_t attribute, int64_t device_id); int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id); diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 82a3563979f16ecec97b4862d1aa3486f54cdef1..9b82bec44c3c64f9feb8e58e723c3bcb5880e16c 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) { t.numel() * t.element_size()); } -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink) { - auto inp_size = inp.numel() * inp.element_size(); - // custom allreduce requires input byte size to be multiples of 16 - if (inp_size % 16 != 0) return false; - if (!_is_weak_contiguous(inp)) return false; - if (world_size == 2 || full_nvlink) return inp_size <= max_size; - // for 4 or more non NVLink-capable GPUs, custom allreduce provides little - // performance improvement over NCCL. - return false; -} - void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, cudaStream_t stream) { auto fa = reinterpret_cast(_fa); diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9cae432a47c17a8c478116456bbca0a..a2f7e43300002323f799afe3cf74ea7063e668c6 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,23 @@ namespace vllm { -constexpr int kMaxBlocks = 64; -// note: we don't want to use atomics for signals because peer atomics are no -// supported on PCIe links +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; struct Signal { - alignas(128) uint32_t start[kMaxBlocks][8]; - alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { Signal* signals[8]; }; // like std::array, but aligned template @@ -123,47 +130,71 @@ DINLINE O downcast(array_t val) { } } -// This function is meant to be used as the first synchronization in the all -// reduce kernel. Thus, it doesn't need to make any visibility guarantees for -// prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes. -template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); - } - __syncthreads(); +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#else + asm volatile("membar.sys; st.volatile.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +#endif +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); +#else + asm volatile("ld.volatile.global.u32 %0, [%1]; membar.gl;" + : "=r"(flag) + : "l"(flag_addr)); +#endif + return flag; } -// This function is meant to be used as the second or the final synchronization -// barrier in the all reduce kernel. If it's the final synchronization barrier, -// we don't need to make any visibility guarantees for prior memory accesses. -template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - __syncthreads(); - // eliminate the case that prior writes are not visible after signals become - // visible. Note that I did not managed to make this happen through a lot of - // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. - if constexpr (!final_sync) __threadfence_system(); +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } } - if constexpr (!final_sync) __syncthreads(); + if constexpr (is_start || need_fence) __syncthreads(); } template @@ -178,33 +209,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { +DINLINE P* get_tmp_buf(Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -222,12 +251,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

    (sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -437,6 +466,8 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076cd917c0fe0d3fa84bee823c7e8cb2..376687e91cfda3945a39b85c21721454105be57c 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=XXX + * export MPI_HOME=xxx * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./custom_all_reduce_test + * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ #include #include @@ -44,7 +44,14 @@ } while (0) __global__ void dummy_kernel() { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +#else + for (int i = 0; i < 100; i++) { + long long int start = clock64(); + while (clock64() - start < 150000000); // approximately 98.4ms on P40 + } +#endif } template @@ -302,15 +309,19 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // for (int threads : {256, 512}) { + // Uncomment to scan through different block size configs. + // for (int threads : {256, 512, 1024}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, + // performance_test); // } // } + // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); + MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; } diff --git a/csrc/cutlass_extensions/cute_utils.cuh b/csrc/cutlass_extensions/cute_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..1842fab8b2cace3e4bc9a2690449db79deb7c680 --- /dev/null +++ b/csrc/cutlass_extensions/cute_utils.cuh @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +namespace cute { + +//////////////////////////////////////////////////////////////////// +// layout utils +//////////////////////////////////////////////////////////////////// + +// Permute layout based on indices, example: +// permute_layout<1, 0>(layout) will swap the two dimensions +// permute_layout<0, 2, 1>(layout) will swap the last two dimensions +template +CUTE_HOST_DEVICE static constexpr auto permute_layout(Layout l) { + static_assert(rank(l) == sizeof...(I), "Invalid permutation, rank mismatch"); + return cute::make_layout(cute::get(l)...); +} + +// is the layout f(x) = x +template +CUTE_HOST_DEVICE static constexpr bool is_identity_layout() { + if constexpr (std::is_same_v) + return true; + else { + constexpr auto coalesced_layout = coalesce(Layout{}); + if constexpr (rank(coalesced_layout) == 1 && + stride<0>(coalesced_layout) == 1) { + return true; + } + return false; + } +} + +//////////////////////////////////////////////////////////////////// +// Pointer utils +//////////////////////////////////////////////////////////////////// + +template +static constexpr auto get_logical_ptr(PointerType* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return cute::subbyte_iterator(ptr); + } else { + return ptr; + } +} + +//////////////////////////////////////////////////////////////////// +// Misc utils +//////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() { + constexpr auto bits = sizeof_bits_v * Elements{}; + if constexpr (bits % 128 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } else if constexpr (bits % 64 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<64>{}; + } else if constexpr (bits % 32 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<32>{}; + } else if constexpr (bits % 16 == 0) { + return AutoVectorizingCopyWithAssumedAlignment<16>{}; + } else { + return AutoVectorizingCopyWithAssumedAlignment<8>{}; + } +} + +}; // namespace cute diff --git a/csrc/cutlass_extensions/torch_utils.hpp b/csrc/cutlass_extensions/torch_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c78572521eecd958701c1b76a78c7f243a85af6 --- /dev/null +++ b/csrc/cutlass_extensions/torch_utils.hpp @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include "cute/layout.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/bfloat16.h" +#include "cutlass/half.h" + +using ColumnMajor = typename cutlass::layout::ColumnMajor; +using RowMajor = typename cutlass::layout::RowMajor; + +namespace cute { + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g, + seq) { + return g(f(cute::get(static_cast(t)), I)...); +} + +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq) { + return make_shape(f(I)...); +} + +}; // namespace detail + +template +CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) { + if constexpr (cute::is_tuple::value) { + return detail::tapply_with_idx( + t, f, [](auto const&... a) { return cute::make_tuple(a...); }, + tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// calls: make_shape(f(0), f(1), ..., f(N-1)) +template +CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { + return detail::make_shape_from_idx(f, make_seq{}); +} + +}; // namespace cute + +// Make a layout from a tensor with `rank(Stride{})`, where the shape is the +// shape of the passed in tensor and the strides are of type `Stride` and +// contain the strides of the passed in tensor, checking that any static strides +// in `Stride{}` match the strides of the passed in tensor. +// If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra +// strides are set to be 0 or 1. +template +static inline auto make_cute_layout(torch::Tensor const& tensor, + std::string_view name = "tensor") { + TORCH_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx( + Stride{}, [&](auto const& stride_ele, auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; + } else { + return tensor.stride(idx); + } + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); + + auto shape = cute::make_shape_from_idx([&](auto const& idx) { + if (idx < tensor.dim()) + return tensor.size(idx); + else + return int64_t(1); + }); + + return make_layout(shape, stride); +} + +template +static inline auto maybe_make_cute_layout( + c10::optional const& tensor, + std::string_view name = "tensor") { + using Layout = decltype(make_cute_layout(*tensor)); + + if (tensor) { + return std::optional{make_cute_layout(*tensor, name)}; + } else { + return std::optional{}; + } +} + +// +// Torch Type to Cutlass Type (equivalent_cutlass_type) +// + +template +struct equivalent_cutlass_type { + using type = T; +}; + +template +using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::half_t; +}; + +template <> +struct equivalent_cutlass_type { + using type = cutlass::bfloat16_t; +}; + +// +// equivalent_scalar_t (basically inverse of equivalent_cutlass_type) +// + +// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from +// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +template +struct equivalent_scalar_type { + using type = T; +}; + +template +using equivalent_scalar_type_t = typename equivalent_scalar_type::type; + +template <> +struct equivalent_scalar_type { + using type = c10::Half; +}; + +template <> +struct equivalent_scalar_type { + using type = c10::BFloat16; +}; + +// get equivalent c10::ScalarType tag from compile time type +template +static inline constexpr c10::ScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_collective_builder.cuh b/csrc/cutlass_extensions/vllm_collective_builder.cuh new file mode 100644 index 0000000000000000000000000000000000000000..085ee1290031fb88eb1c1e06dfba50eca50dceda --- /dev/null +++ b/csrc/cutlass_extensions/vllm_collective_builder.cuh @@ -0,0 +1,43 @@ +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::gemm::collective { +using namespace cute; + +// +// VLLMCollectiveBuilder is a wrapper around CollectiveBuilder that allows for +// for custom kernel tags, allowing you to build custom collectives. Without +// touching the cutlass library headers, using `CutlassKernelTag` will mean it +// will resort to using the standard cutlass collective builder. +// + +// Use the default Cutlass collective builder, i.e. use an unmodified cutless +// collective +struct CutlassKernelTag {}; + +template +struct VLLMCollectiveBuilder { + static_assert(sizeof(ElementA) == 0, + "Could not build a collective for given parameters."); +}; + +template +struct VLLMCollectiveBuilder< + CutlassKernelTag, ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, + ElementB, GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType> { + using CollectiveOp = typename CollectiveBuilder< + ArchTag, OpClass, ElementA, GmemLayoutA, AlignmentA, ElementB, + GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, + ClusterShape_MNK, StageCountType, KernelScheduleType>::CollectiveOp; +}; + +}; // namespace cutlass::gemm::collective \ No newline at end of file diff --git a/csrc/cutlass_extensions/vllm_custom_types.cuh b/csrc/cutlass_extensions/vllm_custom_types.cuh new file mode 100644 index 0000000000000000000000000000000000000000..6146bdc1f08c68db3f8e13e55f50c1baa5430547 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_custom_types.cuh @@ -0,0 +1,50 @@ +#pragma once + +#include "cutlass/integer_subbyte.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct vllm_biased_integer_subbyte : public integer_subbyte { + using Base = integer_subbyte; + + using Storage = typename Base::Storage; + using xint_t = typename Base::xint_t; + + using Base::bits_mask_; + using Base::sign_mask_; + using Base::storage; + + // + // Methods + // + + /// No operation + vllm_biased_integer_subbyte() = default; + + /// Conversion from integer type + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(int value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(unsigned value) + : Base(value) {} + CUTLASS_HOST_DEVICE explicit vllm_biased_integer_subbyte(double value) + : Base(value) {} +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// "GPTQ" types, i.e. symmetric quantization +using vllm_uint4b8_t = vllm_biased_integer_subbyte<4, 8>; // u4b8 +using vllm_uint8b128_t = vllm_biased_integer_subbyte<8, 128>; // u8b128 + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct sizeof_bits> { + static constexpr int value = Bits; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py new file mode 100644 index 0000000000000000000000000000000000000000..4fcfcd311aa913368d42f900aa624edc67c2d53e --- /dev/null +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -0,0 +1,49 @@ +import enum +from typing import Dict, Union + +from cutlass_library import * + +# +# Extend cutlass library with custom types, and missing values +# + + +class VLLMDataType(enum.Enum): + u4b8 = enum_auto() + u8b128 = enum_auto() + + +class MixedInputKernelScheduleType(enum.Enum): + TmaWarpSpecializedMixedInput = enum_auto() + TmaWarpSpecializedPingpongMixedInput = enum_auto() + TmaWarpSpecializedCooperativeMixedInput = enum_auto() + + +VLLMDataTypeNames: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeNames, # type: ignore + **{ + VLLMDataType.u4b8: "u4b8", + VLLMDataType.u8b128: "u8b128", + } +} + +VLLMDataTypeTag: Dict[Union[VLLMDataType, DataType], str] = { + **DataTypeTag, # type: ignore + **{ + VLLMDataType.u4b8: "cutlass::vllm_uint4b8_t", + VLLMDataType.u8b128: "cutlass::vllm_uint8b128_t", + } +} + +VLLMKernelScheduleTag: Dict[Union[ + MixedInputKernelScheduleType, KernelScheduleType], str] = { + **KernelScheduleTag, # type: ignore + **{ + MixedInputKernelScheduleType.TmaWarpSpecializedMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedPingpongMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput", + MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput: + "cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput", + } + } diff --git a/csrc/cutlass_extensions/vllm_numeric_conversion.cuh b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh new file mode 100644 index 0000000000000000000000000000000000000000..2ad914f8e9868ef235be45e4173a74d0abec54b5 --- /dev/null +++ b/csrc/cutlass_extensions/vllm_numeric_conversion.cuh @@ -0,0 +1,795 @@ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass_extensions/vllm_custom_types.cuh" +#include "cutlass_extensions/cute_utils.cuh" + +// this file extends: +// https://github.com/NVIDIA/cutlass/blob/cutlass-3.5.0/include/cutlass/numeric_conversion.h +// with vllm specific type conversions, namely: vllm_uint4b8_t, vllm_uint8b128_t +// as well as adds interleaved numeric array converters for specific types. +// (interleaved numeric array converters can be more efficient for subbyte +// types) + +namespace cutlass { + +// InterleavedNumericArrayConverter is like NumericArrayConverter but also +// deinterleaves converted elements based on IlvBlkLayout, interleaving can +// make subbyte converts more efficient by allowing for efficient extraction +// of subbyte elements from a 32bit register. +template +struct InterleavedNumericArrayConverter { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + CUTE_INVALID_CONTROL_PATH( + "InterleavedNumericArrayConverter not implemented\n"); + return {}; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +template +struct InterleavedNumericArrayConverter< + IlvBlkLayout, T, S, N, Round, + std::enable_if_t()>> { + using Converter = NumericArrayConverter; + + using result_type = typename Converter::result_type; + using source_type = typename Converter::source_type; + + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return Converter::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// TODO (LucasWilkinson): Implement +// for Array <= Array + +// .... + +template +struct ArrayConverterPacked32Bit { + using result_type = Array; + using source_type = Array; + + using result_packed_8_t = Array; + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_8_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + static_assert(N % 2 == 0, "N must be a multiple of 2"); + static_assert(cutlass::sizeof_bits_v >= 4); // TODO: add 16 packed sources + static_assert(32 % cutlass::sizeof_bits_v == 0); + static constexpr auto src_elems_per_32bit_reg = + 32 / cutlass::sizeof_bits_v; + + // Maybe not Valid. ScalarConverter will not actually work unless + // NumericConverter is implemented. However it won't be used + // anyways since we assert N % 2 == 0, just here for compliance with + // VectorizedConverter. + using ScalarConverter = NumericConverter; + + template + CUTLASS_DEVICE static uint32_t to_reg(PackedSrc const& source) { + if constexpr (sizeof(PackedSrc) == 1) { + return static_cast(reinterpret_cast(source)); + } else if constexpr (sizeof(PackedSrc) == 2) { + return static_cast(reinterpret_cast(source)); + } else { + static_assert(sizeof(PackedSrc) == 4); + return reinterpret_cast(source); + } + } + + // The core converter uses bit tricks to construct a known FP16 number, then + // does a subtraction in FP16 for the final result. + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert(PackedSrcType::kElements == PackedResultType::kElements); + static_assert(PackedResultType::kElements == 2 || + PackedResultType::kElements == 4 || + PackedResultType::kElements == 8, + "Invalid PackedResultType must be 2, 4 or 8."); + static_assert(std::is_same_v); + static_assert(std::is_same_v); + + return RegConvert32bit::template convert(to_reg(source)); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + ArrayConverterPacked32Bit; + + if constexpr (src_elems_per_32bit_reg >= 8) { + detail::VectorizedConverter::convert< + ConverterType, result_packed_8_t, src_packed_8_t, result_packed_4_t, + src_packed_4_t, result_packed_2_t, src_packed_2_t>(result, source); + } else if constexpr (src_elems_per_32bit_reg >= 4) { + detail::VectorizedConverter::convert(result, source); + } else { + detail::VectorizedConverter::convert(result, source); + } + + return result; + } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + // Below constructs the following temporary: + // fp16s_01 = {0x00, i4_01, 0x00, i4_01} + // fp16s_23 = {0x00, i4_23, 0x00, i4_23} + // fp16s_45 = {0x00, i4_45, 0x00, i4_45} + // fp16s_67 = {0x00, i4_67, 0x00, i4_67} + // We use inline asm instead of __byte_perm intrinsic since we don't want + // the documented (& 0x7) on the index. NVCC might be able to optimize it + // out since the index is a constexpr, but we choose to be safe about it + // here. + uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for F16 -> I4 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src), "n"(0), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a fp16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the FP16 to the correct value for the + // FP16 magic_num. We will be constructing {1024+16*(x1+8), 1024+(x0+8)}, + // where x1 in the high nibble and x0 is the low nibble then using hfma + // to subtract 1032 from that + // The AND does the following: + // 1) Clear the set bits for the int4 we will ignore. + // We use lop3 so that we can use 1 instruction for AND and XOR. + static constexpr uint32_t xor_mask = 0x64006400; + static constexpr uint32_t and_mask = 0xFFF0FF0F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 hfmas that do the following: + // {x1, x0} = {1024+16*(x1+8), 1024+(x0+8)} * {1/16, 1} - {72, 1032} + // = {x1 + 1152, x0 + 1032} * {1/16, 1} - {72, 1032} + static constexpr uint32_t hfma_bias_rep = 0xD480E408; // {72, 1032} + static constexpr uint32_t hfma_scale_rep = 0x2C003C00; // {1 / 16, 1} + + const half2& hfma_bias = reinterpret_cast(hfma_bias_rep); + const half2& hfma_scale = reinterpret_cast(hfma_scale_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+(x1+8), 1024+(x0+8)} * {1, 1} - {1032, 1032} + // For high nibble: + // {x1, x0} = {1024+16*(x1+8), 1024+16*(x0+8)} * {1/16, 1/16} + // - {72, 72} + static constexpr uint32_t low_nib_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD480D480; // {-72, -72} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::half_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t xor_mask = 0x64006400; + + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + + static constexpr uint32_t and_xor_imm_lut = (0xf0 & 0xcc) ^ 0xaa; + + static constexpr uint32_t low_nib_mask = 0x000F000F; + static constexpr uint32_t high_nib_mask = 0x00F000F0; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(high_nib_mask), "n"(xor_mask), "n"(and_xor_imm_lut)); + + // For low nibble: + // {x1, x0} = {1024+x1, 1024+x0} - {1024, 1024} + // For high nibble: + // {x1, x0} = {1024+16*x1, 1024+16*x0} * {1/16, 1/16} - {64, 64} + static constexpr uint32_t low_nib_bias = 0x64006400; // {1024, 1024} + static constexpr uint32_t high_nib_scale = 0x2C002C00; // {1/16, 1/16} + static constexpr uint32_t high_nib_bias = 0xD400D400; // {-64, -64} + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = + __hsub2(fp16x2_val, reinterpret_cast(low_nib_bias)); + } + + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(high_nib_scale), + reinterpret_cast(high_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + // Hold output FP16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + + uint32_t const prmt_indices[2] = {0x5150, 0x5352}; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(r[ii]) + : "r"(src), "n"(start_byte_for_fp16), + "r"(prmt_indices[ii])); + } + + // -128 is folded into bias subtraction, i.e. the 0x80 in the low bytes + static constexpr uint32_t bias_rep = 0x64806480; + const half2& bias = reinterpret_cast(bias_rep); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, bias); + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + PackedResultType r; + + // __byte_perm simulates the add.u32 0x4B000000 to every u8 element of + // u8x4 source and stores the result in r (without introducing extra + // cvt.u32.u8 instruction) + uint32_t const prmt_indices[4] = {0x7650, 0x7651, 0x7652, 0x7653}; + uint32_t* result_as_int = reinterpret_cast(&r); + for (int ii = 0; ii < PackedResultType::kElements; ++ii) { + result_as_int[ii] = __byte_perm(src, 0x4B000000, prmt_indices[ii]); + // Subtract the magic number 0x4B000000 from tmp in floating-point + // arithmetic to obtain final result + r[ii] -= (8388608.f + 128.f); // fold in -128 bias + } + + return r; + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + + static FloatRoundStyle const round_style = Round; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src_reg) { + // Hold output BF16s in reg. We need 1 reg for every 2 elements + using RegArray = + cutlass::AlignedArray; + RegArray r; + uint32_t src_reg_shifted = src_reg >> 4; + + // Below constructs the following temporary: + uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3}; + static_assert(RegArray::kElements <= 4, + "Too many inputs for uint4b8_t -> BF16 vector converter"); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(r[ii]) + : "r"(src_reg), "r"(src_reg_shifted), "r"(prmt_indices[ii])); + } + + // Since the stored 4bit values are biased by 8 we get stored_val = (x+8) + // we are trying to construct x and a BF16 value + // The below XOR does the following: + // 1) Sets the exponent bits of the BF16 to the correct value for the + // BF16 magic_num. We will be constructing {128 + (x1+8), 128 + (x0+8)} + // and subtracting 136 to get {x1, x0} + static constexpr uint32_t xor_mask = 0x43004300; + static constexpr uint32_t and_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + + // For each operand, computes: + // r[i] = (r[i] & and_mask) ^ xor_mask + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + } + + // We will issue 2 bfmas that do the following: + // high BF16: + // hi_bf16 - 136, lo_bf16 - 136 + + // This is the BF16 {136, 136} represented as an integer. + static constexpr uint32_t bias_rep = 0x43084308; + const __nv_bfloat162& bias = + reinterpret_cast(bias_rep); + + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, bias); + } + + return reinterpret_cast(r); + } + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, vllm_uint4b8_t, N, + Round, void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128+(x1+8), 128+(x0+8)} * {1, 1} - {136, 136} + static constexpr uint32_t low_nib_bias = 0x43084308; // {136, 136} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +// for IlvdLayout: (2, 4):(4, 1) +template +struct InterleavedNumericArrayConverter, Stride<_4, _1>>, + cutlass::bfloat16_t, uint4_t, N, Round, + void> { + using IlvdLayout = Layout, Stride<_4, _1>>; + static_assert(N % size(IlvdLayout{}) == 0); + + using result_type = Array; + using source_type = Array; + + private: + struct RegConvert { + template + CUTLASS_DEVICE static PackedResultType convert(uint32_t src) { + using RegArray = + cutlass::AlignedArray; + RegArray r; + + static_assert(PackedResultType::kElements <= size(IlvdLayout{})); + static constexpr uint32_t or_mask = 0x43004300; + + // Unlike float16 where the mantissa is large enough to contain 2 + // nibbles, bfloat16 can only fit one, so we can only convert one + // nibble at a time + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src >> (4 * ii); + + static constexpr uint32_t and_or_imm_lut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t low_nib_mask = 0x000F000F; + + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(low_nib_mask), "n"(or_mask), "n"(and_or_imm_lut)); + + // For low nibble: + // {x1, x0} = {128 + x1, 128 + x0} * {1, 1} - {128, 128} + static constexpr uint32_t low_nib_bias = 0x43004300; // {128, 128} + + { + __nv_bfloat162& fp16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + fp16x2_val = + __hsub2(fp16x2_val, + reinterpret_cast(low_nib_bias)); + } + } + + return reinterpret_cast(r); + }; + }; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + return ArrayConverterPacked32Bit::convert(source); + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +// for Array <= Array +template +struct NumericArrayConverter { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + private: + using result_packed_4_t = Array; + using result_packed_2_t = Array; + using src_packed_4_t = Array; + using src_packed_2_t = Array; + + // Not Valid, not supported, only here to satisfy the interface and to avoid + // a compile error. ScalarConverter will not actually work until + // NumericConverter is + // implemented + using ScalarConverter = + NumericConverter; + + template + CUTLASS_DEVICE static PackedResultType packed_convert( + PackedSrcType const& source) { + static_assert( + (platform::is_same::value && + platform::is_same::value) || + (platform::is_same::value && + platform::is_same::value), + "Invalid PackedSrcType/PackedResultType must be 2 or 4 to use private " + "convert dispatch."); + + NumericArrayConverter + convert_uint8_to_f32; + Array tmp = + convert_uint8_to_f32(source); + NumericArrayConverter + convert_f32_to_bf16_; + return convert_f32_to_bf16_(tmp); + } + + friend class detail::VectorizedConverter; + + public: + CUTLASS_DEVICE + static result_type convert(source_type const& source) { + result_type result; + using ConverterType = + NumericArrayConverter; + detail::VectorizedConverter::convert(result, source); + + return result; + } + + CUTLASS_DEVICE + result_type operator()(source_type const& s) const { return convert(s); } +}; + +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 1f8e2dcbca6d07ca5da760aa051e21a5b99fb00f..fc6bd2770cfeef97d02d001bf5fd8ac17df8159a 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -3,13 +3,16 @@ #include #include "dispatch_utils.h" -#include "reduction_utils.cuh" #ifndef USE_ROCM #include #include + #include + #include #else #include #include + #include + #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; @@ -31,7 +34,11 @@ __global__ void rms_norm_kernel( const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -228,12 +235,11 @@ fused_add_rms_norm_kernel( variance += temp.sum_squares(); residual_v[id] = temp; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -268,12 +274,11 @@ fused_add_rms_norm_kernel( variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 0000000000000000000000000000000000000000..3a464c5f327ad24d60b286a17affbbacd35eaf35 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,632 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + const c10::optional& bias, + bool silu_activation, + int64_t pad_slot_id, + const c10::optional& query_start_loc = std::nullopt, + const c10::optional& cache_indices = std::nullopt, + const c10::optional& has_initial_state = std::nullopt) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + params.pad_slot_id = pad_slot_id; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + const bool varlen = params.query_start_loc_ptr != nullptr; + params.x_batch_stride = x.stride(varlen ? 1 : 0); + params.x_c_stride = x.stride(varlen ? 0 : 1); + params.x_l_stride = x.stride(varlen ? 1 : -1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(varlen ? 1 : 0); + params.out_c_stride = out.stride(varlen ? 0 : 1); + params.out_l_stride = out.stride(varlen ? 1 : -1); +} + + +void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &conv_states, + const c10::optional &query_start_loc, + const c10::optional &cache_indices, + const c10::optional &has_initial_state, + bool silu_activation, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const bool varlen = query_start_loc.has_value() ? true : false; + const auto sizes = x.sizes(); + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; + const int width = weight.size(-1); + if (varlen){ + CHECK_SHAPE(x, dim, seqlen); + } + else { + CHECK_SHAPE(x, batch_size, dim, seqlen); + } + CHECK_SHAPE(weight, dim, width); + + + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); + } + + + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); + } + + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); + } + + at::Tensor out = x; + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, + silu_activation, + pad_slot_id, + query_start_loc, + cache_indices, + has_initial_state + ); + + if (conv_states.has_value()) { + auto conv_states_ = conv_states.value(); + TORCH_CHECK(conv_states_.scalar_type() == input_type); + TORCH_CHECK(conv_states_.is_cuda()); + params.conv_states_ptr = conv_states_.data_ptr(); + params.conv_states_batch_stride = conv_states_.stride(0); + params.conv_states_c_stride = conv_states_.stride(1); + params.conv_states_l_stride = conv_states_.stride(2); + } else { + params.conv_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + causal_conv1d_fwd_cuda(params, stream); + }); +} + + +void causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation, + const c10::optional &cache_seqlens_, + const c10::optional &conv_state_indices_, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + const int conv_state_len = conv_state.size(2); + TORCH_CHECK(conv_state_len >= width - 1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = x; + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, + silu_activation, + pad_slot_id); + params.conv_state_ptr = conv_state.data_ptr(); + params.conv_state_len = conv_state_len; + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + if (cache_seqlens_.has_value()) { + auto cache_seqlens = cache_seqlens_.value(); + TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); + TORCH_CHECK(cache_seqlens.is_cuda()); + TORCH_CHECK(cache_seqlens.stride(-1) == 1); + CHECK_SHAPE(cache_seqlens, batch_size); + params.cache_seqlens = cache_seqlens.data_ptr(); + } else { + params.cache_seqlens = nullptr; + } + + if (conv_state_indices_.has_value()) { + auto conv_state_indices = conv_state_indices_.value(); + TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) + TORCH_CHECK(conv_state_indices.is_cuda()); + TORCH_CHECK(conv_state_indices.stride(0) == 1) + CHECK_SHAPE(conv_state_indices, batch_size); + + int conv_state_entries = conv_state.size(0); + CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); + + params.conv_state_indices_ptr = conv_state_indices.data_ptr(); + } else { + CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); + params.conv_state_indices_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const bool kVarlen = params.query_start_loc_ptr != nullptr; + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; + const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; + const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; + + input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } + input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr + : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t initial_state[kNElts] = {0}; + if (has_initial_state) { + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } + } + smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } + // Final state is stored in the smem_exchange last token slot, + // in case seqlen < kWidth, we would need to take the final state from the + // initial state which is stored in conv_states + // in case seqlen > kWidth, we would need to load the last kWidth - 1 data + // and load it into conv_state accordingly + int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; + if (conv_states != nullptr && tidx == last_thread) { + input_t x_vals_load[kNElts * 2] = {0}; + // in case we are on the first kWidth tokens + if (last_thread == 0 && seqlen < kWidth){ + // Need to take the initial state + reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; + const int offset = seqlen - (kWidth - 1); + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + // pad the existing state + if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } + else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } + } + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + if (offset + w >= 0) + conv_states[w] = x_vals_load[offset + w ]; + } + } + else { + // in case the final state is in between the threads data + reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; + reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; + const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + conv_states[w] = x_vals_load[offset + w ]; + } + } + + } +} + + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + const bool kVarlen = params.query_start_loc_ptr != nullptr; + BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + + auto kernel = &causal_conv1d_fwd_kernel; + + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + kernel<<>>(params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + if (channel_id >= params.dim) return; + + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + + // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor + // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id. + const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr + ? batch_id + : params.conv_state_indices_ptr[batch_id]; + // conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early + if (conv_state_batch_coord == params.pad_slot_id){ + return; + } + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + + conv_state_batch_coord * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int state_len = params.conv_state_len; + int advance_len = params.seqlen; + int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; + int update_idx = cache_seqlen - (kWidth - 1); + update_idx = update_idx < 0 ? update_idx + state_len : update_idx; + + float weight_vals[kWidth] = {0}; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + float x_vals[kWidth] = {0}; + if constexpr (!kIsCircularBuffer) { + #pragma unroll 2 + for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { + conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; + } + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { + input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; + if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { + conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; + } + x_vals[i] = float(state_val); + } + } else { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { + input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; + x_vals[i] = float(state_val); + } + } + #pragma unroll 2 + for (int i = 0; i < params.seqlen; ++i) { + input_t x_val = x[i * params.x_l_stride]; + if constexpr (!kIsCircularBuffer) { + if (i < advance_len && state_len - advance_len + i >= 0) { + conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; + } + } else { + conv_state[update_idx * params.conv_state_l_stride] = x_val; + ++update_idx; + update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; + } + x_vals[kWidth - 1] = float(x_val); + float out_val = bias_val; + #pragma unroll + for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + out[i * params.out_l_stride] = input_t(out_val); + // Shift the input buffer by 1 + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } + } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = params.cache_seqlens == nullptr + ? &causal_conv1d_update_kernel + : &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 0000000000000000000000000000000000000000..e26684a2b98b8ce6f5b3358c796c099e0edc4315 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,159 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + int64_t pad_slot_id; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + int conv_state_len; + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + void *__restrict__ query_start_loc_ptr; + void *__restrict__ has_initial_state_ptr; + void *__restrict__ cache_indices_ptr; + int32_t *__restrict__ cache_seqlens; + + // For the continuous batching case. Makes it so that the mamba state for + // the current batch doesn't need to be a contiguous tensor. + int32_t *__restrict__ conv_state_indices_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; + + void * conv_states_ptr; + index_t conv_states_batch_stride; + index_t conv_states_l_stride; + index_t conv_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..ef74bf447f84065b63026f73c091c6d2bc8625bb --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 0000000000000000000000000000000000000000..563d2fe4ef65b73aefb639f47d612ce8ff2bfd74 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,266 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + int64_t pad_slot_id; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ ssm_states_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + + void *__restrict__ query_start_loc_ptr; + void *__restrict__ cache_indices_ptr; + void *__restrict__ has_initial_state_ptr; + +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 0000000000000000000000000000000000000000..71624696338d00c8215a53ba8247046d180690e4 --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,658 @@ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kVarlen = kVarlen_; + + static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kVarlen = Ktraits::kVarlen; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + int seqlen = params.seqlen; + int sequence_start_index = batch_id; + if constexpr (kVarlen){ + int *query_start_loc = reinterpret_cast(params.query_start_loc_ptr); + sequence_start_index = query_start_loc[batch_id]; + seqlen = query_start_loc[batch_id + 1] - sequence_start_index; + } + const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + // cache_index == params.pad_slot_id is defined as padding, so we exit early + if (cache_index == params.pad_slot_id){ + return; + } + input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; + input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + const int n_chunks = (seqlen + 2048 - 1) / 2048; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + if (chunk == n_chunks - 1) { + ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + } + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + const c10::optional& D, + const c10::optional& delta_bias, + const torch::Tensor ssm_states, + bool has_z, + bool delta_softplus, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool varlen, + int64_t pad_slot_id) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.dim_ngroups_ratio = dim / n_groups; + params.pad_slot_id = pad_slot_id; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr; + params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr; + params.out_ptr = out.data_ptr(); + params.ssm_states_ptr = ssm_states.data_ptr(); + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + + if (varlen){ + params.B_batch_stride = B.stride(2); + params.B_group_stride = B.stride(0); + params.B_dstate_stride = B.stride(1); + params.C_batch_stride = C.stride(2); + params.C_group_stride = C.stride(0); + params.C_dstate_stride = C.stride(1); + + params.u_batch_stride = u.stride(1); + params.u_d_stride = u.stride(0); + params.delta_batch_stride = delta.stride(1); + params.delta_d_stride = delta.stride(0); + if (has_z) { + params.z_batch_stride = z.stride(1); + params.z_d_stride = z.stride(0); + params.out_z_batch_stride = out_z.stride(1); + params.out_z_d_stride = out_z.stride(0); + } + params.out_batch_stride = out.stride(1); + params.out_d_stride = out.stride(0); + + } + else{ + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); + } +} + +void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &query_start_loc, + const c10::optional &cache_indices, + const c10::optional &has_initial_state, + const torch::Tensor &ssm_states, + // used to identify padding entries if cache_indices provided + // in case of padding, the kernel will return early + int64_t pad_slot_id) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const bool varlen = query_start_loc.has_value(); + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; + const int dstate = A.size(1); + const int n_groups = varlen ? B.size(0) : B.size(1); + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + if (varlen) { + CHECK_SHAPE(u, dim, seqlen); + CHECK_SHAPE(delta, dim, seqlen); + } else { + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + } + CHECK_SHAPE(A, dim, dstate); + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + if (varlen) { + CHECK_SHAPE(B, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + } + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + if (varlen) { + CHECK_SHAPE(C, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + } + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); + } + + + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); + } + + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); + } + + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; + + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = delta; + TORCH_CHECK(ssm_states.scalar_type() == input_type); + TORCH_CHECK(ssm_states.is_cuda()); + TORCH_CHECK(ssm_states.stride(-1) == 1); + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_, + delta_bias_, + ssm_states, + has_z, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + varlen, + pad_slot_id + ); + + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 0000000000000000000000000000000000000000..840cb2374a2f03011957909d2a3a27fb7958c91a --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel.h b/csrc/moe/marlin_kernels/marlin_moe_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a217401b3d7c29b7d2786ac84e9a4303f60ae619 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel.h @@ -0,0 +1,1616 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/scalar_type.hpp" + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales +using FragZP = Vec; + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" + : "=r"(res) + : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline FragB dequant(int q); + +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Fast Int8ToFp16: Efficiently dequantize 8bit int values to fp16 +// Reference: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +template <> +__device__ inline FragB dequant(int q) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(*reinterpret_cast(&hi), + *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +__device__ inline void sub_zp(FragB& frag_b, half2& frag_zp, int i) { + half2 zp = __half2half2(reinterpret_cast<__half*>(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); + constexpr int pack_factor = 32 / w_type.size_bits(); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = + !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x * b_thread_vecs; + int b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + zp_sh_stride * slice_col + threadIdx.x; + } + } + int zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + zp_sh_rd = num_ints_per_thread * num_col_threads * + ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); + } + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], + &zp_ptr[zp_gl_rd]); + } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { + cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + + #pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { + is_same_group[pipe] = false; + same_group_id[pipe] = 0; + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning + #pragma nv_diagnostic push + #pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; + #pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = + (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + if constexpr (has_zp) { + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k % 2][0]; + zp_quant_1 = frag_qzp[k % 2][1]; + } + + frag_zp_0 = dequant(zp_quant_0); + frag_zp_1 = dequant(zp_quant_1); + + frag_zp[0] = frag_zp_0[0]; + frag_zp[1] = frag_zp_0[1]; + frag_zp[2] = frag_zp_1[0]; + frag_zp[3] = frag_zp_1[1]; + } + + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant_0, b_quant_1; + if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k % 2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k % 2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + FragB frag_b0 = dequant(b_quant_0); + FragB frag_b1 = dequant(b_quant_1); + // Apply zero-point to frag_b0 + if constexpr (has_zp) { + sub_zp(frag_b0, frag_zp[j], 0); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + // Apply zero-point to frag_b1 + if constexpr (has_zp) { + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 4) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && group_blocks == -1) { + if (i == 0) { + fetch_zp_to_shared(); + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } else { + // For 4-bit per-column scales, we only fetch them here in the + // final step before write-out + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if constexpr (w_type.size_bits() == 8) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + + } else { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && + w_type.size_bits() == 8) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), + frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + 0]); + + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), + frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), + frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > cfg_max_m_blocks) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / (16 * cfg_max_m_blocks); + if (par > max_par) par = max_par; + prob_m = (16 * cfg_max_m_blocks) * par; + m_block_ctr += cfg_max_m_blocks * (par - 1); + max_block = cfg_max_m_blocks; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, zp_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const bool has_zp, // whether zero-points are enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par, // maximum parallelism + int cfg_max_m_blocks // upper bound on m blocks +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \ + HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \ + else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + cfg_max_m_blocks); \ + } + +#define GPTQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) + +#define AWQ_CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \ + __CALL_IF_MOE(W_TYPE, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu new file mode 100644 index 0000000000000000000000000000000000000000..77bc0dd90edde03ab80a0a77241bacfbd4955712 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku4.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = true; + + if (false) { + } + AWQ_CALL_IF_MOE(vllm::kU4, 16, 4, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 8, 256) + AWQ_CALL_IF_MOE(vllm::kU4, 8, 4, 128) + AWQ_CALL_IF_MOE(vllm::kU4, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h new file mode 100644 index 0000000000000000000000000000000000000000..833fadf37721f93717e060c7c0379588704eb777 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu new file mode 100644 index 0000000000000000000000000000000000000000..f7e57b037594539ed77d69d1968b01cc2d506afa --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku4b8.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU4B8, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU4B8, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h new file mode 100644 index 0000000000000000000000000000000000000000..494da8f10e26255d3f408ce7c8d87c53f67af845 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h @@ -0,0 +1,20 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku4b8( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu new file mode 100644 index 0000000000000000000000000000000000000000..a901f0b11cd786b9f5574823939fe3f3c40783af --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu @@ -0,0 +1,31 @@ +#include "marlin_moe_kernel_ku8b128.h" + +namespace marlin_moe { + +// We return bool so we can create these different kernel calls as a sequence +// of if-elseif's. +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks) { + bool has_zp = false; + + if (false) { + } + GPTQ_CALL_IF_MOE(vllm::kU8B128, 16, 4, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 8, 256) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 8, 4, 128) + GPTQ_CALL_IF_MOE(vllm::kU8B128, 4, 8, 128) + else { + return false; + } + return true; +} + +} // namespace marlin_moe diff --git a/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h new file mode 100644 index 0000000000000000000000000000000000000000..f3018aa0c1ab7938ab99bbd0604d7b462d216721 --- /dev/null +++ b/csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h @@ -0,0 +1,18 @@ +#pragma once + +#include "marlin_moe_kernel.h" + +namespace marlin_moe { + +bool call_marlin_moe_kernel_ku8b128( + vllm::ScalarType const& q_type, int thread_n_blocks, int thread_k_blocks, + bool has_act_order, int group_blocks, int num_threads, int blocks, + int max_shared_mem, cudaStream_t stream, const int4* A_ptr, + const int4* B_ptr, int4* C_ptr, const int* sorted_ids_ptr, + const float* topk_weights_ptr, const int4* s_ptr, const int4* zp_ptr, + const int* g_idx_ptr, int* expert_offsets_ptr, int num_groups, + int expert_idx, int num_experts, int topk, int prob_m, int prob_n, + int prob_k, int tot_m, int* locks, bool replicate_input, bool apply_weights, + int m_block, int max_par, int cfg_max_m_blocks); + +} diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2db4e4196b6f00cbf2cfd3872b6046d05049b32 --- /dev/null +++ b/csrc/moe/marlin_moe_ops.cu @@ -0,0 +1,587 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include "core/exception.hpp" +#include "core/scalar_type.hpp" +#include "core/registration.h" +#include "marlin_kernels/marlin_moe_kernel_ku4b8.h" +#include "marlin_kernels/marlin_moe_kernel_ku8b128.h" +#include "marlin_kernels/marlin_moe_kernel_ku4.h" + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_moe { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + } + __syncthreads(); +} + +#else + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +typedef struct { + int max_m_blocks; + thread_config_t tb_cfg; +} exec_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N + {64, 64, 128}, // Reduce both 2X +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X + {64, 64, 128}, // Reduce N 4X, same K +}; + +int get_scales_cache_size(thread_config_t const& th_config, int prob_m, + int prob_n, int prob_k, int num_bits, int group_size, + bool has_act_order, bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = ceildiv(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = ceildiv(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = + tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 4; + + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * STAGES; + } +} + +bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int scales_cache_size, int max_shared_mem) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + + int b_size = (tb_k * tb_n / pack_factor) * 4; + + // Get A size + int m_blocks = ceildiv(prob_m, 16); + int tb_max_m = 16; + + while (true) { + if (m_blocks >= max_m_blocks) { + tb_max_m *= max_m_blocks; + break; + } + + max_m_blocks--; + if (max_m_blocks == 0) { + TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks); + } + } + + int a_size = (tb_max_m * tb_k) * 2; + + float pipe_size = (a_size + b_size) * STAGES; + + TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity + + return pipe_size < 0.95f * (max_shared_mem - scales_cache_size); +} + +bool is_valid_config(thread_config_t const& th_config, int max_m_blocks, + int prob_m, int prob_n, int prob_k, int num_bits, + int group_size, bool has_act_order, bool is_k_full, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Determine cache for scales + int scales_cache_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, + group_size, has_act_order, is_k_full); + + // Check that pipeline fits into cache + if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, scales_cache_size, max_shared_mem)) { + return false; + } + + return true; +} + +exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k, + int num_bits, int group_size, + bool has_act_order, bool is_k_full, + int max_shared_mem) { + int max_m_blocks = 4; + while (max_m_blocks > 0) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k, + num_bits, group_size, has_act_order, is_k_full, + max_shared_mem)) { + return exec_config_t{max_m_blocks, th_config}; + } + } + } + + max_m_blocks--; // Process less M blocks per invocation to reduce cache + // usage + } + + return exec_config_t{0, {-1, -1, -1}}; +} + +#define CALL_MOE_KERNEL_FUNCTION(KERNEL_FUNCTION) \ + else if (KERNEL_FUNCTION( \ + q_type, thread_n_blocks, thread_k_blocks, has_act_order, \ + group_blocks, num_threads, blocks, max_shared_mem, stream, \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + zp_ptr, g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par, \ + exec_cfg.max_m_blocks)) { \ + } + +void marlin_mm_moe(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, void* zp, + const void* g_idx, const void* perm, void* a_tmp, + void* expert_offsets, int prob_m, int prob_n, int prob_k, + void* workspace, vllm::ScalarType const& q_type, + bool has_act_order, bool is_k_full, bool has_zp, + int num_groups, int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int num_bits = q_type.size_bits(); + + // Set thread config + exec_config_t exec_cfg; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + exec_cfg = + exec_config_t{4, thread_config_t{thread_k, thread_n, USER_THREADS}}; + } else { + // Auto config + exec_cfg = + determine_thread_config(prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem); + } + + TORCH_CHECK(exec_cfg.max_m_blocks > 0 && + is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, + prob_m, prob_n, prob_k, num_bits, group_size, + has_act_order, is_k_full, max_shared_mem), + "Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks, + ", thread_k = ", exec_cfg.tb_cfg.thread_k, + ", thread_n = ", exec_cfg.tb_cfg.thread_n, + ", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", + prob_m, ", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits, + ", group_size = ", group_size, + ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, + ", max_shared_mem = ", max_shared_mem); + + int num_threads = exec_cfg.tb_cfg.num_threads; + thread_k = exec_cfg.tb_cfg.thread_k; + thread_n = exec_cfg.tb_cfg.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int tot_m = prob_m; + + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets_ptr = (int*)expert_offsets; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + int pack_factor = 32 / q_type.size_bits(); + + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = + (const int4*)B + (prob_n * prob_k / (pack_factor * 4)) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + const int* sorted_ids_ptr = (const int*)sorted_ids; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; + const int4* zp_ptr = + (const int4*)zp + num_groups * prob_n / (pack_factor * 4) * expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int tot_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < tot_m_blocks; + m_block += 4 * exec_cfg.max_m_blocks) { + if (false) { + } + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4b8) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku8b128) + CALL_MOE_KERNEL_FUNCTION(call_marlin_moe_kernel_ku4) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + } + } +} + +} // namespace marlin_moe + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + torch::Tensor& b_zeros, const torch::Tensor& g_idx, + const torch::Tensor& perm, torch::Tensor& workspace, + vllm::ScalarTypeTorchPtr const& b_q_type, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, + int64_t moe_block_size, bool replicate_input, bool apply_weights) { + bool has_zp = b_zeros.size(1) != 0; + if (has_zp) { + TORCH_CHECK( + *b_q_type == vllm::kU4, + "b_q_type must be u4 when has_zp = True. Got = ", b_q_type->str()); + } else { + TORCH_CHECK( + *b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128, + "b_q_type must be uint4b8 or uint8b128. Got = ", b_q_type->str()); + } + + int pack_factor = 32 / b_q_type->size_bits(); + + int max_par = 4; + + int dev = a.get_device(); + + auto options_dtype = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); + torch::Tensor a_tmp = + replicate_input ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(1) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), + "if is_k_full is false, has_act_order must be true"); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify b_zeros + if (has_zp) { + int rank = b_zeros.sizes().size(); + TORCH_CHECK(rank == 3, "b_zeros rank = ", rank, " is not 3"); + TORCH_CHECK(b_zeros.size(1) == num_groups, + "b_zeros dim 1 = ", b_zeros.size(1), + " is not num_groups = ", num_groups); + TORCH_CHECK(b_zeros.size(2) == size_n / pack_factor, + "b_zeros dim 2 = ", b_zeros.size(2), + " is not size_n / pack_factor = ", size_n / pack_factor); + } + + marlin_moe::marlin_mm_moe( + a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), + b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), + *b_q_type, has_act_order, is_k_full, has_zp, num_groups, group_size, + num_experts, topk, moe_block_size, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, + replicate_input, apply_weights); + return c; +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_gemm_moe", &marlin_gemm_moe); +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 86e42af44df151b897ce06c69ba88d5d5e26f66f..18fbc57ac78343cb0dbf303cc8dbffee0d19f101 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -7,6 +7,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + +#ifndef USE_ROCM + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, " + "__torch__.torch.classes._core_C.ScalarType b_q_type, int size_m, " + "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " + "int moe_block_size, bool replicate_input, bool apply_weights)" + " -> Tensor"); + // conditionally compiled so impl registration is in source file +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/ops.h b/csrc/ops.h index e44cf358a29f55d64b1754f44a5023034cd10c8e..93f36e67bee57a3d5d7afd67462c38e7071916f7 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -47,6 +47,27 @@ void paged_attention_v2_opt( const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); +void paged_attention_v1_opt_tc( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v2_opt_tc( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, double epsilon); @@ -96,21 +117,32 @@ void gelu_quick(torch::Tensor& out, torch::Tensor& input); void trans_w16_gemm(torch::Tensor dst, torch::Tensor src, int64_t row, int64_t col); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables); + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias); -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes); +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes); torch::Tensor awq_gemm(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, @@ -121,38 +153,16 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, - bool use_fp32_reduce); - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits); - -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k); +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); + +torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, + int64_t n); + +torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, + int64_t type, int64_t row); + +torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, + int64_t row); bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); @@ -161,30 +171,29 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k); +void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor const& scale); + torch::Tensor const& scale, + c10::optional const& azp); void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, - torch::Tensor& scales); + torch::Tensor& scales, + c10::optional const& azp); -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); +// torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, +// torch::Tensor b_gptq_qzeros, +// torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, +// bool use_exllama, int64_t bit); -torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, - torch::Tensor b_gptq_qzeros, - torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, - bool use_exllama, int64_t bit); - -void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); +// void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); // void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, // torch::Tensor const& scale); @@ -201,14 +210,40 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, + const torch::Tensor& A, const torch::Tensor& B, + const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, + bool delta_softplus, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + const torch::Tensor& ssm_states, int64_t pad_slot_id); + +void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_, + int64_t pad_slot_id); + +void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool silu_activation, int64_t pad_slot_id); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, const std::vector& offsets, int64_t rank, bool full_nvlink); -bool should_custom_ar(torch::Tensor& inp, int64_t max_size, int64_t world_size, - bool full_nvlink); void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out); void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, torch::Tensor& out); diff --git a/csrc/opt/layernorm_kernels_opt.cu b/csrc/opt/layernorm_kernels_opt.cu index b07ce5f41cf3c151ab3afd631c615efc350e203b..cdc4f4a97de8e4f002a29566ef73cf5cfce71553 100644 --- a/csrc/opt/layernorm_kernels_opt.cu +++ b/csrc/opt/layernorm_kernels_opt.cu @@ -6,13 +6,17 @@ #include #include #include "../dispatch_utils.h" -#include "../reduction_utils.cuh" + #ifndef USE_ROCM #include #include + #include + #include #else #include #include + #include + #include using __nv_bfloat16 = __hip_bfloat16; using __nv_bfloat162 = __hip_bfloat162; @@ -34,7 +38,11 @@ __global__ void rms_norm_kernel( const float x = (float)input[blockIdx.x * hidden_size + idx]; variance += x * x; } - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -231,12 +239,11 @@ fused_add_rms_norm_kernel( variance += temp.sum_squares(); residual_v[id] = temp; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } @@ -271,12 +278,11 @@ fused_add_rms_norm_kernel( variance += x * x; residual[blockIdx.x * hidden_size + idx] = z; } - /* Keep the following if-else block in sync with the - calculation of max_block_size in fused_add_rms_norm */ - if (num_tokens < 256) { - variance = blockReduceSum(variance); - } else - variance = blockReduceSum(variance); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + if (threadIdx.x == 0) { s_variance = rsqrtf(variance / hidden_size + epsilon); } diff --git a/csrc/permute_cols.cu b/csrc/permute_cols.cu new file mode 100644 index 0000000000000000000000000000000000000000..f51fa73298cc15b764504fc87c9580e0fd1a2d05 --- /dev/null +++ b/csrc/permute_cols.cu @@ -0,0 +1,88 @@ +#include + +#include +#include + +#include + +static constexpr int default_threads = 256; +static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +// Currently only supports 16bit types (since we permute half types) +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = std::max(finish_row - start_row, 0); + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +// More efficient version of A[..., perm] +// taken from gptq_marlin.cu +torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(A)); + auto dev = A.get_device(); + auto stream = at::cuda::getCurrentCUDAStream(dev); + + TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16, + "Currently only 16bit types are supported"); + TORCH_CHECK(A.is_contiguous(), "A must be contiguous"); + TORCH_CHECK(A.size(-1) % 8 == 0, + "A columns must be a multiple of 8 (128bits)"); + auto A_2d = A.view({-1, A.size(-1)}); + + torch::Tensor D = torch::empty_like(A); + int sms; + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + int block_rows = div_ceil(A_2d.size(0), sms); + permute_cols_kernel<<>>( + reinterpret_cast(A_2d.const_data_ptr()), + perm.const_data_ptr(), reinterpret_cast(D.mutable_data_ptr()), + A_2d.size(0), A_2d.size(1), block_rows); + return D; +} \ No newline at end of file diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4cd6dad0a9d68e00d085ba927ffb32..46fef79f439fb189eb02e8b3d303afd3ee919d53 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -12,13 +12,22 @@ namespace prepare_inputs { // template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { +__global__ void advance_step_flashattn_kernel( + int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, + int64_t const block_tables_stride) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } + int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { @@ -54,7 +63,7 @@ __global__ void advance_step_kernel(int num_seqs, int num_queries, slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; @@ -79,16 +88,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +__global__ void advance_step_flashinfer_kernel( + int num_threads, int num_seqs, int num_queries, int block_size, + long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x < num_query_blocks) { + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id < num_queries) { + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + // Update paged_kv_last_page_len + paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); + } + } +} + +__global__ void advance_step_flashinfer_indptr_kernel( + int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, + int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + + // Update paged_kv_indptr + if (idx < num_queries) { + int sum = 0; + for (int i = 0; i <= idx; ++i) { + sum += block_table_bound_ptr[i]; + } + paged_kv_indptr_ptr[idx + 1] = sum; + } +} + +__global__ void advance_step_flashinfer_indices_kernel( + int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const block_tables_stride, int* paged_kv_indices_ptr, + int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + int row = idx / block_tables_stride; + int col = idx % block_tables_stride; + + if (row < num_queries && col < block_table_bound_ptr[row]) { + paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = + block_tables_ptr[row * block_tables_stride + col]; + } + // if cudagraph, fill padded seqs with the last valid seq's indptr + if (num_queries < row && row <= num_seqs) { + paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; + } +} + +void advance_step_flashattn(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int if (logging) { - printf("advance_step:\n"); + printf("advance_step_flashattn:\n"); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); @@ -108,24 +192,126 @@ void advance_step(int num_seqs, int num_queries, int block_size, int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + advance_step_flashattn_kernel + <<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +void advance_step_flashinfer( + int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& paged_kv_indices, // type: int + torch::Tensor& paged_kv_indptr, // type: int + torch::Tensor& paged_kv_last_page_len, // type: int + torch::Tensor& block_table_bound) { // type: int + + if (logging) { + printf("advance_step_flashinfer:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + // at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); + verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); + verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, + at::kInt); + + verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + int threads; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); + if (logging) { + printf("launching kernel with %d blocks\n", blocks); + } + + // TODO(will): support arbitrary block_tables stride + if ((blocks * threads) / block_tables.stride(0) < num_queries) { + TORCH_CHECK(false, + "multi-step: not enough threads to map block_table to" + "FlashInfer's paged_kv_indices on GPU. Try reducing the number " + "of seqs,", + " increasing the block size or take smaller steps.", + " num_queries = ", num_queries, + " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + } + + advance_step_flashinfer_kernel<<>>( + threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), + reinterpret_cast(paged_kv_last_page_len.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indptr_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indices_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0), + reinterpret_cast(paged_kv_indices.data_ptr()), + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); -} \ No newline at end of file +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables) { + prepare_inputs::advance_step_flashattn( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables); +} + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { + prepare_inputs::advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, + paged_kv_indptr, paged_kv_last_page_len, block_table_bound); +} diff --git a/csrc/quantization/aqlm/gemm_kernels.cu b/csrc/quantization/aqlm/gemm_kernels.cu index 22da5e4f08a18a3024d7712f2469e69ea89f1200..79cd2c610b3c276ef16a0ae85aacd9b06fffba68 100644 --- a/csrc/quantization/aqlm/gemm_kernels.cu +++ b/csrc/quantization/aqlm/gemm_kernels.cu @@ -496,14 +496,14 @@ torch::Tensor code2x8_matmat(const torch::Tensor& input, } // Accumulate the partition sizes. -int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { +int4 accumulate_sizes(const std::vector& codebook_partition_sizes) { int4 cumulative_sizes; auto cumulative_size = &cumulative_sizes.x; - int i = 0; + size_t i = 0; int last = 0; - assert(codebook_partition_sizes.size(0) <= 4); - for (; i < codebook_partition_sizes.size(0); ++i, ++cumulative_size) { - *cumulative_size = codebook_partition_sizes[i].item() + last; + assert(codebook_partition_sizes.size() <= 4); + for (; i < codebook_partition_sizes.size(); ++i, ++cumulative_size) { + *cumulative_size = codebook_partition_sizes[i] + last; last = *cumulative_size; } // fill in the rest with unreachable. @@ -519,12 +519,12 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes) { torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, const torch::Tensor& codebooks, const torch::Tensor& scales, - const torch::Tensor& codebook_partition_sizes, + const std::vector& codebook_partition_sizes, const std::optional& bias) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); if (nbooks == 1 && entries == (1 << 16)) { @@ -541,13 +541,13 @@ torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, return {}; } -torch::Tensor aqlm_dequant(const torch::Tensor& codes, - const torch::Tensor& codebooks, - const torch::Tensor& codebook_partition_sizes) { +torch::Tensor aqlm_dequant( + const torch::Tensor& codes, const torch::Tensor& codebooks, + const std::vector& codebook_partition_sizes) { int4 cumulative_sizes = vllm::aqlm::accumulate_sizes(codebook_partition_sizes); - int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(0); + int const nbooks = codebooks.size(0) / codebook_partition_sizes.size(); int const entries = codebooks.size(1); const at::cuda::OptionalCUDAGuard device_guard(device_of(codes)); @@ -557,7 +557,8 @@ torch::Tensor aqlm_dequant(const torch::Tensor& codes, auto in_features = codes.size(1) * 8; auto out_features = codes.size(0); - assert(out_features = codebook_partition_sizes.sum().item()); + assert(out_features == std::accumulate(codebook_partition_sizes.begin(), + codebook_partition_sizes.end(), 0)); auto weights = torch::empty({out_features, in_features}, torch::TensorOptions() diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aa9511daa27729d839b6a8789eab8089ff671301..e9987535bd3eaa7b2e8c480aeb21073097b89eb7 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -3,16 +3,28 @@ #include #include "../../dispatch_utils.h" -#include "../../reduction_utils.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif static inline __device__ int8_t float_to_int8_rn(float x) { #ifdef USE_ROCM - static const float i8_min = + static constexpr auto i8_min = static_cast(std::numeric_limits::min()); - static const float i8_max = + static constexpr auto i8_max = static_cast(std::numeric_limits::max()); - // round + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. float dst = std::nearbyint(x); + // saturate dst = std::clamp(dst, i8_min, i8_max); return static_cast(dst); @@ -24,6 +36,59 @@ static inline __device__ int8_t float_to_int8_rn(float x) { #endif } +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + namespace vllm { template @@ -31,12 +96,36 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[i] = float_to_int8_rn(static_cast(input[i]) / scale); + } +} + +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); + auto const val = static_cast(input[i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[i] = quant_val; } } @@ -45,17 +134,24 @@ __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); + float val = static_cast(input[i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } - float const block_absmax_val_maybe = blockReduceMax(absmax_val); + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); __shared__ float block_absmax_val; if (tid == 0) { block_absmax_val = block_absmax_val_maybe; @@ -65,8 +161,63 @@ __global__ void dynamic_scaled_int8_quant_kernel( float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); + } +} + +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int64_t const token_idx = blockIdx.x; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[i] = quant_val; } } @@ -74,10 +225,12 @@ __global__ void dynamic_scaled_int8_quant_kernel( void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor const& scale) { + torch::Tensor const& scale, + c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -86,19 +239,29 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { - vllm::static_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scale.data_ptr(), hidden_size); + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } void dynamic_scaled_int8_quant( torch::Tensor& out, // [..., hidden_size] torch::Tensor const& input, // [..., hidden_size] - torch::Tensor& scales) { + torch::Tensor& scales, c10::optional const& azp) { TORCH_CHECK(input.is_contiguous()); TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); int const hidden_size = input.size(-1); int const num_tokens = input.numel() / hidden_size; @@ -107,9 +270,17 @@ void dynamic_scaled_int8_quant( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); VLLM_DISPATCH_FLOATING_TYPES( input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { - vllm::dynamic_scaled_int8_quant_kernel - <<>>(input.data_ptr(), - out.data_ptr(), - scales.data_ptr(), hidden_size); + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } }); } diff --git a/csrc/quantization/cutlass_w8a8/Epilogues.md b/csrc/quantization/cutlass_w8a8/Epilogues.md new file mode 100644 index 0000000000000000000000000000000000000000..aae04157b10de30de93fa736da1a7782c116eca2 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/Epilogues.md @@ -0,0 +1,147 @@ +# CUTLASS Epilogues + +## Introduction +This document describes the various CUTLASS epilogues implemented for fusing de-quantization operations onto GEMMs. + +Currently, we only support symmetric quantization for weights, +and symmetric and asymmetric quantization for activations. +Both can be quantized per-tensor or per-channel (weights) / per-token (activations). + +There are 4 epilogues: +1. ScaledEpilogue: symmetric quantization for activations, no bias. +1. ScaledEpilogueBias: symmetric quantization for activations, supports bias. +1. ScaledEpilogueAzp: asymmetric per-tensor quantization for activations, supports bias. +1. ScaledEpilogueAzpPerToken: asymmetric per-token quantization for activations, supports bias. + +We do not have epilogues for asymmetric quantization of activations without bias in order to reduce final binary size. +Instead, if no bias is passed, the epilogue will use 0 as the bias. +That induces a redundant addition operation (and runtime check), but the performance impact is minor. + +## Underlying Linear Algebra + +More details available in the [Activation Quantization RFC](https://github.com/vllm-project/vllm/issues/3975). + +If $` \widehat X `$ is the quantized $` X `$, our matrices become the following + +```math +A = s_a (\widehat A - J_a z_a) +``` +```math +B = s_b \widehat B +``` +```math +D = A B + C +``` +```math +D = s_a s_b \widehat D + C +``` + +Here, D is the output of the GEMM, and C is the bias. +A is the activations and supports asymmetric quantization, +and B is the weights and only supports symmetric quantization. +$ s_a $ and $s_b$ are the scales for activations and weights, respectively. +$ z_a $ is the zero-point for activations, and $ J_a $ is the matrix of all ones with dimensions of A. +Additional epilogues would be required to support asymmetric quantization for weights. + +Expanding further, we can calculate $` \widehat D `$ as follows: + +```math +A B = s_a ( \widehat A - J_a z_a ) s_b \widehat B +``` +```math +A B = s_a s_b \left( \widehat A \widehat B - J_a z_a \widehat B \right) +``` +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` + +Note that $` \widehat A \widehat B `$ is the raw output of the GEMM, +and $` J_a \widehat B `$ is known ahead of time. +Each row of it is equal to $` \mathbf 1 \widehat B `$, which is a row-vector of column sums of $` \widehat B `$. + +## Epilogues + +### ScaledEpilogue +This epilogue computes the symmetric quantization for activations without bias, meaning $` C = 0 `$ and $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D +``` +```math +D = s_a s_b \widehat A \widehat B +``` + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). + +### ScaledEpilogueBias +This epilogue computes the symmetric quantization for activations with bias, meaning $` z_a = 0 `$. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \widehat A \widehat B + C +``` + + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +### ScaledEpilogueAzp +This epilogue computes the asymmetric per-tensor quantization for activations with bias. +The output of the GEMM is: + +```math +\widehat D = \widehat A \widehat B - z_a J_a \widehat B +``` +```math +D = s_a s_b \widehat D + C +``` +```math +D = s_a s_b \left( \widehat A \widehat B - z_a J_a \widehat B \right) + C +``` + +Because $` z_a `$ is a scalar, the zero-point term $` z_a J_a \widehat B `$ has every row equal to $` z_a \mathbf 1 B `$. +That is precomputed and stored in `azp_with_adj` as a row-vector. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-tensor as the zero-points are per-tensor. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_with_adj` is the precomputed zero-point term ($` z_a J_a \widehat B `$), is per-channel (row-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_with_adj` term offline and pass it to the kernel. + +### ScaledEpilogueAzpPerToken +This epilogue computes the asymmetric per-token quantization for activations with bias. + +The output of the GEMM is the same as above, but the $` z_a `$ is a column-vector. +That means the zero-point term $` z_a J_a \widehat B `$ becomes an outer product of $` z_a `$ and $` \mathbf 1 \widehat B `$. + +Epilogue parameters: +- `scale_a` is the scale for activations, can be per-tensor (scalar) or per-token (column-vector). + - Generally this will be per-token as the zero-points are per-token. +- `scale_b` is the scale for weights, can be per-tensor (scalar) or per-channel (row-vector). +- `azp_adj` is the precomputed zero-point adjustment term ($` \mathbf 1 \widehat B `$), is per-channel (row-vector). +- `azp` is the zero-point (`z_a`), is per-token (column-vector). +- `bias` is the bias, is always per-channel (row-vector). + +To use these kernels efficiently, users must precompute the `azp_adj` term offline and pass it to the kernel. + +The epilogue performs the following computation (where `Dq` is the raw quantized output of the GEMM): +``` +out = scale_a * scale_b * (Dq - azp_adj * azp) + bias +``` diff --git a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp index c4c6b18654eedbf1971822b4355aad4a29f0ab0c..d407d66ab2aa6c60adb9bb21784f91e8105e3140 100644 --- a/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp +++ b/csrc/quantization/cutlass_w8a8/broadcast_load_epilogue_c2x.hpp @@ -207,6 +207,156 @@ struct VisitorRowOrScalarBroadcast { }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +// This is a modified RowBroadcast that will broadcast 0 if ptr_row is null +template< + class ThreadMap, + class Element, + class StrideMNL +> +struct VisitorRowOrZeroBroadcast { + + // This struct has been modified to remove null_default (because it's always 0) + struct Arguments { + Element const* ptr_row = nullptr; + StrideMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + struct SharedStorage {}; + + // Global load type + static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; + using VecType = uint_bit_t; + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast() { } + + CUTLASS_HOST_DEVICE + VisitorRowOrZeroBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + template + struct Callbacks : EmptyCallbacks { + CUTLASS_DEVICE + Callbacks( + GTensor&& tC_gRow, + RTensor&& tC_rRow, + CTensor&& tC_cRow, + ProblemShape problem_shape, + Params const* params_ptr + ): + tC_gRow(cute::forward(tC_gRow)), + tC_rRow(cute::forward(tC_rRow)), + tC_cRow(cute::forward(tC_cRow)), + n(get<1>(problem_shape)), + params_ptr(params_ptr) { } + + GTensor tC_gRow; + RTensor tC_rRow; + CTensor tC_cRow; + Params const* params_ptr; + int n; + + // This function is modified from VisitorRowBroadcast + CUTLASS_DEVICE void + begin_epilogue() { + clear(tC_rRow); + auto src_v = filter(tC_gRow); + auto coord_v = filter(tC_cRow); + auto dst_v = filter(tC_rRow); + + if (params_ptr->ptr_row != nullptr) { + // In this case we are loading from a row vector and broadcasting + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + bool guard = get<1>(coord_v(i)) < n; + cutlass::arch::global_load( + dst_v(i), (void const*)&src_v(i), guard); + } + } else { + // In this case we are broadcasting 0 + VecType filled_vec; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < VecLength; i++) { + reinterpret_cast(&filled_vec)[i] = Element{0}; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src_v); ++i) { + if (get<1>(coord_v(i)) < n) { + dst_v(i) = filled_vec; + } + } + } + } + + template + CUTLASS_DEVICE auto // returns an Array + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + Array const& frg_acc) { + Tensor rRow_frg = recast>(coalesce(tC_rRow)); + return rRow_frg(column_idx); + } + }; + + template + CUTLASS_DEVICE auto + get_callbacks( + gemm::GemmCoord threadblock_tile_offset, + int thread_idx, + ProblemShape problem_shape + ) { + Tensor mRow = make_tensor( + make_gmem_ptr(params_ptr->ptr_row), + problem_shape, + params_ptr->dRow); + + // VECTOR, FRAGMENT_COLUMN + Tensor tC_gRow = recast( + ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) + )(_,_,_0{},_0{},_0{},_0{}); + Tensor tC_rRow = make_tensor_like(tC_gRow); + + // Generate the pred tensor + Tensor cRow = make_identity_tensor(mRow.shape()); + Tensor tC_cRow = outer_partition( + ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), + Shape>{}, + (_0{}) + ); + + return Callbacks< + decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_cRow), ProblemShape>( + cute::move(tC_gRow), + cute::move(tC_rRow), + cute::move(tC_cRow), + problem_shape, + params_ptr + ); + } + +}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// // Column vector broadcast @@ -217,7 +367,7 @@ template< > struct VisitorColOrScalarBroadcast { - // This struct has been modified to have a bool indicating that ptr_col is a + // This struct has been modified to have a bool indicating that ptr_col is a // scalar that must be broadcast. struct Arguments { Element const* ptr_col = nullptr; diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu index 8d0dfee7bf23a62090d56f2fcad8d7a563565657..ee801e16573d4e1af1588fe05302404833d5d677 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu @@ -50,6 +50,25 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a, } } +void cutlass_scaled_mm_azp_sm75(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + torch::Tensor const& azp_adj, + c10::optional const& azp, + c10::optional const& bias) { + TORCH_CHECK(a_scales.dtype() == torch::kFloat32); + TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + + if (azp) { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, *azp, bias); + } else { + return cutlass_scaled_mm_sm75_epilogue( + out, a, b, a_scales, b_scales, azp_adj, bias); + } +} + template