Unverified Commit 81e6345d authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

LLM.int8() Refactoring: Part 1 (#1401)



* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation

* Fix unintended change

* New naive mm_dequant kernel for row-major; cleanup

* fix

* int8 refactor: initial sparse decomp, cleanup

* Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup

* int8: inference optimizations, some cleanup

* int8: more tests passing, cleanup

* int8 - more cleanup, most tests passing

* int8: specify CUDA stream for int8 ops

* perf: reduce overhead from getting cudaStream ptr

* Mark some functions for deprecation.

* int8 sparse decomp: small perf improvement

* update setup.py

* Update bitsandbytes/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn

* int8 cleanup

* Ignore ruff rule ISC001 (incompatible with formatter)

* add comment

* int8 more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8: rename / deprecate old fn signatures

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* type annotation

* format update

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Add comment to explain division optimization

* more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Type annotations, cleanup

* remove unused kernels; improved type annotations

* small perf optimization for single-GPU systems

* small perf optimization for single-GPU systems

* update docstrings

* Improve docs and tests

* Update docstring

* Update test

* add benchmarking script

* test cleanup: add deprecated marker, move benchmarks out

* Add int8 dequant function; misc improvements

* int8 matmul fallback for inner dims not divisible by 4

* improve register usage of kInt8VectorQuant - especially for A100/H100

* disable fail-fast for package build

* maxwell compat

* ptxas verbose

* docs update

* doc update

* backward fix

* Bugfix sparse decomp

* Int8 fix for PEFT OLoRA init

* Fix test for deprecated spmm_coo

* test improvement

* doc update

* typo

* doc cleanup

* docs

* add inference benchmark script

* Add benchmarks, doc update

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
parent 7dca7004
......@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
done
if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ."
else
pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release
fi
output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
......
......@@ -60,6 +60,7 @@ jobs:
##
build-shared-libs-cuda:
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64]
......
......@@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/
Debug/
Release/
cmake-build-*/
# IDE local files
.vs/
.idea/
# Distribution / packaging
.Python
......
......@@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
......@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" )
endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON)
set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" )
......@@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES})
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS)
if(NOT APPLE)
......@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)
if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
set_target_properties(bitsandbytes
PROPERTIES
CUDA_SEPARABLE_COMPILATION ON
......
# Benchmarking
## Inference
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
See the example script in
[inference_benchmark.py](inference_benchmark.py).
### Results (as of v0.45.0)
Our overall benchmarking results compared with v0.44.1 provide the following insights:
#### LLM.int8()
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.
#### NF4/FP4
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.
Summaries with the benchmarking results are provided below.
#### NVIDIA T4 16GB
<details>
<summary>Qwen 2.5 3B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
</details>
#### NVIDIA RTX 4090 24GB
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
</details>
<details>
<summary>Qwen 2.5 14B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
</details>
#### NVIDIA H100 80GB SXM
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
</details>
<details>
<summary>Qwen 2.5 32B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| BF16 | 1 | 0.0508 | 19.67 |
| NF4 | 1 | 0.0707 | 14.14 |
| NF4+DQ | 1 | 0.0860 | 11.63 |
| INT8 | 1 | 0.1031 | 9.70 |
| INT8+Decomp | 1 | 0.1820 | 5.49 |
| BF16 | 8 | 0.0525 | 152.50 |
| NF4 | 8 | 0.1154 | 69.35 |
| NF4+DQ | 8 | 0.1209 | 66.19 |
| INT8 | 8 | 0.1078 | 74.24 |
| INT8+Decomp | 8 | 0.1958 | 40.87 |
| BF16 | 32 | 0.0547 | 584.54 |
| NF4 | 32 | 0.1246 | 256.84 |
| NF4+DQ | 32 | 0.1298 | 246.47 |
| INT8 | 32 | 0.1056 | 302.96 |
| INT8+Decomp | 32 | 0.2027 | 157.83 |
</details>
<details>
<summary>Llama 3.1 70B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| NF4 | 1 | 0.0833 | 12.00 |
| NF4+DQ | 1 | 0.1052 | 9.50 |
| INT8 | 1 | 0.1294 | 7.73 |
| INT8+Decomp | 1 | 0.1985 | 5.04 |
| NF4 | 8 | 0.2348 | 34.07 |
| NF4+DQ | 8 | 0.2423 | 33.01 |
| INT8 | 8 | 0.1313 | 60.94 |
| INT8+Decomp | 8 | 0.2052 | 38.99 |
| NF4 | 32 | 0.2491 | 128.46 |
| NF4+DQ | 32 | 0.2580 | 124.04 |
| INT8 | 32 | 0.1314 | 243.45 |
| INT8+Decomp | 32 | 0.2189 | 146.19 |
</details>
#### Software Configuration
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).
For all hardware configurations, we used the following dependencies:
* `transformers==4.46.3`
* `accelerate==1.1.1`
* `tokenizers==0.20.3`
* `torch==2.5.1`
* `bitsandbytes==0.44.1`
* `bitsandbytes==0.45.0.dev`
In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.
"""
Inference benchmarking tool.
Requirements:
transformers
accelerate
bitsandbytes
optimum-benchmark
Usage: python inference_benchmark.py model_id
options:
-h, --help show this help message and exit
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
--bf16
--fp16
--nf4
--nf4-dq
--int8
--int8-decomp
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
"""
import argparse
from pathlib import Path
from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging
import torch
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
WEIGHTS_CONFIGS = {
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
"nf4": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": False,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"nf4-dq": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"int8-decomp": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
},
},
"int8": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 0.0,
},
},
}
if __name__ == "__main__":
setup_logging(level="INFO")
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
parser.add_argument(
"--configs",
nargs="+",
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
default=["nf4", "int8", "int8-decomp"],
)
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")
parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
parser.add_argument("--input-length", type=int, default=64)
parser.add_argument("--out-dir", type=str, default="reports")
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
"""
Basic benchmark for text generation.
Usage: python benchmarking/int8/int8_benchmark.py
"""
import time
import torch
from torch.profiler import ProfilerActivity, profile
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MAX_NEW_TOKENS = 128
model_name = "meta-llama/Llama-3.1-8B"
text = "Below is a question. I need an answer.\n\nExplain machine learning: "
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
),
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
print(model)
# warmup
print("Warmup...")
for i in range(3):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
print("Profiler starting...")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_modules=True,
with_stack=True,
) as prof:
model.generate(input_ids, max_new_tokens=1)
print(
prof.key_averages().table(
sort_by="cpu_time_total",
max_name_column_width=50,
top_level_events_only=True,
row_limit=50,
)
)
torch.cuda.synchronize()
print("Generating...")
num = 0
time_1 = time.time()
for i in range(5):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
num += len(generated_ids[0])
print("=" * 40)
print(f"Example:\n{tokenizer.decode(generated_ids[0])}")
print("=" * 40)
print(f"Speed: {num/(time.time() - time_1)}token/s")
"""
Extracted from tests/test_functional.py
Note: This feature is currently unused! It is kept here for archival purposes.
Usage: pytest benchmarking/int8/row_scale_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
[
pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
],
)
@pytest.mark.skip("Row scale has some bugs for ampere")
@pytest.mark.benchmark
def test_row_scale_bench(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2)
torch.cuda.synchronize()
print("vector-wise", time.time() - t0)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/int8/training_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
],
)
@pytest.mark.benchmark
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print("")
# torch.cuda.synchronize()
## warmup
# for i in range(100):
# torch.matmul(A, w1.t())
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
grad = grad.view(-1, grad.shape[-1]).contiguous()
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
# torch.cuda.empty_cache()
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/matmul_benchmark.py
"""
import time
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
# pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"),
pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"),
# pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"),
# pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k")
],
)
@pytest.mark.benchmark
def test_bench_matmul(batch, seq, model, hidden):
iters = 1000
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
B_fp4, state = F.quantize_fp4(B)
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
# linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
# torch.cuda.synchronize()
# print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
# torch.cuda.synchronize()
# print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize()
print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(
f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(
f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
CB, SCB, _ = F.int8_vectorwise_quant(B)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
out32 = F.int8_linear_matmul(CA, CB)
torch.cuda.synchronize()
print(
f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# C32A, SA = F.transform(CA, "col32")
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
# C32A, SA = F.transform(CA, "col32")
# CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1)
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
# torch.cuda.synchronize()
# print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
# torch.cuda.synchronize()
# print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# linear8bit_train(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# linear8bit_train_thresh(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
from dataclasses import dataclass
from functools import reduce # Required in Python 3
import operator
from math import prod
from typing import Callable, Optional, Tuple
import warnings
from warnings import warn
import torch
from typing_extensions import deprecated
import bitsandbytes.functional as F
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
......@@ -104,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous()
@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
......@@ -215,6 +213,7 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
......@@ -226,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool:
return True
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format):
assert format in (
"col_turing",
......@@ -234,6 +234,7 @@ def _get_tile_size(format):
return (8, 32) if format == "col_turing" else (32, 32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
......@@ -243,27 +244,28 @@ def get_tile_inds(format, device):
@dataclass
class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False
CB = None
CxB = None
SB = None
SCB = None
CxBt = None
SBt = None
CBt = None
CB: Optional[torch.Tensor] = None
CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None
subB = None
subB: Optional[torch.Tensor] = None
outlier_pool = None
outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False
threshold = 0.0
idx = None
idx: Optional[torch.Tensor] = None
is_training = True
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = F.get_special_format_str()
formatB = "row" # TODO: Deprecate/remove
def reset_grads(self):
self.CB = None
......@@ -283,12 +285,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt
def forward(
ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
):
state = state or MatmulLtState()
# default of pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
......@@ -301,123 +308,80 @@ class MatMul8bitLt(torch.autograd.Function):
else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16
if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long()
CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None and using_igemmlt:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
# 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if ctx.needs_input_grad[1]:
# Slower path
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
else:
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# Fast path
CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
CAt = SCAt = None
# 2. Quantize B
if state.has_fp16_weights:
has_grad = True if (getattr(B, "grad", None) is not None) else False
has_grad = False
if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads()
(
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else:
has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights:
# extract outliers
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
if state.CxB is not None:
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
else:
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
shapeB = state.SB[0] if state.SB else B.shape
# Handle sparse decomposition. In some instances, we may have not found any
# outlier columns at all. In that case, we'll skip this part completely.
if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
state.idx = outlier_cols
if len(input_shape) == 3:
output_shape = (input_shape[0], input_shape[1], shapeB[0])
else:
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
# Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
CAt[:, state.idx] = 0
# Extract the input outliers in original precision
subA = A[:, state.idx].contiguous()
# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t()
else:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
else:
A_wo_outliers = A.clone()
if state.idx is not None:
A_wo_outliers[:, state.idx.long()] = 0
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
subA = None
# 3. Int8 Matmul
out32 = F.int8_linear_matmul(CA, state.CB)
# Dequantize matmul result
if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
# TODO: Fused bias for fp32/bf16?
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
output += torch.matmul(subA, state.subB)
if subA is not None and state.subB is not None:
output = output.addmm(subA, state.subB)
# 5. Save state
ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
......@@ -425,23 +389,27 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx)
else:
ctx.tensors = [None, None, A]
ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x
return clone_func(output.view(output_shape))
output_shape = (*input_shape[:-1], state.CB.shape[0])
if len(input_shape) == 3:
return output.reshape(output_shape)
return output
@staticmethod
def backward(ctx, grad_output):
def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
state: MatmulLtState = ctx.state
grad_A = grad_B = grad_bias = None
if req_gradBias:
......@@ -452,35 +420,20 @@ class MatMul8bitLt(torch.autograd.Function):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA:
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
else:
raise Exception("State must contain either CBt or CB or CxB matrix for backward")
raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None
......@@ -548,7 +501,7 @@ def matmul(
out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
threshold=0.0,
bias=None,
bias: Optional[torch.Tensor] = None,
):
state = state or MatmulLtState()
if threshold > 0.0:
......@@ -561,9 +514,10 @@ def matmul_4bit(
B: torch.Tensor,
quant_state: F.QuantState,
out: Optional[torch.Tensor] = None,
bias=None,
bias: Optional[torch.Tensor] = None,
):
assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0:
warn(
......
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes as ct
import logging
import os
......@@ -37,11 +19,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path.
"""
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:
......@@ -67,6 +45,9 @@ class BNBNativeLibrary:
def __getattr__(self, item):
return getattr(self._lib, item)
def __getitem__(self, item):
return getattr(self._lib, item)
class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True
......@@ -114,6 +95,6 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""",
)
......@@ -11,7 +11,7 @@ class CUDASpecs:
cuda_version_tuple: Tuple[int, int]
@property
def has_cublaslt(self) -> bool:
def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5)
......
......@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
# 7.5 is the minimum CC for int8 tensor cores
if not cuda_specs.has_imma:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
......
This diff is collapsed.
......@@ -16,7 +16,6 @@ from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer,
)
......@@ -481,11 +480,8 @@ class Linear4bit(nn.Linear):
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype)
return out
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
class LinearFP4(Linear4bit):
......@@ -570,11 +566,11 @@ class LinearNF4(Linear4bit):
class Int8Params(torch.nn.Parameter):
def __new__(
cls,
data=None,
data: Optional[torch.Tensor] = None,
requires_grad=True,
has_fp16_weights=False,
CB=None,
SCB=None,
CB: Optional[torch.Tensor] = None,
SCB: Optional[torch.Tensor] = None,
):
if data is None:
data = torch.empty(0)
......@@ -588,12 +584,9 @@ class Int8Params(torch.nn.Parameter):
if self.has_fp16_weights:
return super().cuda(device)
else:
# we store the 8-bit rows-major weight
# we convert this weight to the turning/ampere weight during the first inference pass
# We quantize the weight and store in 8bit row-major
B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B)
del CBt
del SCBt
CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
self.data = CB
self.CB = CB
self.SCB = SCB
......@@ -888,7 +881,6 @@ class Linear8bitLt(nn.Linear):
output_features: int,
bias=True,
has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0,
index=None,
device=None,
......@@ -905,13 +897,12 @@ class Linear8bitLt(nn.Linear):
Whether the linear class uses the bias term as well.
"""
super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState()
self.index = index
self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True
......@@ -928,29 +919,19 @@ class Linear8bitLt(nn.Linear):
param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None
key_name = prefix + f"{scb_name}"
# We now only save in row-major. This format information is stored for backwards compatibility.
format_name = prefix + "weight_format"
if not self.state.has_fp16_weights:
if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
weights_format = self.state.formatB
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
def _load_from_state_dict(
self,
......@@ -1008,12 +989,9 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights:
if self.state.CB is not None and self.state.CxB is not None:
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
if not self.state.has_fp16_weights and self.state.CB is not None:
self.weight.data = self.state.CB
return out
......
......@@ -184,9 +184,9 @@ class MatMulFP8Global(torch.autograd.Function):
class SwitchBackBnb(torch.autograd.Function):
@staticmethod
# TODO: the B008 on the line below is a likely bug; the current implementation will
# have each SwitchBackBnb instance share a single MatmulLtState instance!!!
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008
def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None):
state = state or MatmulLtState()
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if prod(A.shape) == 0:
......@@ -204,7 +204,6 @@ class SwitchBackBnb(torch.autograd.Function):
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
......@@ -216,25 +215,21 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None:
if state.threshold > 0.0 and outlier_cols is not None:
if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long()
idx = outlier_cols
CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else:
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
if state.SB is None:
state.SB = (state.CB.shape, "row")
else:
# print('A shape', A.shape)
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
if not state.has_fp16_weights and state.SB is None:
state.SB = (state.CB.shape, "row")
subA = None
# 2. Quantize B
......@@ -245,34 +240,26 @@ class SwitchBackBnb(torch.autograd.Function):
if is_transposed:
B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None:
if (state.is_training and not has_grad) or state.SB is None:
state.reset_grads()
(
CB,
state.CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB)
_,
) = F.int8_double_quant(B.to(torch.float16))
state.SB = (state.CB.shape, "row")
else:
has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights:
if outlier_cols is not None and not state.has_fp16_weights:
# extract outliers
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.idx = outlier_cols
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
shapeB = state.SB[0]
......@@ -283,25 +270,22 @@ class SwitchBackBnb(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0])
# 3. Matmul
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
out32 = F.int8_linear_matmul(CA, state.CB)
# we apply the fused bias here
if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias)
output = output.to(A.dtype)
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None)
output = output.to(A.dtype).add_(bias)
output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype)
output.add_(bias)
# 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None:
if outlier_cols is not None and subA is not None:
output += torch.matmul(subA, state.subB)
# 5. Save state
ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
......@@ -321,10 +305,10 @@ class SwitchBackBnb(torch.autograd.Function):
if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
grad_A = grad_B = grad_bias = None
......@@ -336,7 +320,7 @@ class SwitchBackBnb(torch.autograd.Function):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))
if req_gradB:
# print('back A shape', A.shape)
......@@ -344,16 +328,7 @@ class SwitchBackBnb(torch.autograd.Function):
grad_B = torch.matmul(grad_output.t(), A)
if req_gradA:
if state.CBt is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
......
#pragma once
// TODO: Let's make some of these constexpr and put in a namespace.
#define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#define BNB_WARP_SIZE 32
// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536
#else
#define BNB_MAX_THREADS_PER_SM 2048
#endif
// Maximum resident warps per SM is always directly related to the number of threads.
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
// Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
#define BNB_MAX_BLOCKS_PER_SM 16
#else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#endif
This diff is collapsed.
......@@ -112,12 +112,12 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(
template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n);
half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment