Unverified Commit 81e6345d authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

LLM.int8() Refactoring: Part 1 (#1401)



* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation

* Fix unintended change

* New naive mm_dequant kernel for row-major; cleanup

* fix

* int8 refactor: initial sparse decomp, cleanup

* Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup

* int8: inference optimizations, some cleanup

* int8: more tests passing, cleanup

* int8 - more cleanup, most tests passing

* int8: specify CUDA stream for int8 ops

* perf: reduce overhead from getting cudaStream ptr

* Mark some functions for deprecation.

* int8 sparse decomp: small perf improvement

* update setup.py

* Update bitsandbytes/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn

* int8 cleanup

* Ignore ruff rule ISC001 (incompatible with formatter)

* add comment

* int8 more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8: rename / deprecate old fn signatures

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* type annotation

* format update

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Add comment to explain division optimization

* more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Type annotations, cleanup

* remove unused kernels; improved type annotations

* small perf optimization for single-GPU systems

* small perf optimization for single-GPU systems

* update docstrings

* Improve docs and tests

* Update docstring

* Update test

* add benchmarking script

* test cleanup: add deprecated marker, move benchmarks out

* Add int8 dequant function; misc improvements

* int8 matmul fallback for inner dims not divisible by 4

* improve register usage of kInt8VectorQuant - especially for A100/H100

* disable fail-fast for package build

* maxwell compat

* ptxas verbose

* docs update

* doc update

* backward fix

* Bugfix sparse decomp

* Int8 fix for PEFT OLoRA init

* Fix test for deprecated spmm_coo

* test improvement

* doc update

* typo

* doc cleanup

* docs

* add inference benchmark script

* Add benchmarks, doc update

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
parent 7dca7004
...@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90" ...@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????} [[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???} [[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image" echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \ "apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ && cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ." && cmake --build ."
else else
pip install cmake==3.28.3 pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release cmake --build . --config Release
fi fi
done
output_dir="output/${build_os}/${build_arch}" output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}" mkdir -p "${output_dir}"
......
...@@ -60,6 +60,7 @@ jobs: ...@@ -60,6 +60,7 @@ jobs:
## ##
build-shared-libs-cuda: build-shared-libs-cuda:
strategy: strategy:
fail-fast: false
matrix: matrix:
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64] arch: [x86_64, aarch64]
......
...@@ -22,9 +22,11 @@ CMakeFiles/ ...@@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/ bitsandbytes.dir/
Debug/ Debug/
Release/ Release/
cmake-build-*/
# IDE local files # IDE local files
.vs/ .vs/
.idea/
# Distribution / packaging # Distribution / packaging
.Python .Python
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release` # For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables # You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path. # is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
...@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") ...@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE) if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" ) message(FATAL_ERROR "CUDA is not supported on macOS" )
endif() endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON) set(BUILD_CUDA ON)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps") elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
...@@ -166,9 +163,6 @@ if(BUILD_CUDA) ...@@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES}) list(APPEND SRC_FILES ${CUDA_FILES})
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA) add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS) elseif(BUILD_MPS)
if(NOT APPLE) if(NOT APPLE)
...@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include) ...@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)
if(BUILD_CUDA) if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()
set_target_properties(bitsandbytes set_target_properties(bitsandbytes
PROPERTIES PROPERTIES
CUDA_SEPARABLE_COMPILATION ON CUDA_SEPARABLE_COMPILATION ON
......
# Benchmarking
## Inference
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
See the example script in
[inference_benchmark.py](inference_benchmark.py).
### Results (as of v0.45.0)
Our overall benchmarking results compared with v0.44.1 provide the following insights:
#### LLM.int8()
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.
#### NF4/FP4
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.
Summaries with the benchmarking results are provided below.
#### NVIDIA T4 16GB
<details>
<summary>Qwen 2.5 3B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
</details>
#### NVIDIA RTX 4090 24GB
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
</details>
<details>
<summary>Qwen 2.5 14B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
</details>
#### NVIDIA H100 80GB SXM
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
</details>
<details>
<summary>Qwen 2.5 32B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| BF16 | 1 | 0.0508 | 19.67 |
| NF4 | 1 | 0.0707 | 14.14 |
| NF4+DQ | 1 | 0.0860 | 11.63 |
| INT8 | 1 | 0.1031 | 9.70 |
| INT8+Decomp | 1 | 0.1820 | 5.49 |
| BF16 | 8 | 0.0525 | 152.50 |
| NF4 | 8 | 0.1154 | 69.35 |
| NF4+DQ | 8 | 0.1209 | 66.19 |
| INT8 | 8 | 0.1078 | 74.24 |
| INT8+Decomp | 8 | 0.1958 | 40.87 |
| BF16 | 32 | 0.0547 | 584.54 |
| NF4 | 32 | 0.1246 | 256.84 |
| NF4+DQ | 32 | 0.1298 | 246.47 |
| INT8 | 32 | 0.1056 | 302.96 |
| INT8+Decomp | 32 | 0.2027 | 157.83 |
</details>
<details>
<summary>Llama 3.1 70B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| NF4 | 1 | 0.0833 | 12.00 |
| NF4+DQ | 1 | 0.1052 | 9.50 |
| INT8 | 1 | 0.1294 | 7.73 |
| INT8+Decomp | 1 | 0.1985 | 5.04 |
| NF4 | 8 | 0.2348 | 34.07 |
| NF4+DQ | 8 | 0.2423 | 33.01 |
| INT8 | 8 | 0.1313 | 60.94 |
| INT8+Decomp | 8 | 0.2052 | 38.99 |
| NF4 | 32 | 0.2491 | 128.46 |
| NF4+DQ | 32 | 0.2580 | 124.04 |
| INT8 | 32 | 0.1314 | 243.45 |
| INT8+Decomp | 32 | 0.2189 | 146.19 |
</details>
#### Software Configuration
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).
For all hardware configurations, we used the following dependencies:
* `transformers==4.46.3`
* `accelerate==1.1.1`
* `tokenizers==0.20.3`
* `torch==2.5.1`
* `bitsandbytes==0.44.1`
* `bitsandbytes==0.45.0.dev`
In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.
"""
Inference benchmarking tool.
Requirements:
transformers
accelerate
bitsandbytes
optimum-benchmark
Usage: python inference_benchmark.py model_id
options:
-h, --help show this help message and exit
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
--bf16
--fp16
--nf4
--nf4-dq
--int8
--int8-decomp
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
"""
import argparse
from pathlib import Path
from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging
import torch
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
WEIGHTS_CONFIGS = {
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
"nf4": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": False,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"nf4-dq": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"int8-decomp": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
},
},
"int8": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 0.0,
},
},
}
if __name__ == "__main__":
setup_logging(level="INFO")
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
parser.add_argument(
"--configs",
nargs="+",
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
default=["nf4", "int8", "int8-decomp"],
)
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")
parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
parser.add_argument("--input-length", type=int, default=64)
parser.add_argument("--out-dir", type=str, default="reports")
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
"""
Basic benchmark for text generation.
Usage: python benchmarking/int8/int8_benchmark.py
"""
import time
import torch
from torch.profiler import ProfilerActivity, profile
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MAX_NEW_TOKENS = 128
model_name = "meta-llama/Llama-3.1-8B"
text = "Below is a question. I need an answer.\n\nExplain machine learning: "
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
),
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
print(model)
# warmup
print("Warmup...")
for i in range(3):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
print("Profiler starting...")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_modules=True,
with_stack=True,
) as prof:
model.generate(input_ids, max_new_tokens=1)
print(
prof.key_averages().table(
sort_by="cpu_time_total",
max_name_column_width=50,
top_level_events_only=True,
row_limit=50,
)
)
torch.cuda.synchronize()
print("Generating...")
num = 0
time_1 = time.time()
for i in range(5):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
num += len(generated_ids[0])
print("=" * 40)
print(f"Example:\n{tokenizer.decode(generated_ids[0])}")
print("=" * 40)
print(f"Speed: {num/(time.time() - time_1)}token/s")
"""
Extracted from tests/test_functional.py
Note: This feature is currently unused! It is kept here for archival purposes.
Usage: pytest benchmarking/int8/row_scale_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
[
pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
],
)
@pytest.mark.skip("Row scale has some bugs for ampere")
@pytest.mark.benchmark
def test_row_scale_bench(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2)
torch.cuda.synchronize()
print("vector-wise", time.time() - t0)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/int8/training_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
],
)
@pytest.mark.benchmark
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print("")
# torch.cuda.synchronize()
## warmup
# for i in range(100):
# torch.matmul(A, w1.t())
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
grad = grad.view(-1, grad.shape[-1]).contiguous()
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
# torch.cuda.empty_cache()
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/matmul_benchmark.py
"""
import time
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
# pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"),
pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"),
# pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"),
# pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k")
],
)
@pytest.mark.benchmark
def test_bench_matmul(batch, seq, model, hidden):
iters = 1000
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
B_fp4, state = F.quantize_fp4(B)
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
# linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
# torch.cuda.synchronize()
# print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
# torch.cuda.synchronize()
# print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize()
print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(
f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(
f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
CB, SCB, _ = F.int8_vectorwise_quant(B)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
out32 = F.int8_linear_matmul(CA, CB)
torch.cuda.synchronize()
print(
f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# C32A, SA = F.transform(CA, "col32")
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
# C32A, SA = F.transform(CA, "col32")
# CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1)
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
# torch.cuda.synchronize()
# print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
# torch.cuda.synchronize()
# print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# linear8bit_train(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# linear8bit_train_thresh(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from math import prod
import operator
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import warnings import warnings
from warnings import warn from warnings import warn
import torch import torch
from typing_extensions import deprecated
import bitsandbytes.functional as F import bitsandbytes.functional as F
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
...@@ -104,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - ...@@ -104,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous() return outputs.reshape(rows, cols).contiguous()
@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
...@@ -215,6 +213,7 @@ bmm_cublas = MatMul8bit.apply ...@@ -215,6 +213,7 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool: def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel""" """check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5): if torch.cuda.get_device_capability(device=device) < (7, 5):
...@@ -226,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -226,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool:
return True return True
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format): def _get_tile_size(format):
assert format in ( assert format in (
"col_turing", "col_turing",
...@@ -234,6 +234,7 @@ def _get_tile_size(format): ...@@ -234,6 +234,7 @@ def _get_tile_size(format):
return (8, 32) if format == "col_turing" else (32, 32) return (8, 32) if format == "col_turing" else (32, 32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device): def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad(): with torch.no_grad():
...@@ -243,27 +244,28 @@ def get_tile_inds(format, device): ...@@ -243,27 +244,28 @@ def get_tile_inds(format, device):
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False force_no_igemmlt: bool = False
CB = None
CxB = None
SB = None
SCB = None
CxBt = None CB: Optional[torch.Tensor] = None
SBt = None CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
CBt = None SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None
subB = None subB: Optional[torch.Tensor] = None
outlier_pool = None outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False has_accumulated_gradients = False
threshold = 0.0 threshold = 0.0
idx = None idx: Optional[torch.Tensor] = None
is_training = True is_training = True
has_fp16_weights = True has_fp16_weights = True
memory_efficient_backward = False
use_pool = False use_pool = False
formatB = F.get_special_format_str() formatB = "row" # TODO: Deprecate/remove
def reset_grads(self): def reset_grads(self):
self.CB = None self.CB = None
...@@ -283,12 +285,17 @@ class MatmulLtState: ...@@ -283,12 +285,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): def forward(
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
):
state = state or MatmulLtState()
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -301,123 +308,80 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -301,123 +308,80 @@ class MatMul8bitLt(torch.autograd.Function):
else: else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16 # Cast A to fp16
if A.dtype != torch.float16: if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1]) A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None: # 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if state.has_fp16_weights: if ctx.needs_input_grad[1]:
idx = torch.unique(coo_tensorA.colidx).long() # Slower path
CA[:, idx] = 0 CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
if not state.has_fp16_weights and state.CxB is None and using_igemmlt: # Fast path
state.CxB, state.SB = F.transform(state.CB, to_order=formatB) CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
subA = None CAt = SCAt = None
# 2. Quantize B has_grad = False
if state.has_fp16_weights:
has_grad = True if (getattr(B, "grad", None) is not None) else False if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: if is_transposed:
B = B.contiguous() B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads() state.reset_grads()
(
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else:
has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
if state.CxB is not None:
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
else:
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) # 2. Quantize B
CA[:, state.idx.long()] = 0 state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
shapeB = state.SB[0] if state.SB else B.shape # Handle sparse decomposition. In some instances, we may have not found any
# outlier columns at all. In that case, we'll skip this part completely.
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
state.idx = outlier_cols
if len(input_shape) == 3: # Zero out the outliers in the transposed 8bit inputs.
output_shape = (input_shape[0], input_shape[1], shapeB[0]) if CAt is not None:
else: CAt[:, state.idx] = 0
output_shape = (input_shape[0], shapeB[0])
# Extract the input outliers in original precision
# 3. Matmul subA = A[:, state.idx].contiguous()
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t()
else:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
else: else:
A_wo_outliers = A.clone() subA = None
if state.idx is not None:
A_wo_outliers[:, state.idx.long()] = 0 # 3. Int8 Matmul
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype)) out32 = F.int8_linear_matmul(CA, state.CB)
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None: # Dequantize matmul result
output = output.add_(bias) if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if subA is not None and state.subB is not None:
output += torch.matmul(subA, state.subB) output = output.addmm(subA, state.subB)
# 5. Save state # 5. Save state
ctx.state = state ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
...@@ -425,23 +389,27 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -425,23 +389,27 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensors = (CAt, subA, A) ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx) ctx.tensor_states = (SCAt, state.idx)
else: else:
ctx.tensors = [None, None, A] ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x output_shape = (*input_shape[:-1], state.CB.shape[0])
return clone_func(output.view(output_shape))
if len(input_shape) == 3:
return output.reshape(output_shape)
return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty: if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB state: MatmulLtState = ctx.state
state = ctx.state
grad_A = grad_B = grad_bias = None grad_A = grad_B = grad_bias = None
if req_gradBias: if req_gradBias:
...@@ -452,35 +420,20 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -452,35 +420,20 @@ class MatMul8bitLt(torch.autograd.Function):
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True) Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
if state.CBt is not None: if state.CB is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception("State must contain either CBt or CB or CxB matrix for backward") raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
...@@ -548,7 +501,7 @@ def matmul( ...@@ -548,7 +501,7 @@ def matmul(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None, bias: Optional[torch.Tensor] = None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
...@@ -561,9 +514,10 @@ def matmul_4bit( ...@@ -561,9 +514,10 @@ def matmul_4bit(
B: torch.Tensor, B: torch.Tensor,
quant_state: F.QuantState, quant_state: F.QuantState,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
bias=None, bias: Optional[torch.Tensor] = None,
): ):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn( warn(
......
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes as ct import ctypes as ct
import logging import logging
import os import os
...@@ -37,11 +19,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -37,11 +19,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path. The library is not guaranteed to exist at the returned path.
""" """
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
...@@ -67,6 +45,9 @@ class BNBNativeLibrary: ...@@ -67,6 +45,9 @@ class BNBNativeLibrary:
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._lib, item) return getattr(self._lib, item)
def __getitem__(self, item):
return getattr(self._lib, item)
class CudaBNBNativeLibrary(BNBNativeLibrary): class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True compiled_with_cuda = True
...@@ -114,6 +95,6 @@ python -m bitsandbytes ...@@ -114,6 +95,6 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""", """,
) )
...@@ -11,7 +11,7 @@ class CUDASpecs: ...@@ -11,7 +11,7 @@ class CUDASpecs:
cuda_version_tuple: Tuple[int, int] cuda_version_tuple: Tuple[int, int]
@property @property
def has_cublaslt(self) -> bool: def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5) return self.highest_compute_capability >= (7, 5)
......
...@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
# 7.5 is the minimum CC for cublaslt # 7.5 is the minimum CC for int8 tensor cores
if not cuda_specs.has_cublaslt: if not cuda_specs.has_imma:
print_dedented( print_dedented(
""" """
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
......
This diff is collapsed.
...@@ -16,7 +16,6 @@ from bitsandbytes.functional import QuantState ...@@ -16,7 +16,6 @@ from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer, OutlierTracer,
) )
...@@ -481,11 +480,8 @@ class Linear4bit(nn.Linear): ...@@ -481,11 +480,8 @@ class Linear4bit(nn.Linear):
x = x.to(self.compute_dtype) x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype) return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return out
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
...@@ -570,11 +566,11 @@ class LinearNF4(Linear4bit): ...@@ -570,11 +566,11 @@ class LinearNF4(Linear4bit):
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
cls, cls,
data=None, data: Optional[torch.Tensor] = None,
requires_grad=True, requires_grad=True,
has_fp16_weights=False, has_fp16_weights=False,
CB=None, CB: Optional[torch.Tensor] = None,
SCB=None, SCB: Optional[torch.Tensor] = None,
): ):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -588,12 +584,9 @@ class Int8Params(torch.nn.Parameter): ...@@ -588,12 +584,9 @@ class Int8Params(torch.nn.Parameter):
if self.has_fp16_weights: if self.has_fp16_weights:
return super().cuda(device) return super().cuda(device)
else: else:
# we store the 8-bit rows-major weight # We quantize the weight and store in 8bit row-major
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device) B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
del CBt
del SCBt
self.data = CB self.data = CB
self.CB = CB self.CB = CB
self.SCB = SCB self.SCB = SCB
...@@ -888,7 +881,6 @@ class Linear8bitLt(nn.Linear): ...@@ -888,7 +881,6 @@ class Linear8bitLt(nn.Linear):
output_features: int, output_features: int,
bias=True, bias=True,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
device=None, device=None,
...@@ -905,13 +897,12 @@ class Linear8bitLt(nn.Linear): ...@@ -905,13 +897,12 @@ class Linear8bitLt(nn.Linear):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
self.state.threshold = threshold self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
...@@ -928,29 +919,19 @@ class Linear8bitLt(nn.Linear): ...@@ -928,29 +919,19 @@ class Linear8bitLt(nn.Linear):
param_from_weight = getattr(self.weight, scb_name) param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state # case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name) param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None
key_name = prefix + f"{scb_name}" key_name = prefix + f"{scb_name}"
# We now only save in row-major. This format information is stored for backwards compatibility.
format_name = prefix + "weight_format" format_name = prefix + "weight_format"
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if param_from_weight is not None: if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8) destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None: elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
weights_format = self.state.formatB destination[format_name] = torch.tensor(0, dtype=torch.uint8)
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
...@@ -1008,12 +989,9 @@ class Linear8bitLt(nn.Linear): ...@@ -1008,12 +989,9 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights and self.state.CB is not None:
if self.state.CB is not None and self.state.CxB is not None: self.weight.data = self.state.CB
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out return out
......
...@@ -184,9 +184,9 @@ class MatMulFP8Global(torch.autograd.Function): ...@@ -184,9 +184,9 @@ class MatMulFP8Global(torch.autograd.Function):
class SwitchBackBnb(torch.autograd.Function): class SwitchBackBnb(torch.autograd.Function):
@staticmethod @staticmethod
# TODO: the B008 on the line below is a likely bug; the current implementation will def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None):
# have each SwitchBackBnb instance share a single MatmulLtState instance!!! state = state or MatmulLtState()
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -204,7 +204,6 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -204,7 +204,6 @@ class SwitchBackBnb(torch.autograd.Function):
# 3. Matmul # 3. Matmul
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
# 5. Save state # 5. Save state
formatB = state.formatB
input_shape = A.shape input_shape = A.shape
if state.outlier_pool is None: if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance() state.outlier_pool = GlobalOutlierPooler.get_instance()
...@@ -216,25 +215,21 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -216,25 +215,21 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and outlier_cols is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long() idx = outlier_cols
CA[:, idx] = 0 CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx] subA = A[:, idx]
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
else: else:
if state.CxB is None: if state.SB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions state.SB = (state.CB.shape, "row")
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
# print('A shape', A.shape) if not state.has_fp16_weights and state.SB is None:
if not state.has_fp16_weights and state.CxB is None: state.SB = (state.CB.shape, "row")
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
# 2. Quantize B # 2. Quantize B
...@@ -245,34 +240,26 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -245,34 +240,26 @@ class SwitchBackBnb(torch.autograd.Function):
if is_transposed: if is_transposed:
B = B.contiguous() B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.SB is None:
state.reset_grads() state.reset_grads()
( (
CB, state.CB,
state.CBt, state.CBt,
state.SCB, state.SCB,
state.SCBt, state.SCBt,
coo_tensorB, _,
) = F.double_quant(B.to(torch.float16)) ) = F.int8_double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB) state.SB = (state.CB.shape, "row")
else: else:
has_grad = False has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights: if outlier_cols is not None and not state.has_fp16_weights:
# extract outliers # extract outliers
state.idx = outlier_cols
outlier_idx = torch.unique(coo_tensorA.colidx) outliers = state.CB[:, state.idx.long()].clone()
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
shapeB = state.SB[0] shapeB = state.SB[0]
...@@ -283,25 +270,22 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -283,25 +270,22 @@ class SwitchBackBnb(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0]) output_shape = (input_shape[0], shapeB[0])
# 3. Matmul # 3. Matmul
C32A, SA = F.transform(CA, "col32") out32 = F.int8_linear_matmul(CA, state.CB)
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here # we apply the fused bias here
if bias is None or bias.dtype == torch.float16: if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
output = output.to(A.dtype)
else: # apply bias separately else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype)
output = output.to(A.dtype).add_(bias) output.add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if outlier_cols is not None and subA is not None:
output += torch.matmul(subA, state.subB) output += torch.matmul(subA, state.subB)
# 5. Save state # 5. Save state
ctx.state = state ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
...@@ -321,10 +305,10 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -321,10 +305,10 @@ class SwitchBackBnb(torch.autograd.Function):
if ctx.is_empty: if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state state = ctx.state
grad_A = grad_B = grad_bias = None grad_A = grad_B = grad_bias = None
...@@ -336,7 +320,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -336,7 +320,7 @@ class SwitchBackBnb(torch.autograd.Function):
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
# print('back A shape', A.shape) # print('back A shape', A.shape)
...@@ -344,16 +328,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -344,16 +328,7 @@ class SwitchBackBnb(torch.autograd.Function):
grad_B = torch.matmul(grad_output.t(), A) grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: if req_gradA:
if state.CBt is not None: if state.CB is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
......
#pragma once
// TODO: Let's make some of these constexpr and put in a namespace.
#define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#define BNB_WARP_SIZE 32
// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536
#else
#define BNB_MAX_THREADS_PER_SM 2048
#endif
// Maximum resident warps per SM is always directly related to the number of threads.
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
// Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
#define BNB_MAX_BLOCKS_PER_SM 16
#else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#endif
This diff is collapsed.
...@@ -112,12 +112,12 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index ...@@ -112,12 +112,12 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16( template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment