Unverified Commit 81e6345d authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

LLM.int8() Refactoring: Part 1 (#1401)



* Start of int8 refactor: remove col32/col_ampere/col_turing transforms in new igemmlt implementation

* Fix unintended change

* New naive mm_dequant kernel for row-major; cleanup

* fix

* int8 refactor: initial sparse decomp, cleanup

* Int8 refactoring: remove separate NO_CUBLASLT build; more cleanup

* int8: inference optimizations, some cleanup

* int8: more tests passing, cleanup

* int8 - more cleanup, most tests passing

* int8: specify CUDA stream for int8 ops

* perf: reduce overhead from getting cudaStream ptr

* Mark some functions for deprecation.

* int8 sparse decomp: small perf improvement

* update setup.py

* Update bitsandbytes/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8 - perf improvement for sparse decomposition inference; deprecate get_tensor_stream() in favor of new private fn

* int8 cleanup

* Ignore ruff rule ISC001 (incompatible with formatter)

* add comment

* int8 more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* int8: rename / deprecate old fn signatures

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* type annotation

* format update

* Update bitsandbytes/research/autograd/_functions.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Add comment to explain division optimization

* more cleanup

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* Update bitsandbytes/functional.py
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>

* cleanup

* Type annotations, cleanup

* remove unused kernels; improved type annotations

* small perf optimization for single-GPU systems

* small perf optimization for single-GPU systems

* update docstrings

* Improve docs and tests

* Update docstring

* Update test

* add benchmarking script

* test cleanup: add deprecated marker, move benchmarks out

* Add int8 dequant function; misc improvements

* int8 matmul fallback for inner dims not divisible by 4

* improve register usage of kInt8VectorQuant - especially for A100/H100

* disable fail-fast for package build

* maxwell compat

* ptxas verbose

* docs update

* doc update

* backward fix

* Bugfix sparse decomp

* Int8 fix for PEFT OLoRA init

* Fix test for deprecated spmm_coo

* test improvement

* doc update

* typo

* doc cleanup

* docs

* add inference benchmark script

* Add benchmarks, doc update

---------
Co-authored-by: default avatarAarni Koskela <akx@iki.fi>
parent 7dca7004
...@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90" ...@@ -8,21 +8,21 @@ build_capability="50;52;60;61;70;75;80;86;89;90"
[[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????} [[ "${cuda_version}" == 11.7.* ]] && build_capability=${build_capability%??????}
[[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???} [[ "${cuda_version}" == 11.8.* ]] && build_capability=${build_capability%???}
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja [[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
for NO_CUBLASLT in ON OFF; do
if [ "${build_os:0:6}" == ubuntu ]; then if [ "${build_os:0:6}" == ubuntu ]; then
image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04 image=nvidia/cuda:${cuda_version}-devel-ubuntu22.04
echo "Using image $image" echo "Using image $image"
docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \ docker run --platform "linux/$build_arch" -i -w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \ "apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \ && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" -DNO_CUBLASLT=${NO_CUBLASLT} . \ && cmake -DPTXAS_VERBOSE=1 -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY=\"${build_capability}\" . \
&& cmake --build ." && cmake --build ."
else else
pip install cmake==3.28.3 pip install cmake==3.28.3
cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DNO_CUBLASLT=${NO_CUBLASLT} -DCMAKE_BUILD_TYPE=Release -S . cmake -G Ninja -DCOMPUTE_BACKEND=cuda -DCOMPUTE_CAPABILITY="${build_capability}" -DCMAKE_BUILD_TYPE=Release -S .
cmake --build . --config Release cmake --build . --config Release
fi fi
done
output_dir="output/${build_os}/${build_arch}" output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}" mkdir -p "${output_dir}"
......
...@@ -60,6 +60,7 @@ jobs: ...@@ -60,6 +60,7 @@ jobs:
## ##
build-shared-libs-cuda: build-shared-libs-cuda:
strategy: strategy:
fail-fast: false
matrix: matrix:
os: [ubuntu-latest, windows-latest] os: [ubuntu-latest, windows-latest]
arch: [x86_64, aarch64] arch: [x86_64, aarch64]
......
...@@ -22,9 +22,11 @@ CMakeFiles/ ...@@ -22,9 +22,11 @@ CMakeFiles/
bitsandbytes.dir/ bitsandbytes.dir/
Debug/ Debug/
Release/ Release/
cmake-build-*/
# IDE local files # IDE local files
.vs/ .vs/
.idea/
# Distribution / packaging # Distribution / packaging
.Python .Python
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# For MSVC: `cmake -B build . && cmake --build build --config Release` # For MSVC: `cmake -B build . && cmake --build build --config Release`
# You can also use the following options and variables # You can also use the following options and variables
# - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend # - COMPUTE_BACKEND: Set to `cpu`, `cuda`, or `mps` to select the backend
# - NO_CUBLASLT: Default OFF, will skip building/linking CUBLASLT support
# - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version # - CUDA_VERSION: The expected CUDA version, for sanity checking. The actual version
# is whatever CMake finds on your path. # is whatever CMake finds on your path.
# - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC. # - COMPUTE_CAPABILITY: Which GPU Arch/Compute codes to provide to NVCC.
...@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") ...@@ -47,10 +46,8 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
if(APPLE) if(APPLE)
message(FATAL_ERROR "CUDA is not supported on macOS" ) message(FATAL_ERROR "CUDA is not supported on macOS" )
endif() endif()
option(NO_CUBLASLT "Disable CUBLAS" OFF)
set(BUILD_CUDA ON) set(BUILD_CUDA ON)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
message(STATUS "NO_CUBLASLT := ${NO_CUBLASLT}")
elseif(${COMPUTE_BACKEND} STREQUAL "mps") elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
...@@ -166,9 +163,6 @@ if(BUILD_CUDA) ...@@ -166,9 +163,6 @@ if(BUILD_CUDA)
list(APPEND SRC_FILES ${CUDA_FILES}) list(APPEND SRC_FILES ${CUDA_FILES})
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
if(NO_CUBLASLT)
string(APPEND BNB_OUTPUT_NAME "_nocublaslt")
endif()
add_compile_definitions(BUILD_CUDA) add_compile_definitions(BUILD_CUDA)
elseif(BUILD_MPS) elseif(BUILD_MPS)
if(NOT APPLE) if(NOT APPLE)
...@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include) ...@@ -212,13 +206,7 @@ target_include_directories(bitsandbytes PUBLIC csrc include)
if(BUILD_CUDA) if(BUILD_CUDA)
target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(bitsandbytes PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cusparse) target_link_libraries(bitsandbytes PUBLIC CUDA::cudart CUDA::cublas CUDA::cublasLt CUDA::cusparse)
if(NO_CUBLASLT)
target_compile_definitions(bitsandbytes PUBLIC NO_CUBLASLT)
else()
target_link_libraries(bitsandbytes PUBLIC CUDA::cublasLt)
endif()
set_target_properties(bitsandbytes set_target_properties(bitsandbytes
PROPERTIES PROPERTIES
CUDA_SEPARABLE_COMPILATION ON CUDA_SEPARABLE_COMPILATION ON
......
# Benchmarking
## Inference
End-to-end inference benchmarking can be performed using the 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
See the example script in
[inference_benchmark.py](inference_benchmark.py).
### Results (as of v0.45.0)
Our overall benchmarking results compared with v0.44.1 provide the following insights:
#### LLM.int8()
* **Turing/Ampere/Ada**: The observed per-token throughput is improved by 60-85%, while latency is decreased by 40-45%.
* **H100**: With our benchmarking of Llama 3.1 70B, we observed the new LLM.int8() to consistently outperform NF4 at batch size >= 8.
#### NF4/FP4
* **Turing/Ampere/Ada**: With batch size of 1, per-token throughput is _improved by 10-25%_ and per-token latency is _decreased by 10-20%_.
* **H100**: Across all batch sizes, per-token throughput is _improved by up to 28%_ and per-token latency is _decreased by up to 22%_.
Summaries with the benchmarking results are provided below.
#### NVIDIA T4 16GB
<details>
<summary>Qwen 2.5 3B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| FP16 | 1 | 0.0390 | 25.66 | 0.0390 | 1.00 | 25.66 | 1.000x |
| NF4 | 1 | 0.0608 | 16.45 | 0.0710 | 1.14 | 14.08 | 1.168x |
| NF4+DQ | 1 | 0.0736 | 13.58 | 0.0905 | 1.19 | 11.05 | 1.229x |
| INT8 | 1 | 0.0902 | 11.08 | 0.1609 | 1.44 | 6.21 | 1.784x |
| INT8+Decomp | 1 | 0.1672 | 5.98 | 0.2994 | 1.44 | 3.34 | 1.790x |
| FP16 | 8 | 0.0422 | 189.56 | 0.0422 | 1.00 | 189.56 | 1.000x |
| NF4 | 8 | 0.0960 | 83.37 | 0.1010 | 1.05 | 79.17 | 1.053x |
| NF4+DQ | 8 | 0.1042 | 76.80 | 0.1156 | 1.10 | 69.18 | 1.110x |
| INT8 | 8 | 0.0919 | 87.01 | 0.1640 | 1.44 | 48.78 | 1.784x |
| INT8+Decomp | 8 | 0.1812 | 44.15 | 0.3296 | 1.45 | 24.28 | 1.818x |
| FP16 | 32 | 0.0601 | 532.30 | 0.0601 | 1.00 | 532.30 | 1.000x |
| NF4 | 32 | 0.1150 | 278.32 | 0.1182 | 1.03 | 270.71 | 1.028x |
| NF4+DQ | 32 | 0.1215 | 263.36 | 0.1297 | 1.06 | 246.76 | 1.067x |
| INT8 | 32 | 0.0943 | 339.21 | 0.1640 | 1.42 | 195.14 | 1.738x |
| INT8+Decomp | 32 | 0.1912 | 167.37 | 0.3413 | 1.44 | 93.75 | 1.785x |
</details>
#### NVIDIA RTX 4090 24GB
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0211 | 47.46 | 0.0211 | 1.00 | 47.46 | 1.000x |
| NF4 | 1 | 0.0148 | 67.71 | 0.0164 | 1.10 | 61.08 | 1.109x |
| NF4+DQ | 1 | 0.0175 | 57.08 | 0.0208 | 1.16 | 48.15 | 1.185x |
| INT8 | 1 | 0.0220 | 45.39 | 0.0395 | 1.44 | 25.32 | 1.793x |
| INT8+Decomp | 1 | 0.0449 | 22.26 | 0.0743 | 1.40 | 13.45 | 1.655x |
| BF16 | 8 | 0.0239 | 334.64 | 0.0239 | 1.00 | 334.64 | 1.000x |
| NF4 | 8 | 0.0425 | 188.08 | 0.0422 | 0.99 | 189.50 | 0.993x |
| NF4+DQ | 8 | 0.0443 | 180.68 | 0.0437 | 0.99 | 183.02 | 0.987x |
| INT8 | 8 | 0.0221 | 361.61 | 0.0389 | 1.43 | 205.82 | 1.757x |
| INT8+Decomp | 8 | 0.0478 | 164.55 | 0.0777 | 1.38 | 103.01 | 1.597x |
| BF16 | 32 | 0.0304 | 1054.35 | 0.0304 | 1.00 | 1054.35 | 1.000x |
| NF4 | 32 | 0.0461 | 694.60 | 0.0466 | 1.01 | 686.90 | 1.011x |
| NF4+DQ | 32 | 0.0471 | 678.73 | 0.0480 | 1.02 | 666.33 | 1.019x |
| INT8 | 32 | 0.0230 | 1390.54 | 0.0390 | 1.41 | 819.99 | 1.696x |
| INT8+Decomp | 32 | 0.0512 | 624.94 | 0.0835 | 1.39 | 383.18 | 1.631x |
</details>
<details>
<summary>Qwen 2.5 14B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| NF4 | 1 | 0.0214 | 46.74 | 0.0256 | 1.16 | 39.10 | 1.195x |
| NF4+DQ | 1 | 0.0256 | 39.03 | 0.0318 | 1.19 | 31.46 | 1.241x |
| INT8 | 1 | 0.0326 | 30.68 | 0.0596 | 1.45 | 16.79 | 1.827x |
| INT8+Decomp | 1 | 0.0648 | 15.44 | 0.1105 | 1.41 | 9.05 | 1.706x |
| NF4 | 8 | 0.0696 | 114.95 | 0.0697 | 1.00 | 114.78 | 1.001x |
| NF4+DQ | 8 | 0.0719 | 111.29 | 0.0723 | 1.01 | 110.70 | 1.005x |
| INT8 | 8 | 0.0325 | 246.22 | 0.0596 | 1.45 | 134.21 | 1.835x |
| INT8+Decomp | 8 | 0.0721 | 110.95 | 0.1201 | 1.40 | 66.62 | 1.665x |
</details>
#### NVIDIA H100 80GB SXM
<details>
<summary>Llama 3.1 8B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> | Mean Latency (s) <sub>v0.44.1</sub> | Latency Improvement | Throughput <sub>v0.44.1</sub> | Throughput Improvement |
|----------------------|------------|------------------------------|------------------------|--------------------------|---------------------|--------------------|------------------------|
| BF16 | 1 | 0.0244 | 40.99 | 0.0244 | 1.00 | 40.99 | 1.000x |
| NF4 | 1 | 0.0331 | 30.14 | 0.0391 | 1.15 | 25.60 | 1.177x |
| NF4+DQ | 1 | 0.0411 | 24.34 | 0.0528 | 1.22 | 18.92 | 1.286x |
| INT8 | 1 | 0.0522 | 19.17 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 1 | 0.0817 | 12.24 | N/A | N/A | N/A | N/A |
| BF16 | 8 | 0.0255 | 313.90 | 0.0255 | 1.00 | 313.90 | 1.000x |
| NF4 | 8 | 0.0476 | 168.05 | 0.0551 | 1.14 | 145.13 | 1.158x |
| NF4+DQ | 8 | 0.0566 | 141.27 | 0.0663 | 1.15 | 120.67 | 1.171x |
| INT8 | 8 | 0.0515 | 155.44 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 8 | 0.0853 | 93.79 | N/A | N/A | N/A | N/A |
| BF16 | 32 | 0.0261 | 1227.96 | 0.0261 | 1.00 | 1227.96 | 1.000x |
| NF4 | 32 | 0.0486 | 658.65 | 0.0546 | 1.11 | 585.91 | 1.124x |
| NF4+DQ | 32 | 0.0577 | 555.06 | 0.0665 | 1.13 | 481.04 | 1.154x |
| INT8 | 32 | 0.0545 | 586.26 | N/A | N/A | N/A | N/A |
| INT8+Decomp | 32 | 0.0864 | 370.51 | N/A | N/A | N/A | N/A |
</details>
<details>
<summary>Qwen 2.5 32B Instruct</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| BF16 | 1 | 0.0508 | 19.67 |
| NF4 | 1 | 0.0707 | 14.14 |
| NF4+DQ | 1 | 0.0860 | 11.63 |
| INT8 | 1 | 0.1031 | 9.70 |
| INT8+Decomp | 1 | 0.1820 | 5.49 |
| BF16 | 8 | 0.0525 | 152.50 |
| NF4 | 8 | 0.1154 | 69.35 |
| NF4+DQ | 8 | 0.1209 | 66.19 |
| INT8 | 8 | 0.1078 | 74.24 |
| INT8+Decomp | 8 | 0.1958 | 40.87 |
| BF16 | 32 | 0.0547 | 584.54 |
| NF4 | 32 | 0.1246 | 256.84 |
| NF4+DQ | 32 | 0.1298 | 246.47 |
| INT8 | 32 | 0.1056 | 302.96 |
| INT8+Decomp | 32 | 0.2027 | 157.83 |
</details>
<details>
<summary>Llama 3.1 70B</summary>
| | Batch Size | Mean Latency (s) <sub>v0.45.0.dev</sub> | Throughput <sub>v0.45.0.dev</sub> |
|-------------|------------|-----------------------------------------|-----------------------------------|
| NF4 | 1 | 0.0833 | 12.00 |
| NF4+DQ | 1 | 0.1052 | 9.50 |
| INT8 | 1 | 0.1294 | 7.73 |
| INT8+Decomp | 1 | 0.1985 | 5.04 |
| NF4 | 8 | 0.2348 | 34.07 |
| NF4+DQ | 8 | 0.2423 | 33.01 |
| INT8 | 8 | 0.1313 | 60.94 |
| INT8+Decomp | 8 | 0.2052 | 38.99 |
| NF4 | 32 | 0.2491 | 128.46 |
| NF4+DQ | 32 | 0.2580 | 124.04 |
| INT8 | 32 | 0.1314 | 243.45 |
| INT8+Decomp | 32 | 0.2189 | 146.19 |
</details>
#### Software Configuration
We focus on the default PyTorch CUDA backend in 🤗 [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark). We used commit [`6e6b1036`](https://github.com/huggingface/optimum-benchmark/commit/6e6b10363f3ac65926881f2c6a6113b6cefc06cd).
For all hardware configurations, we used the following dependencies:
* `transformers==4.46.3`
* `accelerate==1.1.1`
* `tokenizers==0.20.3`
* `torch==2.5.1`
* `bitsandbytes==0.44.1`
* `bitsandbytes==0.45.0.dev`
In the RTX 4090 setting, the CUDA 12.4 build of PyTorch is used. In the other settings we used the CUDA 12.1 build.
"""
Inference benchmarking tool.
Requirements:
transformers
accelerate
bitsandbytes
optimum-benchmark
Usage: python inference_benchmark.py model_id
options:
-h, --help show this help message and exit
--configs {bf16,fp16,nf4,nf4-dq,int8,int8-decomp} [{bf16,fp16,nf4,nf4-dq,int8,int8-decomp} ...]
--bf16
--fp16
--nf4
--nf4-dq
--int8
--int8-decomp
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
"""
import argparse
from pathlib import Path
from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig
from optimum_benchmark.logging_utils import setup_logging
import torch
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
WEIGHTS_CONFIGS = {
"fp16": {"torch_dtype": "float16", "quantization_scheme": None, "quantization_config": {}},
"bf16": {"torch_dtype": "bfloat16", "quantization_scheme": None, "quantization_config": {}},
"nf4": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": False,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"nf4-dq": {
"torch_dtype": "bfloat16" if BFLOAT16_SUPPORT else "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_4bit": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_use_double_quant": True,
"bnb_4bit_compute_dtype": torch.bfloat16 if BFLOAT16_SUPPORT else "float16",
},
},
"int8-decomp": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 6.0,
},
},
"int8": {
"torch_dtype": "float16",
"quantization_scheme": "bnb",
"quantization_config": {
"load_in_8bit": True,
"llm_int8_threshold": 0.0,
},
},
}
if __name__ == "__main__":
setup_logging(level="INFO")
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
parser.add_argument(
"--configs",
nargs="+",
choices=["bf16", "fp16", "nf4", "nf4-dq", "int8", "int8-decomp"],
default=["nf4", "int8", "int8-decomp"],
)
parser.add_argument("--bf16", dest="configs", action="append_const", const="bf16")
parser.add_argument("--fp16", dest="configs", action="append_const", const="fp16")
parser.add_argument("--nf4", dest="configs", action="append_const", const="nf4")
parser.add_argument("--nf4-dq", dest="configs", action="append_const", const="nf4-dq")
parser.add_argument("--int8", dest="configs", action="append_const", const="int8")
parser.add_argument("--int8-decomp", dest="configs", action="append_const", const="int8-decomp")
parser.add_argument("--batches", nargs="+", type=int, default=[1, 8, 16, 32])
parser.add_argument("--input-length", type=int, default=64)
parser.add_argument("--out-dir", type=str, default="reports")
args = parser.parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for batch_size in args.batches:
print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
scenario_config = InferenceConfig(
latency=True,
memory=True,
input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
)
backend_config = PyTorchConfig(
device="cuda",
device_ids="0",
device_map="auto",
no_weights=False,
model=args.model_id,
**WEIGHTS_CONFIGS[config],
)
benchmark_config = BenchmarkConfig(
name=f"benchmark-{config}-bsz{batch_size}",
scenario=scenario_config,
launcher=launcher_config,
backend=backend_config,
)
out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
benchmark_report = Benchmark.launch(benchmark_config)
benchmark_report.log()
benchmark_report.save_json(out_path)
"""
Basic benchmark for text generation.
Usage: python benchmarking/int8/int8_benchmark.py
"""
import time
import torch
from torch.profiler import ProfilerActivity, profile
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MAX_NEW_TOKENS = 128
model_name = "meta-llama/Llama-3.1-8B"
text = "Below is a question. I need an answer.\n\nExplain machine learning: "
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
quantization_config=BitsAndBytesConfig(
load_in_8bit=True,
llm_int8_threshold=6.0,
),
attn_implementation="sdpa",
torch_dtype=torch.float16,
)
print(model)
# warmup
print("Warmup...")
for i in range(3):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
print("Profiler starting...")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
with_modules=True,
with_stack=True,
) as prof:
model.generate(input_ids, max_new_tokens=1)
print(
prof.key_averages().table(
sort_by="cpu_time_total",
max_name_column_width=50,
top_level_events_only=True,
row_limit=50,
)
)
torch.cuda.synchronize()
print("Generating...")
num = 0
time_1 = time.time()
for i in range(5):
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
num += len(generated_ids[0])
print("=" * 40)
print(f"Example:\n{tokenizer.decode(generated_ids[0])}")
print("=" * 40)
print(f"Speed: {num/(time.time() - time_1)}token/s")
"""
Extracted from tests/test_functional.py
Note: This feature is currently unused! It is kept here for archival purposes.
Usage: pytest benchmarking/int8/row_scale_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("dim1", "dim4", "inner"),
[
pytest.param(1024, 12288 * 4, 12288, id="1024, 12288*4, 12288"),
pytest.param(2048, 4096 * 4, 4096, id="2048, 4096*4, 4096"),
],
)
@pytest.mark.skip("Row scale has some bugs for ampere")
@pytest.mark.benchmark
def test_row_scale_bench(dim1, dim4, inner):
formatB = F.get_special_format_str()
err1, err2, err3 = [], [], []
relerr1, relerr2 = [], []
scale = 1
A = torch.randn(dim1, inner, device="cuda").half()
B = torch.randn(dim4, inner, device="cuda").half()
torch.nn.init.xavier_uniform_(B)
# warmpup
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
C1 = torch.matmul(A, B.t())
torch.cuda.synchronize()
print("16", time.time() - t0)
C1a, C1b, stats1a, stats1b, coo_tensor = F.int8_double_quant(A)
CB, absmaxB = F.vectorwise_quant(B, quant_type="linear")
A2, SA = F.nvidia_transform(C1a, "col32")
B2, SB = F.nvidia_transform(CB, formatB)
A1, maxA = F.vectorwise_quant(A, dim=1)
c = 10.0 * inner * scale
row_scale = maxA / c
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2, dtype=torch.int8, row_scale=row_scale)
torch.cuda.synchronize()
print("row-wise", time.time() - t0)
C2a, C2b, stats2a, stats2b, coo_tensor = F.int8_double_quant(B)
B2, SB = F.nvidia_transform(C2a, formatB)
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
outC32 = F.int8_linear_matmul(A2, B2)
torch.cuda.synchronize()
print("vector-wise", time.time() - t0)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/int8/training_benchmark.py
"""
import time
import pytest
import torch
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
pytest.param(2, 512, 4 * 1024, 3 * 4 * 1024, id="batch=2, seq=512, model=4k, hidden=12k"),
pytest.param(2, 512, 5120, 3 * 5120, id="batch=2, seq=512, model=5k, hidden=15k"),
pytest.param(2, 512, 12 * 1024, 4 * 12 * 1024, id="batch=2, seq=512, model=12k, hidden=48k"),
],
)
@pytest.mark.benchmark
def test_bench_8bit_training(batch, seq, model, hidden):
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
grad = torch.randn(batch, seq, model, device="cuda").half()
w1 = torch.randint(-128, 127, size=(hidden, model), device="cuda").half()
w2 = torch.randint(-128, 127, size=(model, hidden), device="cuda").half()
print("")
# torch.cuda.synchronize()
## warmup
# for i in range(100):
# torch.matmul(A, w1.t())
# torch.cuda.synchronize()
dtype = torch.int8
A = A.view(-1, A.shape[-1]).contiguous()
grad = grad.view(-1, grad.shape[-1]).contiguous()
torch.cuda.synchronize()
t0 = time.time()
for i in range(k):
out1 = torch.matmul(A, w1.t()) # fc1
# out2 = torch.matmul(out1, w2.t())# fc2
# d1 = torch.matmul(grad, w2) # delta1
# d2 = torch.matmul(d1, w1) # delta2
# grad1 = torch.einsum('bo,bh->oh', out1, grad) # grad w2
# grad2 = torch.einsum('bh,bo->ho', A, d2) # grad w1
torch.cuda.synchronize()
t16 = time.time() - t0
print(t16)
# torch.cuda.empty_cache()
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# C32A, SA = F.transform2(CA, 'col32')
## fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
##out1 = F.mm_dequant(out1_32, Sout1_32, statsAt, statsw1t)
## fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
##out2 = F.mm_dequant(out2_32, Sout2_32, statsout1t, statsw2t)
## delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
##d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
##d1 = F.mm_dequant(d1_32, Sd1_32, statsgradt, statsw2)
## delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
##d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
##d2 = F.mm_dequant(d2_32, Sd2_32, statsd1t, statsw1)
## grad1
# C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
##grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
##grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1, statsgrad)
## grad2
# C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
##grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
##grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsA, statsd1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# Cw2, Cw2t, statsw2, statsw2t, coo_tensor = F.double_quant(w2)
# CTw1, Sw1 = F.transform2(Cw1, formatB)
# CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# CTw2, Sw2 = F.transform2(Cw2, formatB)
# CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(k):
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #Cw1, Cw1t, statsw1, statsw1t, coo_tensor = F.double_quant(w1)
# #CTw1, Sw1 = F.transform2(Cw1, formatB)
# #CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A, threshold=3.5)
# CA, CAt, statsA, statsAt, coo_tensor = F.double_quant(A)
# #CTw1t, Sw1t = F.transform2(Cw1t, formatB, transpose=True)
# #CTw2, Sw2 = F.transform2(Cw2, formatB)
# #CTw2t, Sw2t = F.transform2(Cw2t, formatB, transpose=True)
# C32A, SA = F.transform2(CA, 'col32')
# # fc1
# out1_32, Sout1_32 = F.igemmlt(C32A, CTw1, SA, Sw1, dtype=dtype)
# #out1dn = F.mm_dequant(out1_32, Sout1_32, statsA, statsw1)
# #print(coo_tensor.nnz)
# #out1sp = F.spmm_coo(coo_tensor, w1.t())
# #print(w1.t().shape)
# #out1 = out1dn + out1sp
# # fc2
# Cout1, Cout1t, statsout1, statsout1t, coo_tensor = F.double_quant(out1)
# C32out1, Sout1 = F.transform2(Cout1, 'col32')
# out2_32, Sout2_32 = F.igemmlt(C32out1, CTw2, Sout1, Sw2, dtype=dtype)
# #out2 = F.mm_dequant(out2_32, Sout2_32, statsout1, statsw2)
# # delta1
# Cgrad, Cgradt, statsgrad, statsgradt, coo_tensor = F.double_quant(grad)
# C32grad, Sgrad = F.transform2(Cgrad, 'col32')
# d1_32, Sd1_32 = F.igemmlt(C32grad, CTw2t, Sgrad, Sw2t, dtype=dtype)
# #d1 = F.mm_dequant(d1_32, Sd1_32, statsgrad, statsw2t)
# # delta2
# Cd1, Cd1t, statsd1, statsd1t, coo_tensor = F.double_quant(d1)
# C32d1, Sd1 = F.transform2(Cd1, 'col32')
# d2_32, Sd2_32 = F.igemmlt(C32d1, CTw1t, Sd1, Sw1t, dtype=dtype)
# #d2 = F.mm_dequant(d2_32, Sd2_32, statsd1, statsw1t)
# # grad1
# #C32out1t, Sout1t = F.transform2(Cout1t, 'col32', transpose=True)
# #CTgradt, Sgradt = F.transform2(Cgradt, formatB, transpose=True)
# #grad1_32, Sgrad1_32 = F.igemmlt(C32out1t, CTgradt, Sout1t, Sgradt, dtype=dtype)
# #grad1 = F.mm_dequant(grad1_32, Sgrad1_32, statsout1t, statsgradt)
# ## grad2
# #C32At, SAt = F.transform2(CAt, 'col32', transpose=True)
# #CTd1t, Sd1t = F.transform2(Cd1t, formatB, transpose=True)
# #grad2_32, Sgrad2_32 = F.igemmlt(C32At, CTd1t, SAt, Sd1t, dtype=dtype)
# #grad2 = F.mm_dequant(grad2_32, Sgrad2_32, statsAt, statsd1t)
# torch.cuda.synchronize()
# t8 = time.time() - t0
# print(t8)
"""
Extracted from tests/test_functional.py
Usage: pytest benchmarking/matmul_benchmark.py
"""
import time
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
k = 20
torch.set_printoptions(precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000)
@pytest.mark.parametrize(
("batch", "seq", "model", "hidden"),
[
# pytest.param(1, 128, 6656, 4 * 6656, id="batch=1, seq=128, model=6656, hidden=26k"),
pytest.param(1, 1, 3584, 512, id="batch=1, seq=128, model=3584, hidden=19k"),
# pytest.param(4, 128, 6656, 4 * 6656, id="batch=4, seq=128, model=6656, hidden=26k"),
# pytest.param(16, 256, 6656, 4 * 6656, id="batch=16, seq=256, model=6656, hidden=26k")
],
)
@pytest.mark.benchmark
def test_bench_matmul(batch, seq, model, hidden):
iters = 1000
formatB = F.get_special_format_str()
A = torch.randn(batch, seq, model, device="cuda").half()
B = torch.empty(hidden, model, dtype=torch.float16, device="cuda")
torch.nn.init.xavier_uniform_(B)
B_fp4, state = F.quantize_fp4(B)
B_fp4_c, state_c = F.quantize_fp4(B, compress_statistics=True)
B_nf4, state_nf4 = F.quantize_nf4(B)
B_nf4_c, state_nf4_c = F.quantize_nf4(B, compress_statistics=True)
linear8bit = bnb.nn.Linear8bitLt(model, hidden, False, False).cuda().half()
linear8bit.eval()
outliers = torch.randint(0, model, size=(5,)).cuda()
A[:, :, outliers] = 8.0
linearMixedBit = bnb.nn.Linear8bitLt(model, hidden, False, False, threshold=6.0).cuda().half()
# linearMixedBit.eval()
linear8bit_train = bnb.nn.Linear8bitLt(model, hidden, False).cuda().half()
linear8bit_train_thresh = bnb.nn.Linear8bitLt(model, hidden, False, threshold=6.0).cuda().half()
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
# warmup
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print("")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
torch.matmul(A, B.t())
torch.cuda.synchronize()
print(
f"pytorch fp16: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s",
)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state)
# torch.cuda.synchronize()
# print( f"bnb fp4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# bnb.matmul_4bit(A, B_fp4.t(), quant_state=state_c)
# torch.cuda.synchronize()
# print( f"bnb fp4 + compressed stats: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s" )
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
torch.cuda.synchronize()
print(f"bnb nf4: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul_4bit(A, B_nf4_c.t(), quant_state=state_nf4_c)
torch.cuda.synchronize()
print(f"bnb nf4+DQ: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B)
torch.cuda.synchronize()
print(
f"B -> CB (each iteration): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
bnb.matmul(A, B, threshold=6.0)
torch.cuda.synchronize()
print(
f"B -> CB + threshold: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
CA, SCA, _ = F.int8_vectorwise_quant(A, threshold=0.0)
CB, SCB, _ = F.int8_vectorwise_quant(B)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
out32 = F.int8_linear_matmul(CA, CB)
torch.cuda.synchronize()
print(
f"no overhead int8 [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# C32A, SA = F.transform(CA, "col32")
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=0.0)
# C32A, SA = F.transform(CA, "col32")
# CB, CBt, SCB, SCBt, coo_tensorB = F.double_quant(B)
# CxB, SB = F.transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# torch.cuda.synchronize()
# print(f"no overhead matmul-lt: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1)
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1)
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# F.vectorwise_mm_dequant(Cout, statsA, statsB.t())
# torch.cuda.synchronize()
# print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# BA, statsB = F.vectorwise_quant(B, dim=1, quant_type="linear")
# CxB, SB = F.nvidia_transform(CB, to_order=formatB)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# A2 = A.view(-1, A.shape[-1]).contiguous()
# CA, statsA = F.vectorwise_quant(A2, dim=1, quant_type="linear")
# C32A, SA = F.nvidia_transform(CA, "col32")
# out32, Sout32 = F.igemmlt(C32A, CxB, SA, SB)
# Cout, Sout = F.nvidia_transform(out32, "row", state=Sout32)
# out = Cout * statsB * statsA * (1.0 / (127 * 127))
# torch.cuda.synchronize()
# print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
linear8bit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linear8bit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
linearMixedBit(A)
torch.cuda.synchronize()
t0 = time.time()
for i in range(iters):
linearMixedBit(A)
torch.cuda.synchronize()
print(
f"bnb linear8bitlt with threshold (eval): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s"
)
# linear8bit_train(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
# linear8bit_train_thresh(A)
# torch.cuda.synchronize()
# t0 = time.time()
# for i in range(iters):
# linear8bit_train(A)
# torch.cuda.synchronize()
# print( f"bnb linear8bitlt with threshold (training): [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
from dataclasses import dataclass from dataclasses import dataclass
from functools import reduce # Required in Python 3 from math import prod
import operator
from typing import Callable, Optional, Tuple from typing import Callable, Optional, Tuple
import warnings import warnings
from warnings import warn from warnings import warn
import torch import torch
from typing_extensions import deprecated
import bitsandbytes.functional as F import bitsandbytes.functional as F
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov: # The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py # https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
...@@ -104,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) - ...@@ -104,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous() return outputs.reshape(rows, cols).contiguous()
@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function): class MatMul8bit(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None): def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
...@@ -215,6 +213,7 @@ bmm_cublas = MatMul8bit.apply ...@@ -215,6 +213,7 @@ bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply matmul_cublas = MatMul8bit.apply
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool: def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel""" """check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5): if torch.cuda.get_device_capability(device=device) < (7, 5):
...@@ -226,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool: ...@@ -226,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool:
return True return True
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format): def _get_tile_size(format):
assert format in ( assert format in (
"col_turing", "col_turing",
...@@ -234,6 +234,7 @@ def _get_tile_size(format): ...@@ -234,6 +234,7 @@ def _get_tile_size(format):
return (8, 32) if format == "col_turing" else (32, 32) return (8, 32) if format == "col_turing" else (32, 32)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device): def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device) transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad(): with torch.no_grad():
...@@ -243,27 +244,28 @@ def get_tile_inds(format, device): ...@@ -243,27 +244,28 @@ def get_tile_inds(format, device):
@dataclass @dataclass
class MatmulLtState: class MatmulLtState:
_tile_indices: Optional[torch.Tensor] = None _tile_indices: Optional[torch.Tensor] = None
force_no_igemmlt: bool = False force_no_igemmlt: bool = False
CB = None
CxB = None
SB = None
SCB = None
CxBt = None CB: Optional[torch.Tensor] = None
SBt = None CxB: Optional[torch.Tensor] = None # TODO: Deprecate/remove
CBt = None SB: Optional[torch.Tensor] = None
SCB: Optional[torch.Tensor] = None
CxBt: Optional[torch.Tensor] = None # TODO: Deprecate/remove
SBt: Optional[torch.Tensor] = None
CBt: Optional[torch.Tensor] = None
subB = None subB: Optional[torch.Tensor] = None
outlier_pool = None outlier_pool: Optional[GlobalOutlierPooler] = None
has_accumulated_gradients = False has_accumulated_gradients = False
threshold = 0.0 threshold = 0.0
idx = None idx: Optional[torch.Tensor] = None
is_training = True is_training = True
has_fp16_weights = True has_fp16_weights = True
memory_efficient_backward = False
use_pool = False use_pool = False
formatB = F.get_special_format_str() formatB = "row" # TODO: Deprecate/remove
def reset_grads(self): def reset_grads(self):
self.CB = None self.CB = None
...@@ -283,12 +285,17 @@ class MatmulLtState: ...@@ -283,12 +285,17 @@ class MatmulLtState:
class MatMul8bitLt(torch.autograd.Function): class MatMul8bitLt(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@staticmethod @staticmethod
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState): def forward(
using_igemmlt = supports_igemmlt(A.device) and not state.force_no_igemmlt ctx: torch.autograd.function.FunctionCtx,
A: torch.Tensor,
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None,
):
state = state or MatmulLtState()
# default of pytorch behavior if inputs are empty # default of pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -301,123 +308,80 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -301,123 +308,80 @@ class MatMul8bitLt(torch.autograd.Function):
else: else:
return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device) return torch.empty(A.shape[:-1] + B.shape[:1], dtype=A.dtype, device=A.device)
# 1. Quantize A
# 2. Quantize B
# 3. Matmul
# 4. Mixed-precision decomposition matmul
# 5. Save state
formatB = state.formatB
input_shape = A.shape input_shape = A.shape
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
# Cast A to fp16 # Cast A to fp16
if A.dtype != torch.float16: if A.dtype != torch.float16:
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization") warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
# 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.reshape(-1, A.shape[-1]) A = A.reshape(-1, A.shape[-1])
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None: # 1. Quantize A. Note that as a side-effect, outliers are suppressed in CA/CAt.
if state.has_fp16_weights: if ctx.needs_input_grad[1]:
idx = torch.unique(coo_tensorA.colidx).long() # Slower path
CA[:, idx] = 0 CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
CAt[:, idx] = 0
subA = A[:, idx]
state.subB = B[:, idx].t().contiguous()
state.idx = idx
else: else:
if state.CxB is None and using_igemmlt: # Fast path
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions CA, SCA, outlier_cols = F.int8_vectorwise_quant(A.to(torch.float16), threshold=state.threshold)
# we also need to convert it to the turing/ampere format CAt = SCAt = None
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else:
if not state.has_fp16_weights and state.CxB is None and using_igemmlt:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None
# 2. Quantize B has_grad = False
if state.has_fp16_weights:
has_grad = True if (getattr(B, "grad", None) is not None) else False if state.has_fp16_weights or state.CB is None:
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1) is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: if is_transposed:
B = B.contiguous() B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.CB is None or state.SCB is None:
state.reset_grads() state.reset_grads()
(
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B.to(torch.float16))
if using_igemmlt:
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
state.CB = CB
else:
has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights: # 2. Quantize B
# extract outliers state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))
outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
if state.CxB is not None:
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
else:
outliers = state.CB[:, state.idx.long()].clone()
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) # Handle sparse decomposition. In some instances, we may have not found any
CA[:, state.idx.long()] = 0 # outlier columns at all. In that case, we'll skip this part completely.
CAt[:, state.idx.long()] = 0 if state.threshold > 0.0 and outlier_cols is not None and outlier_cols.numel():
subA = A[:, state.idx.long()] state.idx = outlier_cols
shapeB = state.SB[0] if state.SB else B.shape # Zero out the outliers in the transposed 8bit inputs.
if CAt is not None:
CAt[:, state.idx] = 0
if len(input_shape) == 3: # Extract the input outliers in original precision
output_shape = (input_shape[0], input_shape[1], shapeB[0]) subA = A[:, state.idx].contiguous()
# Extract the corresponding weights
if state.has_fp16_weights:
state.subB = B[:, state.idx].t()
else:
# To dequantize our weights associated with the input outliers,
# we want to divide by 127. It's however more performant to multiply
# by the reciprocal.
outliers = state.CB[:, state.idx]
state.subB = (outliers.t() * state.SCB * 7.874015718698502e-3).to(A.dtype)
else: else:
output_shape = (input_shape[0], shapeB[0]) subA = None
# 3. Int8 Matmul
out32 = F.int8_linear_matmul(CA, state.CB)
# 3. Matmul # Dequantize matmul result
if using_igemmlt:
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
if bias is None or bias.dtype == torch.float16: if bias is None or bias.dtype == torch.float16:
# we apply the fused bias here # we apply the fused bias here
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
output = output.to(A.dtype)
else: # apply bias separately else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) # TODO: Fused bias for fp32/bf16?
output = output.to(A.dtype).add_(bias) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype).add_(bias)
else:
A_wo_outliers = A.clone()
if state.idx is not None:
A_wo_outliers[:, state.idx.long()] = 0
output = torch.nn.functional.linear(A_wo_outliers, state.CB.to(A.dtype))
output = output.mul_(state.SCB.unsqueeze(0).mul(1.0 / 127.0))
if bias is not None:
output = output.add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if subA is not None and state.subB is not None:
output += torch.matmul(subA, state.subB) output = output.addmm(subA, state.subB)
# 5. Save state # 5. Save state
ctx.state = state ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
...@@ -425,23 +389,27 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -425,23 +389,27 @@ class MatMul8bitLt(torch.autograd.Function):
ctx.tensors = (CAt, subA, A) ctx.tensors = (CAt, subA, A)
ctx.tensor_states = (SCAt, state.idx) ctx.tensor_states = (SCAt, state.idx)
else: else:
ctx.tensors = [None, None, A] ctx.tensors = [None, None, None]
ctx.tensor_states = (None, None) ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None) ctx.save_for_backward(None, None)
clone_func = torch.clone if len(output_shape) == 3 else lambda x: x output_shape = (*input_shape[:-1], state.CB.shape[0])
return clone_func(output.view(output_shape))
if len(input_shape) == 3:
return output.reshape(output_shape)
return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor):
if ctx.is_empty: if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB state: MatmulLtState = ctx.state
state = ctx.state
grad_A = grad_B = grad_bias = None grad_A = grad_B = grad_bias = None
if req_gradBias: if req_gradBias:
...@@ -452,35 +420,20 @@ class MatMul8bitLt(torch.autograd.Function): ...@@ -452,35 +420,20 @@ class MatMul8bitLt(torch.autograd.Function):
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True) Cgrad, _, _, SCgradt, _ = F.int8_double_quant(grad_output.to(torch.float16))
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt) gradB32 = F.int8_linear_matmul(Cgrad.t().contiguous(), CAt.t())
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt) grad_B = F.int8_mm_dequant(gradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None: if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA) grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
if req_gradA: if req_gradA:
if state.CBt is not None: if state.CB is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output.to(ctx.dtype_A), CB).view(ctx.grad_shape)
elif state.CxB is not None:
CB = (
undo_layout(state.CxB, state.tile_indices)
.to(ctx.dtype_A)
.mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
)
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
raise Exception("State must contain either CBt or CB or CxB matrix for backward") raise Exception("State must contain CB matrix for backward")
return grad_A, grad_B, None, grad_bias, None return grad_A, grad_B, None, grad_bias, None
...@@ -548,7 +501,7 @@ def matmul( ...@@ -548,7 +501,7 @@ def matmul(
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
state: Optional[MatmulLtState] = None, state: Optional[MatmulLtState] = None,
threshold=0.0, threshold=0.0,
bias=None, bias: Optional[torch.Tensor] = None,
): ):
state = state or MatmulLtState() state = state or MatmulLtState()
if threshold > 0.0: if threshold > 0.0:
...@@ -561,9 +514,10 @@ def matmul_4bit( ...@@ -561,9 +514,10 @@ def matmul_4bit(
B: torch.Tensor, B: torch.Tensor,
quant_state: F.QuantState, quant_state: F.QuantState,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
bias=None, bias: Optional[torch.Tensor] = None,
): ):
assert quant_state is not None assert quant_state is not None
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False:
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn( warn(
......
"""
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multiple)
- CuBLAS-LT: full-build 8-bit optimizer
- no CuBLAS-LT: no 8-bit matrix multiplication (`nomatmul`)
evaluation:
- if paths faulty, return meaningful error
- else:
- determine CUDA version
- determine capabilities
- based on that set the default path
"""
import ctypes as ct import ctypes as ct
import logging import logging
import os import os
...@@ -37,11 +19,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -37,11 +19,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
The library is not guaranteed to exist at the returned path. The library is not guaranteed to exist at the returned path.
""" """
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}{DYNAMIC_LIBRARY_SUFFIX}"
if not cuda_specs.has_cublaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
...@@ -67,6 +45,9 @@ class BNBNativeLibrary: ...@@ -67,6 +45,9 @@ class BNBNativeLibrary:
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._lib, item) return getattr(self._lib, item)
def __getitem__(self, item):
return getattr(self._lib, item)
class CudaBNBNativeLibrary(BNBNativeLibrary): class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True compiled_with_cuda = True
...@@ -114,6 +95,6 @@ python -m bitsandbytes ...@@ -114,6 +95,6 @@ python -m bitsandbytes
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""", """,
) )
...@@ -11,7 +11,7 @@ class CUDASpecs: ...@@ -11,7 +11,7 @@ class CUDASpecs:
cuda_version_tuple: Tuple[int, int] cuda_version_tuple: Tuple[int, int]
@property @property
def has_cublaslt(self) -> bool: def has_imma(self) -> bool:
return self.highest_compute_capability >= (7, 5) return self.highest_compute_capability >= (7, 5)
......
...@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -134,8 +134,8 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")
# 7.5 is the minimum CC for cublaslt # 7.5 is the minimum CC for int8 tensor cores
if not cuda_specs.has_cublaslt: if not cuda_specs.has_imma:
print_dedented( print_dedented(
""" """
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
......
...@@ -3,25 +3,19 @@ ...@@ -3,25 +3,19 @@
# This source code is licensed under the MIT license found in the # This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import ctypes as ct import ctypes as ct
from functools import reduce # Required in Python 3
import itertools import itertools
import operator from math import prod
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from typing_extensions import deprecated
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import lib from .cextension import lib
# math.prod not compatible with python < 3.8
def prod(iterable):
return reduce(operator.mul, iterable, 1)
name2qmap = {} name2qmap = {}
if lib and lib.compiled_with_cuda: if lib and lib.compiled_with_cuda:
...@@ -197,6 +191,20 @@ dtype2bytes[torch.int8] = 1 ...@@ -197,6 +191,20 @@ dtype2bytes[torch.int8] = 1
FIRST_CUDA_DEVICE = torch.device("cuda", index=0) FIRST_CUDA_DEVICE = torch.device("cuda", index=0)
# When multiple GPUs are present, we use a context manager to
# switch to the correct device of a tensor before invoking our CUDA
# kernels in the C++ library. However, when there's only one device
# there is no need to incur the overhead of cudaGetDevice/cudaSetDevice.
if torch.cuda.device_count() > 1:
def _cuda_device_of(a: torch.Tensor):
return torch.cuda.device_of(a)
else:
import contextlib
def _cuda_device_of(a: torch.Tensor):
return contextlib.nullcontext()
def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE): def get_paged(*shape, dtype=torch.float32, device=FIRST_CUDA_DEVICE):
num_bytes = dtype2bytes[dtype] * prod(shape) num_bytes = dtype2bytes[dtype] * prod(shape)
...@@ -251,10 +259,12 @@ def fill(A, value, device=None, prefetch=True): ...@@ -251,10 +259,12 @@ def fill(A, value, device=None, prefetch=True):
elementwise_func("fill", A, None, value) elementwise_func("fill", A, None, value)
@deprecated("Function will be removed in a future release.", category=FutureWarning)
def arange(A, device=None): def arange(A, device=None):
elementwise_func("arange", A, None, 0) elementwise_func("arange", A, None, 0)
@deprecated("Function will be removed in a future release.", category=FutureWarning)
def _mul(A, B, device=None): def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0) elementwise_func("_mul", A, B, 0)
...@@ -421,72 +431,88 @@ def create_quantile_map(A, total_bits=8): ...@@ -421,72 +431,88 @@ def create_quantile_map(A, total_bits=8):
return q return q
@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning)
def get_special_format_str(): def get_special_format_str():
if not torch.cuda.is_available(): return "row"
return "col_turing"
major, _minor = torch.cuda.get_device_capability()
if major <= 7:
return "col_turing"
if major == 8:
return "col_ampere"
return "col_turing"
def is_on_gpu(tensors): def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
"""Verifies that the input tensors are all on the same device.
An input tensor may also be marked as `paged`, in which case the device placement is ignored.
Args:
tensors (`Iterable[Optional[torch.Tensor]]`): A list of tensors to verify.
Raises:
`RuntimeError`: Raised when the verification fails.
Returns:
`Literal[True]`
"""
on_gpu = True on_gpu = True
gpu_ids = set() gpu_ids = set()
for t in tensors: for t in tensors:
if t is None: # NULL pointers and paged tensors are OK.
continue # NULL pointers are fine if t is not None and not getattr(t, "is_paged", False):
is_paged = getattr(t, "is_paged", False) on_gpu &= t.is_cuda
on_gpu &= t.device.type == "cuda" or is_paged
if not is_paged:
gpu_ids.add(t.device.index) gpu_ids.add(t.device.index)
if not on_gpu: if not on_gpu:
raise TypeError( raise RuntimeError(
f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}", f"All input tensors need to be on the same GPU, but found some tensors to not be on a GPU:\n {[(t.shape, t.device) for t in tensors]}",
) )
if len(gpu_ids) > 1: if len(gpu_ids) > 1:
raise TypeError( raise RuntimeError(
f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}", f"Input tensors need to be on the same GPU, but found the following tensor and device combinations:\n {[(t.shape, t.device) for t in tensors]}",
) )
return on_gpu return on_gpu
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream:
stream = torch.cuda.current_stream(tensor.device) return torch.cuda.current_stream(tensor.device)
return stream
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]:
""" """Gets the memory address of the first element of a tenso
Get the ctypes pointer from a PyTorch Tensor.
Parameters Args:
---------- A (`Optional[Tensor]`): A PyTorch tensor.
A : torch.tensor
The PyTorch tensor.
Returns Returns:
------- `Optional[ct.c_void_p]`: A pointer to the underlying tensor data.
ctypes.c_void_p
""" """
if A is None: if A is None:
return None return None
else:
return ct.c_void_p(A.data.data_ptr()) return ct.c_void_p(A.data_ptr())
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def pre_call(device): def pre_call(device):
prev_device = torch.cuda.current_device() prev_device = torch.cuda.current_device()
torch.cuda.set_device(device) torch.cuda.set_device(device)
return prev_device return prev_device
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def post_call(prev_device): def post_call(prev_device):
torch.cuda.set_device(prev_device) torch.cuda.set_device(prev_device)
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_func(dtype, orderA, orderOut, transpose=False): def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}' name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name): if not hasattr(lib, name):
...@@ -498,6 +524,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False): ...@@ -498,6 +524,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
return getattr(lib, name) return getattr(lib, name)
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False): def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty # init_func = torch.empty
init_func = torch.zeros init_func = torch.zeros
...@@ -537,6 +567,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans ...@@ -537,6 +567,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans
raise NotImplementedError(f"To_order not supported: {to_order}") raise NotImplementedError(f"To_order not supported: {to_order}")
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def nvidia_transform( def nvidia_transform(
A, A,
to_order, to_order,
...@@ -818,37 +852,38 @@ class QuantState: ...@@ -818,37 +852,38 @@ class QuantState:
def quantize_blockwise( def quantize_blockwise(
A: Tensor, A: torch.Tensor,
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=4096, blocksize=4096,
nested=False, nested=False,
) -> Tuple[Tensor, QuantState]: ) -> Tuple[torch.Tensor, QuantState]:
""" """Quantize a tensor in blocks of values.
Quantize tensor A in blocks of size 4096 values.
The input tensor is quantized by dividing it into blocks of `blocksize` values.
Quantizes tensor A by dividing it into blocks of 4096 values. The the absolute maximum value within these blocks is calculated for scaling
Then the absolute maximum value within these blocks is calculated the non-linear quantization.
for the non-linear quantization.
Args:
Parameters A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
---------- code (`torch.Tensor`, *optional*):
A : torch.Tensor A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
The input tensor. For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
code : torch.Tensor absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
The quantization map. out (`torch.Tensor`, *optional*): A tensor to use to store the result.
absmax : torch.Tensor blocksize (`int`, *optional*):
The absmax values. The size of the blocks. Defaults to 4096.
out : torch.Tensor Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
The output tensor (8-bit). nested (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
Returns Raises:
------- ValueError: Raised when the input data type is not supported.
torch.Tensor:
The 8-bit tensor. Returns:
tuple(torch.Tensor, torch.Tensor): `Tuple[torch.Tensor, QuantState]`: A tuple containing the quantization results.
The quantization state to undo the quantization. - `torch.Tensor`: The quantized tensor.
- [`QuantState`]: The state object used to undo the quantization.
""" """
if code is None: if code is None:
...@@ -858,8 +893,7 @@ def quantize_blockwise( ...@@ -858,8 +893,7 @@ def quantize_blockwise(
if absmax is None: if absmax is None:
n = A.numel() n = A.numel()
blocks = n // blocksize blocks = -(n // -blocksize)
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None: if out is None:
...@@ -867,40 +901,30 @@ def quantize_blockwise( ...@@ -867,40 +901,30 @@ def quantize_blockwise(
if A.device.type != "cpu": if A.device.type != "cpu":
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
cblocksize = ct.c_int32(blocksize)
prev_device = pre_call(A.device)
code = code.to(A.device) code = code.to(A.device)
is_on_gpu([code, A, out, absmax])
if A.dtype == torch.float32: is_on_gpu([A, out, absmax])
lib.cquantize_blockwise_fp32(
get_ptr(code), with _cuda_device_of(A):
get_ptr(A), args = (
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
elif A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(
get_ptr(code), get_ptr(code),
get_ptr(A), get_ptr(A),
get_ptr(absmax), get_ptr(absmax),
get_ptr(out), get_ptr(out),
cblocksize, ct.c_int32(blocksize),
ct.c_int(A.numel()), ct.c_int(A.numel()),
) )
if A.dtype == torch.float16:
lib.cquantize_blockwise_fp16(*args)
elif A.dtype == torch.bfloat16: elif A.dtype == torch.bfloat16:
lib.cquantize_blockwise_bf16( lib.cquantize_blockwise_bf16(*args)
get_ptr(code), elif A.dtype == torch.float32:
get_ptr(A), lib.cquantize_blockwise_fp32(*args)
get_ptr(absmax),
get_ptr(out),
cblocksize,
ct.c_int(A.numel()),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
else: else:
# cpu # cpu
code = code.cpu() code = code.cpu()
...@@ -932,39 +956,46 @@ def quantize_blockwise( ...@@ -932,39 +956,46 @@ def quantize_blockwise(
def dequantize_blockwise( def dequantize_blockwise(
A: Tensor, A: torch.Tensor,
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 4096, blocksize: int = 4096,
nested=False, nested=False,
) -> Tensor: ) -> torch.Tensor:
"""Dequantize a tensor in blocks of values.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_blockwise`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
code (`torch.Tensor`, *optional*):
A mapping describing the low-bit data type. Defaults to a signed 8-bit dynamic type.
For more details, see (8-Bit Approximations for Parallelism in Deep Learning)[https://arxiv.org/abs/1511.04561].
Ignored when `quant_state` is provided.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 4096.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
Ignored when `quant_state` is provided.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
`torch.Tensor`:
The dequantized tensor. The datatype is indicated by `quant_state.dtype` and defaults to `torch.float32`.
""" """
Dequantizes blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in
blocks of size 4096.
Parameters
----------
A : torch.Tensor
The input 8-bit tensor.
quant_state : QuantState
Object with code, absmax and other quantization state components.
absmax : torch.Tensor
The absmax values.
code : torch.Tensor
The quantization map.
out : torch.Tensor
Dequantized output tensor (default: float32)
Returns
-------
torch.Tensor:
Dequantized tensor (default: float32)
"""
assert quant_state is not None or absmax is not None assert quant_state is not None or absmax is not None
if code is None and quant_state is None: if code is None and quant_state is None:
if "dynamic" not in name2qmap: if "dynamic" not in name2qmap:
...@@ -985,47 +1016,33 @@ def dequantize_blockwise( ...@@ -985,47 +1016,33 @@ def dequantize_blockwise(
out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device) out = torch.empty(A.shape, dtype=quant_state.dtype, device=A.device)
if A.device.type != "cpu": if A.device.type != "cpu":
device = pre_call(A.device)
code = quant_state.code.to(A.device) code = quant_state.code.to(A.device)
if quant_state.blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if quant_state.blocksize not in [4096, 2048, 1024, 512, 256, 128, 64]:
raise ValueError( raise ValueError(
f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", f"The blocksize of {quant_state.blocksize} is not supported. Supported values: [4096, 2048, 1024, 512, 256, 128, 64]",
) )
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
stream = get_tensor_stream(A)
if out.dtype == torch.float32: with _cuda_device_of(A):
lib.cdequantize_blockwise_fp32( args = (
get_ptr(quant_state.code),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following
)
elif out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(
get_ptr(quant_state.code), get_ptr(quant_state.code),
get_ptr(A), get_ptr(A),
get_ptr(absmax), get_ptr(absmax),
get_ptr(out), get_ptr(out),
ct.c_int(quant_state.blocksize), ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()), ct.c_int(A.numel()),
stream, _get_tensor_stream(A),
) )
if out.dtype == torch.float16:
lib.cdequantize_blockwise_fp16(*args)
elif out.dtype == torch.bfloat16: elif out.dtype == torch.bfloat16:
lib.cdequantize_blockwise_bf16( lib.cdequantize_blockwise_bf16(*args)
get_ptr(quant_state.code), elif out.dtype == torch.float32:
get_ptr(A), lib.cdequantize_blockwise_fp32(*args)
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(A.numel()),
stream,
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
post_call(A.device)
else: else:
code = quant_state.code.cpu() code = quant_state.code.cpu()
lib.cdequantize_blockwise_cpu_fp32( lib.cdequantize_blockwise_cpu_fp32(
...@@ -1123,7 +1140,7 @@ def get_4bit_type(typename, device=None, blocksize=64): ...@@ -1123,7 +1140,7 @@ def get_4bit_type(typename, device=None, blocksize=64):
def quantize_fp4( def quantize_fp4(
A: Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=64,
...@@ -1134,7 +1151,7 @@ def quantize_fp4( ...@@ -1134,7 +1151,7 @@ def quantize_fp4(
def quantize_nf4( def quantize_nf4(
A: Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=64,
...@@ -1145,39 +1162,38 @@ def quantize_nf4( ...@@ -1145,39 +1162,38 @@ def quantize_nf4(
def quantize_4bit( def quantize_4bit(
A: Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=64,
compress_statistics=False, compress_statistics=False,
quant_type="fp4", quant_type="fp4",
quant_storage=torch.uint8, quant_storage=torch.uint8,
) -> Tuple[Tensor, QuantState]: ) -> Tuple[torch.Tensor, QuantState]:
"""Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized.
Args:
A (`torch.Tensor`): The input tensor. Supports `float16`, `bfloat16`, or `float32` datatypes.
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
quant_storage (`torch.dtype`, *optional*): The dtype of the tensor used to store the result. Defaults to `torch.uint8`.
Raises:
ValueError: Raised when the input data type is not supported.
Returns:
Tuple[`torch.Tensor`, `QuantState`]: A tuple containing the quantization results.
- `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization.
""" """
Quantize tensor A in blocks of 4-bit values.
Quantizes tensor A by dividing it into blocks which are independently quantized to FP4.
Parameters
----------
A : torch.Tensor
The input tensor.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
The output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Tensor with packed 4-bit values.
tuple(torch.Tensor, torch.Size, torch.dtype, int):
The quantization state to undo the quantization.
"""
if A.device.type != "cuda": if A.device.type != "cuda":
raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}") raise NotImplementedError(f"Device type not supported for FP4 quantization: {A.device.type}")
if quant_type not in ["fp4", "nf4"]: if quant_type not in ["fp4", "nf4"]:
...@@ -1187,8 +1203,7 @@ def quantize_4bit( ...@@ -1187,8 +1203,7 @@ def quantize_4bit(
input_shape = A.shape input_shape = A.shape
if absmax is None: if absmax is None:
blocks = n // blocksize blocks = -(n // -blocksize)
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
if out is None: if out is None:
...@@ -1197,68 +1212,35 @@ def quantize_4bit( ...@@ -1197,68 +1212,35 @@ def quantize_4bit(
assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]
prev_device = pre_call(A.device)
is_on_gpu([A, out, absmax]) is_on_gpu([A, out, absmax])
if A.dtype == torch.float32:
if quant_type == "fp4": with _cuda_device_of(A):
lib.cquantize_blockwise_fp32_fp4( args = (
get_ptr(None), None,
get_ptr(A), get_ptr(A),
get_ptr(absmax), get_ptr(absmax),
get_ptr(out), get_ptr(out),
ct.c_int32(blocksize), ct.c_int32(blocksize),
ct.c_int(n), ct.c_int(n),
) )
if A.dtype == torch.bfloat16:
if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4(*args)
else: else:
lib.cquantize_blockwise_fp32_nf4( lib.cquantize_blockwise_bf16_nf4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.float16: elif A.dtype == torch.float16:
if quant_type == "fp4": if quant_type == "fp4":
lib.cquantize_blockwise_fp16_fp4( lib.cquantize_blockwise_fp16_fp4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
lib.cquantize_blockwise_fp16_nf4( lib.cquantize_blockwise_fp16_nf4(*args)
get_ptr(None), elif A.dtype == torch.float32:
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
elif A.dtype == torch.bfloat16:
if quant_type == "fp4": if quant_type == "fp4":
lib.cquantize_blockwise_bf16_fp4( lib.cquantize_blockwise_fp32_fp4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
lib.cquantize_blockwise_bf16_nf4( lib.cquantize_blockwise_fp32_nf4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int32(blocksize),
ct.c_int(n),
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)
code = get_4bit_type(quant_type, device=A.device) code = get_4bit_type(quant_type, device=A.device)
...@@ -1291,59 +1273,60 @@ def quantize_4bit( ...@@ -1291,59 +1273,60 @@ def quantize_4bit(
def dequantize_fp4( def dequantize_fp4(
A: Tensor, A: torch.Tensor,
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: int = 64,
) -> Tensor: ) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
def dequantize_nf4( def dequantize_nf4(
A: Tensor, A: torch.Tensor,
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: int = 64,
) -> Tensor: ) -> torch.Tensor:
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
def dequantize_4bit( def dequantize_4bit(
A: Tensor, A: torch.Tensor,
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: int = 64,
quant_type="fp4", quant_type="fp4",
) -> Tensor: ) -> torch.Tensor:
"""Dequantizes a packed 4-bit quantized tensor.
The input tensor is dequantized by dividing it into blocks of `blocksize` values.
The the absolute maximum value within these blocks is used for scaling
the non-linear dequantization.
Args:
A (`torch.Tensor`): The quantized input tensor.
quant_state ([`QuantState`], *optional*):
The quantization state as returned by [`quantize_4bit`].
Required if `absmax` is not provided.
absmax (`torch.Tensor`, *optional*):
A tensor containing the scaling values.
Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
Raises:
ValueError: Raised when the input data type or blocksize is not supported.
Returns:
`torch.Tensor`: The dequantized tensor.
""" """
Dequantizes FP4 blockwise quantized values.
Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize.
Parameters
----------
A : torch.Tensor
The input tensor (packed 4-bit values).
quant_state : QuantState
object with quantisation stats, incl. absmax values, original tensor shape and original dtype.
absmax : torch.Tensor
The absmax values.
out : torch.Tensor
Dequantized output tensor.
blocksize : int
The blocksize used in quantization.
quant_type : str
The 4-bit quantization data type {fp4, nf4}
Returns
-------
torch.Tensor:
Dequantized tensor.
"""
if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]: if blocksize not in [2048, 4096, 1024, 512, 256, 128, 64]:
raise ValueError( raise ValueError(
f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", f"The blockwise of {blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]",
...@@ -1376,13 +1359,12 @@ def dequantize_4bit( ...@@ -1376,13 +1359,12 @@ def dequantize_4bit(
n = out.numel() n = out.numel()
device = pre_call(A.device)
is_on_gpu([A, absmax, out]) is_on_gpu([A, absmax, out])
stream = get_tensor_stream(A) stream = _get_tensor_stream(A)
if out.dtype == torch.float32:
if quant_state.quant_type == "fp4": with _cuda_device_of(A):
lib.cdequantize_blockwise_fp32_fp4( args = (
get_ptr(None), None,
get_ptr(A), get_ptr(A),
get_ptr(absmax), get_ptr(absmax),
get_ptr(out), get_ptr(out),
...@@ -1390,69 +1372,31 @@ def dequantize_4bit( ...@@ -1390,69 +1372,31 @@ def dequantize_4bit(
ct.c_int(n), ct.c_int(n),
stream, stream,
) )
if out.dtype == torch.bfloat16:
if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4(*args)
else: else:
lib.cdequantize_blockwise_fp32_nf4( lib.cdequantize_blockwise_bf16_nf4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.float16: elif out.dtype == torch.float16:
if quant_state.quant_type == "fp4": if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_fp16_fp4( lib.cdequantize_blockwise_fp16_fp4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else: else:
lib.cdequantize_blockwise_fp16_nf4( lib.cdequantize_blockwise_fp16_nf4(*args)
get_ptr(None), elif out.dtype == torch.float32:
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
elif out.dtype == torch.bfloat16:
if quant_state.quant_type == "fp4": if quant_state.quant_type == "fp4":
lib.cdequantize_blockwise_bf16_fp4( lib.cdequantize_blockwise_fp32_fp4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else: else:
lib.cdequantize_blockwise_bf16_nf4( lib.cdequantize_blockwise_fp32_nf4(*args)
get_ptr(None),
get_ptr(A),
get_ptr(absmax),
get_ptr(out),
ct.c_int(quant_state.blocksize),
ct.c_int(n),
stream,
)
else: else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
post_call(A.device)
is_transposed = True if A.shape[0] == 1 else False if A.shape[0] == 1: # is transposed, transpose back
if is_transposed:
return out.t() return out.t()
else:
return out return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize( def quantize(
A: Tensor, A: Tensor,
code: Optional[torch.Tensor] = None, code: Optional[torch.Tensor] = None,
...@@ -1472,6 +1416,7 @@ def quantize( ...@@ -1472,6 +1416,7 @@ def quantize(
return out, (absmax, code) return out, (absmax, code)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize( def dequantize(
A: Tensor, A: Tensor,
state: Optional[Tuple[Tensor, Tensor]] = None, state: Optional[Tuple[Tensor, Tensor]] = None,
...@@ -1492,6 +1437,7 @@ def dequantize( ...@@ -1492,6 +1437,7 @@ def dequantize(
return out * state[0] return out * state[0]
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
""" """
Quantizes input tensor to 8-bit. Quantizes input tensor to 8-bit.
...@@ -1522,6 +1468,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No ...@@ -1522,6 +1468,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
return out return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor: def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
""" """
Dequantizes the 8-bit tensor to 32-bit. Dequantizes the 8-bit tensor to 32-bit.
...@@ -1547,7 +1494,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = ...@@ -1547,7 +1494,7 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] =
if out is None: if out is None:
out = torch.zeros_like(A, dtype=torch.float32) out = torch.zeros_like(A, dtype=torch.float32)
is_on_gpu([code, A, out]) is_on_gpu([code, A, out])
stream = get_tensor_stream(A) stream = _get_tensor_stream(A)
lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream)
post_call(prev_device) post_call(prev_device)
return out return out
...@@ -1632,7 +1579,8 @@ def optimizer_update_32bit( ...@@ -1632,7 +1579,8 @@ def optimizer_update_32bit(
) )
is_on_gpu([g, p, state1, state2, unorm_vec]) is_on_gpu([g, p, state1, state2, unorm_vec])
prev_device = pre_call(g.device)
with _cuda_device_of(g):
optim_func( optim_func(
get_ptr(g), get_ptr(g),
get_ptr(p), get_ptr(p),
...@@ -1653,9 +1601,13 @@ def optimizer_update_32bit( ...@@ -1653,9 +1601,13 @@ def optimizer_update_32bit(
ct.c_bool(skip_zeros), ct.c_bool(skip_zeros),
ct.c_int32(g.numel()), ct.c_int32(g.numel()),
) )
post_call(prev_device)
@deprecated(
"This function is deprecated and will be removed in a future release. "
"Please use optimizer_update_8bit_blockwise instead. ",
category=FutureWarning,
)
def optimizer_update_8bit( def optimizer_update_8bit(
optimizer_name: str, optimizer_name: str,
g: Tensor, g: Tensor,
...@@ -1811,8 +1763,7 @@ def optimizer_update_8bit_blockwise( ...@@ -1811,8 +1763,7 @@ def optimizer_update_8bit_blockwise(
skip_zeros=False, skip_zeros=False,
) -> None: ) -> None:
optim_func = None optim_func = None
prev_device = pre_call(g.device)
is_on_gpu([g, p, state1, state2, qmap1, qmap2, absmax1, absmax2])
if g.dtype == torch.float32 and state1.dtype == torch.uint8: if g.dtype == torch.float32 and state1.dtype == torch.uint8:
optim_func = str2optimizer8bit_blockwise[optimizer_name][0] optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
elif g.dtype == torch.float16 and state1.dtype == torch.uint8: elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
...@@ -1827,11 +1778,10 @@ def optimizer_update_8bit_blockwise( ...@@ -1827,11 +1778,10 @@ def optimizer_update_8bit_blockwise(
raise ValueError( raise ValueError(
f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
) )
post_call(prev_device)
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2]) is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
prev_device = pre_call(g.device) with _cuda_device_of(g):
optim_func( optim_func(
get_ptr(p), get_ptr(p),
get_ptr(g), get_ptr(g),
...@@ -1853,9 +1803,9 @@ def optimizer_update_8bit_blockwise( ...@@ -1853,9 +1803,9 @@ def optimizer_update_8bit_blockwise(
ct.c_bool(skip_zeros), ct.c_bool(skip_zeros),
ct.c_int32(g.numel()), ct.c_int32(g.numel()),
) )
post_call(prev_device)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5): def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
"""Applies percentile clipping """Applies percentile clipping
...@@ -2008,10 +1958,9 @@ def gemv_4bit( ...@@ -2008,10 +1958,9 @@ def gemv_4bit(
transposed_B=False, transposed_B=False,
state=None, state=None,
): ):
prev_device = pre_call(A.device)
# sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype) # sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if state is None: if state is None:
raise ValueError("state cannot None. gem_4bit( ) requires the state from quantize_4bit( )") raise ValueError("state cannot be None. gemv_4bit() requires the state from quantize_4bit()")
if A.numel() != A.shape[-1]: if A.numel() != A.shape[-1]:
raise ValueError( raise ValueError(
...@@ -2044,7 +1993,9 @@ def gemv_4bit( ...@@ -2044,7 +1993,9 @@ def gemv_4bit(
lda = ct.c_int32(lda) lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb) ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc) ldc = ct.c_int32(ldc)
stream = get_tensor_stream(A) stream = _get_tensor_stream(A)
with _cuda_device_of(A):
if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]:
if A.dtype == torch.float16: if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16( lib.cgemm_4bit_inference_naive_fp16(
...@@ -2100,8 +2051,6 @@ def gemv_4bit( ...@@ -2100,8 +2051,6 @@ def gemv_4bit(
else: else:
raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}")
post_call(prev_device)
return out return out
...@@ -2302,179 +2251,288 @@ def batched_igemm( ...@@ -2302,179 +2251,288 @@ def batched_igemm(
return out return out
def igemmlt(A, B, SA, SB, out=None, Sout=None, dtype=torch.int32): @deprecated(
shapeA = SA[0] "igemmlt is deprecated and will be removed in a future release. Please use int8_linear_matmul instead.",
shapeB = SB[0] category=FutureWarning,
dimsA = len(shapeA) )
dimsB = len(shapeB) def igemmlt(
assert dimsB == 2, "Only two dimensional matrices are supported for argument B" A: torch.Tensor,
if dimsA == 2: B: torch.Tensor,
m = shapeA[0] SA: Tuple[torch.Size, str],
elif dimsA == 3: SB: Tuple[torch.Size, str],
m = shapeA[0] * shapeA[1] out: Optional[torch.Tensor] = None,
Sout: Optional[Tuple[torch.Size, str]] = None,
rows = n = shapeB[0] dtype=torch.int32,
assert prod(list(shapeA)) > 0, f"Input tensor dimensions need to be > 0: {shapeA}" ):
if SA is not None and SA[1] != "row":
# if the tensor is empty, return a transformed empty tensor with the right dimensions raise NotImplementedError(f"Only row-major format inputs are supported, but got format `{SA[1]}`")
if shapeA[0] == 0 and dimsA == 2: if SB is not None and SB[1] != "row":
return torch.empty((0, shapeB[0]), device=A.device, dtype=torch.float16) raise NotImplementedError(f"Only row-major format is supported for matrix B, but got format `{SB[1]}`")
elif shapeA[1] == 0 and dimsA == 3: result = int8_linear_matmul(A, B, out=out, dtype=dtype)
return torch.empty(tuple(shapeA[:2] + [shapeB[0]]), device=A.device, dtype=torch.float16) return result, (result.shape, "row")
if dimsA == 2 and out is None:
out, Sout = get_transform_buffer((shapeA[0], shapeB[0]), dtype, A.device, "col32", "row") def int8_linear_matmul(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtype=torch.int32):
elif dimsA == 3 and out is None: """Performs an 8-bit integer matrix multiplication.
out, Sout = get_transform_buffer((shapeA[0], shapeA[1], shapeB[0]), dtype, A.device, "col32", "row")
A linear transformation is applied such that `out = A @ B.T`. When possible, integer tensor core hardware is
assert dimsB != 3, "len(B.shape)==3 not supported" utilized to accelerate the operation.
assert A.device.type == "cuda"
assert B.device.type == "cuda" Args:
A (`torch.Tensor`): The first matrix operand with the data type `torch.int8`.
B (`torch.Tensor`): The second matrix operand with the data type `torch.int8`.
out (`torch.Tensor`, *optional*): A pre-allocated tensor used to store the result.
dtype (`torch.dtype`, *optional*): The expected data type of the output. Defaults to `torch.int32`.
Raises:
`NotImplementedError`: The operation is not supported in the current environment.
`RuntimeError`: Raised when the cannot be completed for any other reason.
Returns:
`torch.Tensor`: The result of the operation.
"""
#
# To use the IMMA tensor core kernels without special Turing/Ampere layouts,
# cublasLt has some rules, namely: A must be transposed, B must not be transposed.
# The C++ API will calculate `C = A.T @ B` in with A, B, C in col-major.
# This will typically be used with row-major tensors to efficiently
# calculate the linear layer with `C = B @ A.T` without any transformations.
# We will swap A and B in the API invocation, so that we get `C = A @ B.T`.
#
# Quick explanation:
# With row-major A and B tensors, `C = A.T.T @ B.T = A @ B.T`.
# To get row-major output, `C.T = (A @ B.T).T = B @ A.T`.
#
A, B = B, A
shapeA = A.shape
shapeB = B.shape
assert A.dtype == torch.int8 assert A.dtype == torch.int8
assert B.dtype == torch.int8 assert B.dtype == torch.int8
assert out.dtype == dtype assert A.ndim == 2, "Only two dimensional matrices are supported for argument B"
assert SA[1] == "col32" assert B.ndim in [2, 3], "Only two or three dimensional matrices are supported for argument A"
assert SB[1] in ["col_turing", "col_ampere"] assert prod(shapeB) > 0, f"Input tensor dimensions need to be > 0: {shapeB}"
assert Sout[1] == "col32" assert out is None or out.dtype == dtype
shapeC = (*shapeB[:-1], shapeA[0])
k, m = shapeA
n = prod(shapeB[:-1])
lda = shapeA[-1] # Weights (outputs, inputs)
ldb = shapeB[-1] # Activations (batch, tokens, inputs)
ldc = shapeC[-1] # Output (batch, tokens, outputs)
assert ( assert (
shapeA[-1] == shapeB[-1] lda == ldb
), f"Matmullt only supports A @ B^T. Inner matrix dimensions do not match: A @ B = {shapeA} @ {shapeB}" ), f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = {shapeB} @ {shapeA}"
formatB = SB[1]
prev_device = A.device # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
torch.cuda.set_device(A.device) # We'll fall back to a slower fp32 calculation in this circumstance.
# Fortunately, this should not be very common.
if lda % 4 != 0:
result = torch.matmul(B.float(), A.float().t()).to(torch.int32)
if out is not None:
result = out.copy_(result)
return result
ptr = CUBLAS_Context.get_instance().get_context(A.device) if out is None:
out = torch.empty(shapeC, device=A.device, dtype=dtype)
is_on_gpu([A, B, out])
with _cuda_device_of(A):
ctx = CUBLAS_Context.get_instance().get_context(A.device)
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrB = get_ptr(B) ptrB = get_ptr(B)
ptrC = get_ptr(out) ptrC = get_ptr(out)
ptrRowScale = None
k = shapeA[-1]
lda = ct.c_int32(m * 32)
if formatB == "col_turing":
# turing: tiles with rows filled up to multiple of 8 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 7) // 8) * 8 * 32)
else:
# ampere: tiles with rows filled up to multiple of 32 rows by 32 columns
# n = rows
ldb = ct.c_int32(((rows + 31) // 32) * 32 * 32)
ldc = ct.c_int32(m * 32)
m = ct.c_int32(m) m = ct.c_int32(m)
n = ct.c_int32(n) n = ct.c_int32(n)
k = ct.c_int32(k) k = ct.c_int32(k)
lda = ct.c_int32(lda)
ldb = ct.c_int32(ldb)
ldc = ct.c_int32(ldc)
stream = _get_tensor_stream(A)
has_error = 0
ptrRowScale = get_ptr(None)
is_on_gpu([A, B, out])
if formatB == "col_turing":
if dtype == torch.int32: if dtype == torch.int32:
has_error = lib.cigemmlt_turing_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) has_error = lib.cigemmlt_32(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
else: else:
has_error = lib.cigemmlt_turing_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc) has_error = lib.cigemmlt_8(ctx, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc, stream)
elif formatB == "col_ampere":
if dtype == torch.int32:
has_error = lib.cigemmlt_ampere_32(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
else:
has_error = lib.cigemmlt_ampere_8(ptr, m, n, k, ptrA, ptrB, ptrC, ptrRowScale, lda, ldb, ldc)
if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu` if has_error == 100: # `ERR_NOT_IMPLEMENTED` is defined as 100 in `ops.cu`
raise NotImplementedError("igemmlt not available (probably built with NO_CUBLASLT)") raise NotImplementedError("int8_linear_matmul not implemented!")
if has_error: if has_error:
print(f"A: {shapeA}, B: {shapeB}, C: {Sout[0]}; (lda, ldb, ldc): {(lda, ldb, ldc)}; (m, n, k): {(m, n, k)}") raise RuntimeError(
raise Exception("cublasLt ran into an error!") f"cublasLt ran into an error!\n"
f"\t{shapeA=}, {shapeB=}, {shapeC=}\n"
f"\t{(lda, ldb, ldc)=}\n"
f"\t{(m, n, k)=}"
)
return out
torch.cuda.set_device(prev_device)
return out, Sout def int8_mm_dequant(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
):
"""Performs dequantization on the result of a quantized int8 matrix multiplication.
Args:
A (`torch.Tensor` with dtype `torch.int32`): The result of a quantized int8 matrix multiplication.
row_stats (`torch.Tensor`): The row-wise quantization statistics for the lhs operand of the matrix multiplication.
col_stats (`torch.Tensor`): The column-wise quantization statistics for the rhs operand of the matrix multiplication.
out (`torch.Tensor`, *optional*): A pre-allocated tensor to store the output of the operation.
bias (`torch.Tensor`, *optional*): An optional bias vector to add to the result.
Returns:
`torch.Tensor`: The dequantized result with an optional bias, with dtype `torch.float16`.
"""
def mm_dequant(A, quant_state, row_stats, col_stats, out=None, new_row_stats=None, new_col_stats=None, bias=None):
assert A.dtype == torch.int32 assert A.dtype == torch.int32
if bias is not None: if bias is not None:
assert bias.dtype == torch.float16 assert bias.dtype == torch.float16
out_shape = quant_state[0]
if len(out_shape) == 3:
out_shape = (out_shape[0] * out_shape[1], out_shape[2])
if out is None: if out is None:
out = torch.empty(out_shape, dtype=torch.float16, device=A.device) out = torch.empty_like(A, dtype=torch.float16)
if new_row_stats is None:
new_row_stats = torch.empty(out_shape[0], dtype=torch.float32, device=A.device)
if new_col_stats is None:
new_col_stats = torch.empty(out_shape[1], dtype=torch.float32, device=A.device)
assert new_row_stats.shape[0] == row_stats.shape[0], f"{new_row_stats.shape} vs {row_stats.shape}"
assert new_col_stats.shape[0] == col_stats.shape[0], f"{new_col_stats.shape} vs {col_stats.shape}"
prev_device = pre_call(A.device)
ptrA = get_ptr(A) ptrA = get_ptr(A)
ptrOut = get_ptr(out) ptrOut = get_ptr(out)
ptrRowStats = get_ptr(row_stats) ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats) ptrColStats = get_ptr(col_stats)
ptrNewRowStats = get_ptr(new_row_stats)
ptrNewColStats = get_ptr(new_col_stats)
ptrBias = get_ptr(bias) ptrBias = get_ptr(bias)
numRows = ct.c_int32(out_shape[0]) numRows = ct.c_int32(prod(A.shape[:-1]))
numCols = ct.c_int32(out_shape[1]) numCols = ct.c_int32(A.shape[-1])
is_on_gpu([A, row_stats, col_stats, out, new_row_stats, new_col_stats, bias]) is_on_gpu([A, row_stats, col_stats, out, bias])
with _cuda_device_of(A):
lib.cdequant_mm_int32_fp16( lib.cdequant_mm_int32_fp16(
ptrA, ptrA, ptrRowStats, ptrColStats, ptrOut, ptrBias, numRows, numCols, _get_tensor_stream(A)
ptrRowStats,
ptrColStats,
ptrOut,
ptrNewRowStats,
ptrNewColStats,
ptrBias,
numRows,
numCols,
) )
post_call(prev_device)
return out return out
def get_colrow_absmax(A, row_stats=None, col_stats=None, nnz_block_ptr=None, threshold=0.0): @deprecated("mm_dequant is deprecated. Please use int8_mm_dequant() instead.", category=FutureWarning)
assert A.dtype == torch.float16 def mm_dequant(
device = A.device A: torch.Tensor,
quant_state: Optional[Tuple[torch.Size, str]], # Not used
row_stats: torch.Tensor,
col_stats: torch.Tensor,
out: Optional[torch.Tensor] = None,
new_row_stats=None, # Not used
new_col_stats=None, # Not used
bias: Optional[torch.Tensor] = None,
):
return int8_mm_dequant(A, row_stats, col_stats, out, bias)
def get_colrow_absmax(
A: torch.Tensor,
row_stats: Optional[torch.Tensor] = None,
col_stats: Optional[torch.Tensor] = None,
nnz_block_ptr: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
""" "Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The row-wise and column-wise absmax values are determined.
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip>
This function is useful for training, but for inference it is advised to use [`get_row_absmax`] instead.
The column-wise quantization scales are not typically needed in inference scenarios.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): Input tensor.
row_stats (`torch.Tensor`, *optional*): If provided, calculation of row statistics is skipped.
col_stats (`torch.Tensor`, *optional*): If provided, calculation of column statistics is skipped.
nnz_block_ptr (`torch.Tensor`, *optional*): Not used.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization statistics.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization statistics.
- `torch.Tensor` with dtype `torch.bool`, *optional*: A mask indicating the locations of outliers in the input tensor.
"""
assert A.is_floating_point()
cols = A.shape[-1] outlier_mask = None
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1] if row_stats is None or col_stats is None:
else: absA = A.abs().view(-1, A.shape[-1])
rows = A.shape[0]
if threshold > 0.0:
# Filter outliers from stats when enabled
outlier_mask = absA >= threshold
absA.masked_fill_(outlier_mask, 0.0)
col_tiles = (cols + 255) // 256
tiled_rows = ((rows + 15) // 16) * 16
if row_stats is None: if row_stats is None:
row_stats = torch.empty((rows,), dtype=torch.float32, device=device).fill_(-50000.0) # shape [rows]; unsqueeze(-1) gives [rows,1]
# We have a CUDA kernel for row max, but not yet for cols.
row_stats = get_row_absmax(A, threshold)
if col_stats is None: if col_stats is None:
col_stats = torch.empty((cols,), dtype=torch.float32, device=device).fill_(-50000.0) # shape [cols]; unsqueeze(0) gives [1,cols]
col_stats = absA.amax(dim=0, keepdim=False).float()
if nnz_block_ptr is None and threshold > 0.0: return row_stats, col_stats, outlier_mask
nnz_block_ptr = torch.zeros(((tiled_rows * col_tiles) + 1,), dtype=torch.int32, device=device)
ptrA = get_ptr(A)
ptrRowStats = get_ptr(row_stats)
ptrColStats = get_ptr(col_stats)
ptrNnzrows = get_ptr(nnz_block_ptr)
rows = ct.c_int32(rows)
cols = ct.c_int32(cols)
prev_device = pre_call(A.device) def get_row_absmax(A: torch.Tensor, threshold=0.0):
is_on_gpu([A, row_stats, col_stats, nnz_block_ptr]) """Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
lib.cget_col_row_stats(ptrA, ptrRowStats, ptrColStats, ptrNnzrows, ct.c_float(threshold), rows, cols)
post_call(prev_device)
if threshold > 0.0: For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
nnz_block_ptr.cumsum_(0)
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`torch.Tensor` with dtype `torch.float32`: The absolute maximum value for each row, with outliers ignored.
"""
assert A.dtype == torch.float16
rows = prod(A.shape[:-1])
cols = A.shape[-1]
row_stats = torch.empty((rows,), dtype=torch.float32, device=A.device)
is_on_gpu([A])
with _cuda_device_of(A):
lib.cget_row_stats(
get_ptr(A),
get_ptr(row_stats),
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
_get_tensor_stream(A),
)
return row_stats, col_stats, nnz_block_ptr return row_stats
class COOSparseTensor: class COOSparseTensor:
def __init__(self, rows, cols, nnz, rowidx, colidx, values): def __init__(
self, rows: int, cols: int, nnz: int, rowidx: torch.Tensor, colidx: torch.Tensor, values: torch.Tensor
):
assert rowidx.dtype == torch.int32 assert rowidx.dtype == torch.int32
assert colidx.dtype == torch.int32 assert colidx.dtype == torch.int32
assert values.dtype == torch.float16 assert values.dtype == torch.float16
...@@ -2552,96 +2610,204 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half): ...@@ -2552,96 +2610,204 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values) return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)
def double_quant(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): @deprecated("This function is deprecated. Please use `int8_double_quant` instead.", category=FutureWarning)
device = A.device def double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[COOSparseTensor]]:
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip warning={true}>
This function exists for backwards compatibility only. It is advised to use [`int8_double_quant`] instead.
The difference is that this function will return a [`COOSparseTensor`] for outliers instead of a column index.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `COOSparseTensor`, *optional*: A structure representing the outlier values from the input tensor.
"""
coo_tensor = None
quant_row, quant_col, row_stats, col_stats, outlier_cols = int8_double_quant(
A,
col_stats,
row_stats,
out_col,
out_row,
threshold=threshold,
)
if threshold > 0.0 and outlier_cols is not None:
# Build a COO tensor including all of the outlier columns.
outlier_rows = torch.arange(0, A.shape[0], device=A.device, dtype=torch.int32)
outliers = A[:, outlier_cols]
coo_tensor = COOSparseTensor(
A.shape[0],
A.shape[1],
outliers.numel(),
outlier_rows.repeat_interleave(outliers.size(1)),
outlier_cols.repeat(outliers.size(0)).int(),
outliers,
)
return quant_row, quant_col, row_stats, col_stats.flatten().float(), coo_tensor
def int8_double_quant(
A: torch.Tensor,
col_stats: Optional[torch.Tensor] = None,
row_stats: Optional[torch.Tensor] = None,
out_col: Optional[torch.Tensor] = None,
out_row: Optional[torch.Tensor] = None,
threshold=0.0,
):
"""Determine the quantization statistics for input matrix `A` in accordance to the `LLM.int8()` algorithm.
The statistics are determined both row-wise and column-wise (transposed).
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
<Tip>
This function is useful for training, but for inference it is advised to use [`int8_vectorwise_quant`] instead.
This implementation performs additional column-wise transposed calculations which are not optimized.
</Tip>
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input matrix.
col_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantization scales.
row_stats (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantization scales.
out_col (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the column-wise quantized data.
out_row (`torch.Tensor`, *optional*): A pre-allocated tensor to hold the row-wise quantized data.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The row-wise quantized data.
- `torch.Tensor` with dtype `torch.int8`: The column-wise quantized data.
- `torch.Tensor` with dtype `torch.float32`: The row-wise quantization scales.
- `torch.Tensor` with dtype `torch.float32`: The column-wise quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
# TODO: Optimize/write CUDA kernel for this?
# Use CUDA kernel for rowwise and COO tensor
quant_row, row_stats, outlier_cols = int8_vectorwise_quant(A, threshold=threshold)
# PyTorch impl for colwise
_, col_stats, outlier_mask = get_colrow_absmax(A, threshold=threshold)
if threshold > 0.0 and outlier_mask is not None:
A = A.masked_fill(outlier_mask, 0.0)
quant_col = torch.round(A.mul(C) / col_stats.unsqueeze(0)).to(torch.int8)
if out_row is not None:
quant_row = out_row.copy_(quant_row)
if out_col is not None:
quant_col = out_col.copy_(quant_col)
return quant_row, quant_col, row_stats, col_stats.flatten().float(), outlier_cols
def int8_vectorwise_dequant(A: torch.Tensor, stats: torch.Tensor):
"""Dequantizes a tensor with dtype `torch.int8` to `torch.float32`.
Args:
A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor.
stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics.
Returns:
`torch.Tensor` with dtype `torch.float32`: The dequantized tensor.
"""
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3
def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
"""Quantizes a tensor with dtype `torch.float16` to `torch.int8` in accordance to the `LLM.int8()` algorithm.
For more information, see the [LLM.int8() paper](https://arxiv.org/abs/2208.07339).
Args:
A (`torch.Tensor` with dtype `torch.float16`): The input tensor.
threshold (`float`, *optional*):
An optional threshold for sparse decomposition of outlier features.
No outliers are held back when 0.0. Defaults to 0.0.
Returns:
`Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]`: A tuple containing the quantized tensor and relevant statistics.
- `torch.Tensor` with dtype `torch.int8`: The quantized data.
- `torch.Tensor` with dtype `torch.float32`: The quantization scales.
- `torch.Tensor` with dtype `torch.int32`, *optional*: A list of column indices which contain outlier features.
"""
assert A.dtype == torch.half assert A.dtype == torch.half
assert device.type == "cuda" is_on_gpu([A])
prev_device = pre_call(A.device)
rows = prod(A.shape[:-1])
cols = A.shape[-1] cols = A.shape[-1]
if len(A.shape) == 3:
rows = A.shape[0] * A.shape[1]
else:
rows = A.shape[0]
if row_stats is None or col_stats is None: row_stats = torch.empty(rows, device=A.device, dtype=torch.float32)
row_stats, col_stats, nnz_row_ptr = get_colrow_absmax(A, threshold=threshold) out_row = torch.empty(A.shape, device=A.device, dtype=torch.int8)
if out_col is None: outlier_cols = None
out_col = torch.zeros(A.shape, device=device, dtype=torch.int8)
if out_row is None:
out_row = torch.zeros(A.shape, device=device, dtype=torch.int8)
coo_tensor = None
ptrA = get_ptr(A)
ptrColStats = get_ptr(col_stats)
ptrRowStats = get_ptr(row_stats)
ptrOutCol = get_ptr(out_col)
ptrOutRow = get_ptr(out_row)
is_on_gpu([A, col_stats, row_stats, out_col, out_row])
if threshold > 0.0: if threshold > 0.0:
nnz = nnz_row_ptr[-1].item() # TODO we could improve perf of this
if nnz > 0: outliers = A.abs() >= threshold
coo_tensor = coo_zeros(A.shape[0], A.shape[1], nnz_row_ptr[-1].item(), device)
ptrRowIdx = get_ptr(coo_tensor.rowidx) if outliers.any():
ptrColIdx = get_ptr(coo_tensor.colidx) outlier_cols = torch.argwhere(outliers.any(dim=0)).view(-1)
ptrVal = get_ptr(coo_tensor.values)
ptrRowPtr = get_ptr(nnz_row_ptr) with _cuda_device_of(A):
lib.cint8_vector_quant(
lib.cdouble_rowcol_quant( get_ptr(A),
ptrA, get_ptr(out_row),
ptrRowStats, get_ptr(row_stats),
ptrColStats,
ptrOutCol,
ptrOutRow,
ptrRowIdx,
ptrColIdx,
ptrVal,
ptrRowPtr,
ct.c_float(threshold),
ct.c_int32(rows),
ct.c_int32(cols),
)
val, idx = torch.sort(coo_tensor.rowidx)
coo_tensor.rowidx = val
coo_tensor.colidx = coo_tensor.colidx[idx]
coo_tensor.values = coo_tensor.values[idx]
else:
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
ct.c_float(0.0),
ct.c_int32(rows),
ct.c_int32(cols),
)
else:
lib.cdouble_rowcol_quant(
ptrA,
ptrRowStats,
ptrColStats,
ptrOutCol,
ptrOutRow,
None,
None,
None,
None,
ct.c_float(threshold), ct.c_float(threshold),
ct.c_int32(rows), ct.c_int32(rows),
ct.c_int32(cols), ct.c_int32(cols),
_get_tensor_stream(A),
) )
post_call(prev_device)
return out_row, out_col, row_stats, col_stats, coo_tensor # Zero out values from outlier columns across all rows.
# The kernel will handle this for outliers themselves, so we can optimize for rows=1.
if rows > 1 and outlier_cols is not None:
out_row[:, outlier_cols] = 0
return out_row, row_stats, outlier_cols
@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None): def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device) prev_device = pre_call(A.device)
if state is None: if state is None:
...@@ -2690,7 +2856,26 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No ...@@ -2690,7 +2856,26 @@ def transform(A, to_order, from_order="row", out=None, transpose=False, state=No
return out, new_state return out, new_state
def spmm_coo(cooA, B, out=None): def spmm_coo(
cooA: Union[COOSparseTensor, torch.Tensor],
B: torch.Tensor,
out: Optional[torch.Tensor] = None,
):
if not isinstance(cooA, COOSparseTensor):
assert (
cooA.is_sparse and cooA.layout == torch.sparse_coo
), "Tensor must be `COOSparseTensor or a PyTorch COO tensor."
# Convert to custom COOSparseTensor
cooA = COOSparseTensor(
rows=cooA.shape[0],
cols=cooA.shape[1],
nnz=cooA._nnz(),
rowidx=cooA.indices()[0].int(),
colidx=cooA.indices()[1].int(),
values=cooA.values(),
)
if out is None: if out is None:
out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype) out = torch.empty((cooA.rows, B.shape[1]), device=B.device, dtype=B.dtype)
nnz = cooA.nnz nnz = cooA.nnz
...@@ -2823,6 +3008,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None): ...@@ -2823,6 +3008,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0 C = 127.0
@deprecated(
"This function is deprecated and will be removed in a future release. "
"Consider using `int8_vectorwise_quant` instead.",
category=FutureWarning,
)
def vectorwise_quant(x, dim=1, quant_type="vector"): def vectorwise_quant(x, dim=1, quant_type="vector"):
if quant_type == "linear": if quant_type == "linear":
max1 = torch.abs(x).max().float() max1 = torch.abs(x).max().float()
...@@ -2867,6 +3057,10 @@ def vectorwise_quant(x, dim=1, quant_type="vector"): ...@@ -2867,6 +3057,10 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
return None return None
@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_vectorwise_dequant` instead.",
category=FutureWarning,
)
def vectorwise_dequant(xq, max1, quant_type="vector"): def vectorwise_dequant(xq, max1, quant_type="vector"):
if quant_type == "vector": if quant_type == "vector":
x = (xq / C * max1).to(torch.float32) x = (xq / C * max1).to(torch.float32)
...@@ -2875,6 +3069,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"): ...@@ -2875,6 +3069,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
return None return None
@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `int8_mm_dequant` instead.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
if quant_type == "linear": if quant_type == "linear":
norm = S1 * S2 / (C * C) norm = S1 * S2 / (C * C)
...@@ -2934,6 +3132,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"): ...@@ -2934,6 +3132,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return None return None
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
offset = B.float().t().sum(0) * (SA[0] + SA[1]) offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float() x = xq.float()
...@@ -2948,6 +3147,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half): ...@@ -2948,6 +3147,7 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
return x.to(dtype) return x.to(dtype)
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def extract_outliers(A, SA, idx): def extract_outliers(A, SA, idx):
shapeA = SA[0] shapeA = SA[0]
formatA = SA[1] formatA = SA[1]
...@@ -2973,6 +3173,7 @@ def extract_outliers(A, SA, idx): ...@@ -2973,6 +3173,7 @@ def extract_outliers(A, SA, idx):
return out return out
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def pipeline_test(A, batch_size): def pipeline_test(A, batch_size):
out = torch.zeros_like(A) out = torch.zeros_like(A)
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size)) lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
......
...@@ -16,7 +16,6 @@ from bitsandbytes.functional import QuantState ...@@ -16,7 +16,6 @@ from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
OutlierTracer, OutlierTracer,
) )
...@@ -481,11 +480,8 @@ class Linear4bit(nn.Linear): ...@@ -481,11 +480,8 @@ class Linear4bit(nn.Linear):
x = x.to(self.compute_dtype) x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype) bias = None if self.bias is None else self.bias.to(self.compute_dtype)
out = bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state)
out = out.to(inp_dtype) return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
return out
class LinearFP4(Linear4bit): class LinearFP4(Linear4bit):
...@@ -570,11 +566,11 @@ class LinearNF4(Linear4bit): ...@@ -570,11 +566,11 @@ class LinearNF4(Linear4bit):
class Int8Params(torch.nn.Parameter): class Int8Params(torch.nn.Parameter):
def __new__( def __new__(
cls, cls,
data=None, data: Optional[torch.Tensor] = None,
requires_grad=True, requires_grad=True,
has_fp16_weights=False, has_fp16_weights=False,
CB=None, CB: Optional[torch.Tensor] = None,
SCB=None, SCB: Optional[torch.Tensor] = None,
): ):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
...@@ -588,12 +584,9 @@ class Int8Params(torch.nn.Parameter): ...@@ -588,12 +584,9 @@ class Int8Params(torch.nn.Parameter):
if self.has_fp16_weights: if self.has_fp16_weights:
return super().cuda(device) return super().cuda(device)
else: else:
# we store the 8-bit rows-major weight # We quantize the weight and store in 8bit row-major
# we convert this weight to the turning/ampere weight during the first inference pass
B = self.data.contiguous().half().cuda(device) B = self.data.contiguous().half().cuda(device)
CB, CBt, SCB, SCBt, coo_tensorB = bnb.functional.double_quant(B) CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B)
del CBt
del SCBt
self.data = CB self.data = CB
self.CB = CB self.CB = CB
self.SCB = SCB self.SCB = SCB
...@@ -888,7 +881,6 @@ class Linear8bitLt(nn.Linear): ...@@ -888,7 +881,6 @@ class Linear8bitLt(nn.Linear):
output_features: int, output_features: int,
bias=True, bias=True,
has_fp16_weights=True, has_fp16_weights=True,
memory_efficient_backward=False,
threshold=0.0, threshold=0.0,
index=None, index=None,
device=None, device=None,
...@@ -905,13 +897,12 @@ class Linear8bitLt(nn.Linear): ...@@ -905,13 +897,12 @@ class Linear8bitLt(nn.Linear):
Whether the linear class uses the bias term as well. Whether the linear class uses the bias term as well.
""" """
super().__init__(input_features, output_features, bias, device) super().__init__(input_features, output_features, bias, device)
assert not memory_efficient_backward, "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
self.state = bnb.MatmulLtState() self.state = bnb.MatmulLtState()
self.index = index self.index = index
self.state.threshold = threshold self.state.threshold = threshold
self.state.has_fp16_weights = has_fp16_weights self.state.has_fp16_weights = has_fp16_weights
self.state.memory_efficient_backward = memory_efficient_backward
if threshold > 0.0 and not has_fp16_weights: if threshold > 0.0 and not has_fp16_weights:
self.state.use_pool = True self.state.use_pool = True
...@@ -928,29 +919,19 @@ class Linear8bitLt(nn.Linear): ...@@ -928,29 +919,19 @@ class Linear8bitLt(nn.Linear):
param_from_weight = getattr(self.weight, scb_name) param_from_weight = getattr(self.weight, scb_name)
# case 2: self.init_8bit_state was called, SCB is in self.state # case 2: self.init_8bit_state was called, SCB is in self.state
param_from_state = getattr(self.state, scb_name) param_from_state = getattr(self.state, scb_name)
# case 3: SCB is in self.state, weight layout reordered after first forward()
layout_reordered = self.state.CxB is not None
key_name = prefix + f"{scb_name}" key_name = prefix + f"{scb_name}"
# We now only save in row-major. This format information is stored for backwards compatibility.
format_name = prefix + "weight_format" format_name = prefix + "weight_format"
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights:
if param_from_weight is not None: if param_from_weight is not None:
destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach() destination[key_name] = param_from_weight if keep_vars else param_from_weight.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8) destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None and not layout_reordered:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
destination[format_name] = torch.tensor(0, dtype=torch.uint8)
elif param_from_state is not None: elif param_from_state is not None:
destination[key_name] = param_from_state if keep_vars else param_from_state.detach() destination[key_name] = param_from_state if keep_vars else param_from_state.detach()
weights_format = self.state.formatB destination[format_name] = torch.tensor(0, dtype=torch.uint8)
# At this point `weights_format` is an str
if weights_format not in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING:
raise ValueError(f"Unrecognized weights format {weights_format}")
weights_format = LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING[weights_format]
destination[format_name] = torch.tensor(weights_format, dtype=torch.uint8)
def _load_from_state_dict( def _load_from_state_dict(
self, self,
...@@ -1008,12 +989,9 @@ class Linear8bitLt(nn.Linear): ...@@ -1008,12 +989,9 @@ class Linear8bitLt(nn.Linear):
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state) out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
if not self.state.has_fp16_weights: if not self.state.has_fp16_weights and self.state.CB is not None:
if self.state.CB is not None and self.state.CxB is not None: self.weight.data = self.state.CB
# we converted 8-bit row major to turing/ampere format in the first inference pass
# we no longer need the row-major weight
del self.state.CB
self.weight.data = self.state.CxB
return out return out
......
...@@ -184,9 +184,9 @@ class MatMulFP8Global(torch.autograd.Function): ...@@ -184,9 +184,9 @@ class MatMulFP8Global(torch.autograd.Function):
class SwitchBackBnb(torch.autograd.Function): class SwitchBackBnb(torch.autograd.Function):
@staticmethod @staticmethod
# TODO: the B008 on the line below is a likely bug; the current implementation will def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = None):
# have each SwitchBackBnb instance share a single MatmulLtState instance!!! state = state or MatmulLtState()
def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState()): # noqa: B008
# default to pytorch behavior if inputs are empty # default to pytorch behavior if inputs are empty
ctx.is_empty = False ctx.is_empty = False
if prod(A.shape) == 0: if prod(A.shape) == 0:
...@@ -204,7 +204,6 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -204,7 +204,6 @@ class SwitchBackBnb(torch.autograd.Function):
# 3. Matmul # 3. Matmul
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
# 5. Save state # 5. Save state
formatB = state.formatB
input_shape = A.shape input_shape = A.shape
if state.outlier_pool is None: if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance() state.outlier_pool = GlobalOutlierPooler.get_instance()
...@@ -216,25 +215,21 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -216,25 +215,21 @@ class SwitchBackBnb(torch.autograd.Function):
# 1. Quantize A # 1. Quantize A
if len(A.shape) == 3: if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous() A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A.to(torch.float16), threshold=state.threshold) CA, CAt, SCA, SCAt, outlier_cols = F.int8_double_quant(A.to(torch.float16), threshold=state.threshold)
if state.threshold > 0.0 and coo_tensorA is not None: if state.threshold > 0.0 and outlier_cols is not None:
if state.has_fp16_weights: if state.has_fp16_weights:
idx = torch.unique(coo_tensorA.colidx).long() idx = outlier_cols
CA[:, idx] = 0 CA[:, idx] = 0
CAt[:, idx] = 0
subA = A[:, idx] subA = A[:, idx]
state.subB = B[:, idx].t().contiguous() state.subB = B[:, idx].t().contiguous()
state.idx = idx state.idx = idx
else: else:
if state.CxB is None: if state.SB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions state.SB = (state.CB.shape, "row")
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
else: else:
# print('A shape', A.shape) if not state.has_fp16_weights and state.SB is None:
if not state.has_fp16_weights and state.CxB is None: state.SB = (state.CB.shape, "row")
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None subA = None
# 2. Quantize B # 2. Quantize B
...@@ -245,34 +240,26 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -245,34 +240,26 @@ class SwitchBackBnb(torch.autograd.Function):
if is_transposed: if is_transposed:
B = B.contiguous() B = B.contiguous()
if (state.is_training and not has_grad) or state.CxB is None: if (state.is_training and not has_grad) or state.SB is None:
state.reset_grads() state.reset_grads()
( (
CB, state.CB,
state.CBt, state.CBt,
state.SCB, state.SCB,
state.SCBt, state.SCBt,
coo_tensorB, _,
) = F.double_quant(B.to(torch.float16)) ) = F.int8_double_quant(B.to(torch.float16))
state.CxB, state.SB = F.transform(CB, to_order=formatB) state.SB = (state.CB.shape, "row")
else: else:
has_grad = False has_grad = False
if coo_tensorA is not None and not state.has_fp16_weights: if outlier_cols is not None and not state.has_fp16_weights:
# extract outliers # extract outliers
state.idx = outlier_cols
outlier_idx = torch.unique(coo_tensorA.colidx) outliers = state.CB[:, state.idx.long()].clone()
state.idx = outlier_idx
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype) state.subB = (outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().to(A.dtype)
CA[:, state.idx.long()] = 0 CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()] subA = A[:, state.idx.long()]
shapeB = state.SB[0] shapeB = state.SB[0]
...@@ -283,25 +270,22 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -283,25 +270,22 @@ class SwitchBackBnb(torch.autograd.Function):
output_shape = (input_shape[0], shapeB[0]) output_shape = (input_shape[0], shapeB[0])
# 3. Matmul # 3. Matmul
C32A, SA = F.transform(CA, "col32") out32 = F.int8_linear_matmul(CA, state.CB)
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
# we apply the fused bias here # we apply the fused bias here
if bias is None or bias.dtype == torch.float16: if bias is None or bias.dtype == torch.float16:
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=bias) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=bias).to(A.dtype)
output = output.to(A.dtype)
else: # apply bias separately else: # apply bias separately
output = F.mm_dequant(out32, Sout32, SCA, state.SCB, bias=None) output = F.int8_mm_dequant(out32, SCA, state.SCB, bias=None).to(A.dtype)
output = output.to(A.dtype).add_(bias) output.add_(bias)
# 4. Mixed-precision decomposition matmul # 4. Mixed-precision decomposition matmul
if coo_tensorA is not None and subA is not None: if outlier_cols is not None and subA is not None:
output += torch.matmul(subA, state.subB) output += torch.matmul(subA, state.subB)
# 5. Save state # 5. Save state
ctx.state = state ctx.state = state
ctx.formatB = formatB
ctx.grad_shape = input_shape ctx.grad_shape = input_shape
ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype ctx.dtype_A, ctx.dtype_B, ctx.dtype_bias = A.dtype, B.dtype, None if bias is None else bias.dtype
...@@ -321,10 +305,10 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -321,10 +305,10 @@ class SwitchBackBnb(torch.autograd.Function):
if ctx.is_empty: if ctx.is_empty:
bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias) bias_grad = None if ctx.bias is None else torch.zeros_like(ctx.bias)
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, bias_grad, None
req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad req_gradA, req_gradB, _, req_gradBias, _ = ctx.needs_input_grad
CAt, subA, A = ctx.tensors CAt, subA, A = ctx.tensors
SCAt, idx = ctx.tensor_states SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state state = ctx.state
grad_A = grad_B = grad_bias = None grad_A = grad_B = grad_bias = None
...@@ -336,7 +320,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -336,7 +320,7 @@ class SwitchBackBnb(torch.autograd.Function):
if len(grad_output.shape) == 3: if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous() grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output.to(torch.float16)) Cgrad, Cgradt, SCgrad, SCgradt, outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))
if req_gradB: if req_gradB:
# print('back A shape', A.shape) # print('back A shape', A.shape)
...@@ -344,16 +328,7 @@ class SwitchBackBnb(torch.autograd.Function): ...@@ -344,16 +328,7 @@ class SwitchBackBnb(torch.autograd.Function):
grad_B = torch.matmul(grad_output.t(), A) grad_B = torch.matmul(grad_output.t(), A)
if req_gradA: if req_gradA:
if state.CBt is not None: if state.CB is not None:
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
# print('back B shape', state.CxBt.shape)
# print('back grad shape', C32grad.shape)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
elif state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0)) CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A) grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else: else:
......
#pragma once
// TODO: Let's make some of these constexpr and put in a namespace.
#define BNB_CC_MAXWELL 500
#define BNB_CC_MAXWELL2 520
#define BNB_CC_MAXWELL2_X1 530
#define BNB_CC_PASCAL 600
#define BNB_CC_PASCAL_X2 620
#define BNB_CC_VOLTA 700
#define BNB_CC_VOLTA_XAVIER 720
#define BNB_CC_TURING 750
#define BNB_CC_AMPERE 800
#define BNB_CC_AMPERE2 860
#define BNB_CC_AMPERE2_ORIN 870
#define BNB_CC_ADA 890
#define BNB_CC_HOPPER 900
#define BNB_CC_BLACKWELL 1000
#define BNB_FP16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_MAXWELL2_X1)
#define BNB_FP16_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA)
#define BNB_INT8_MMA_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_VOLTA_XAVIER)
#define BNB_BF16_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_AMPERE)
#define BNB_FP8_AVAILABLE (__CUDA_ARCH__ >= BNB_CC_ADA)
#define BNB_WARP_SIZE 32
// The maximum number of resident threads per SM varies by arch.
// For A100/H100 and all prior to Turing, it is 2048, which allows
// for 2 full blocks of 1024 threads per SM.
// Reference: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications-technical-specifications-per-compute-capability
#if __CUDA_ARCH__ == 750
#define BNB_MAX_THREADS_PER_SM 1024
#elif __CUDA_ARCH__ >= 860 && __CUDA_ARCH__ <= 890
#define BNB_MAX_THREADS_PER_SM 1536
#else
#define BNB_MAX_THREADS_PER_SM 2048
#endif
// Maximum resident warps per SM is always directly related to the number of threads.
#define BNB_MAX_WARPS_PER_SM ((BNB_MAX_THREADS_PER_SM) / (BNB_WARP_SIZE))
// Maximum resident blocks per SM may vary.
#if __CUDA_ARCH__ == 860 || __CUDA_ARCH__ == 870
#define BNB_MAX_BLOCKS_PER_SM 16
#else
#define BNB_MAX_BLOCKS_PER_SM ((BNB_MAX_WARPS_PER_SM) / 2)
#endif
...@@ -3,7 +3,9 @@ ...@@ -3,7 +3,9 @@
// This source code is licensed under the MIT license found in the // This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree. // LICENSE file in the root directory of this source tree.
#include <kernels.cuh> #include "kernels.cuh"
#include "common.cuh"
#include <cuda_fp16.h>
#include <cub/block/block_radix_sort.cuh> #include <cub/block/block_radix_sort.cuh>
#include <cub/warp/warp_reduce.cuh> #include <cub/warp/warp_reduce.cuh>
#include <cub/block/block_load.cuh> #include <cub/block/block_load.cuh>
...@@ -219,7 +221,7 @@ __device__ half dhDequantizeNF4(unsigned char val) ...@@ -219,7 +221,7 @@ __device__ half dhDequantizeNF4(unsigned char val)
} }
__device__ float dDequantizeNF4(unsigned char val) __device__ __forceinline__ float dDequantizeNF4(unsigned char val)
{ {
// the values for this tree was generated by test_normal_map_tree // the values for this tree was generated by test_normal_map_tree
...@@ -627,7 +629,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -627,7 +629,7 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
for(int i = threadIdx.x; i < 256; i+=blockDim.x) for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i]; smem_code[i] = code[i];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE) for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{ {
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i; valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX; local_abs_max = -FLT_MAX;
...@@ -645,20 +647,14 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -645,20 +647,14 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items);
if(threadIdx.x == 0) if (threadIdx.x == 0) {
smem_absmax_value[0] = local_abs_max; smem_absmax_value[0] = 1.0f / local_abs_max;
absmax[i / BLOCK_SIZE] = local_abs_max;
}
__syncthreads(); __syncthreads();
if(threadIdx.x == 0)
absmax[i/BLOCK_SIZE] = local_abs_max;
else
local_abs_max = smem_absmax_value[0]; local_abs_max = smem_absmax_value[0];
__syncwarp();
local_abs_max = 1.0f/local_abs_max;
if(STOCHASTIC) if(STOCHASTIC)
{ {
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4); local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
...@@ -722,24 +718,28 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs ...@@ -722,24 +718,28 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
__shared__ typename LoadChar::TempStorage loadchar; __shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet; __shared__ typename StoreT::TempStorage storet;
for (unsigned int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE) for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
{ {
if(DATA_TYPE > 0) if (DATA_TYPE > 0)
{ {
valid_items_load = (n+1)/2 - i > TILE_SIZE ? TILE_SIZE : (n+1)/2 - i; valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
valid_items_store = n - i*2 > TILE_SIZE*2 ? TILE_SIZE*2 : n - i*2; valid_items_store = min(TILE_SIZE * 2, n - i * 2);
} }
else else
{ {
valid_items_load = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_load = min(TILE_SIZE, n - i);
valid_items_store = n - i > TILE_SIZE ? TILE_SIZE : n - i; valid_items_store = valid_items_load;
} }
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH)/(blocksize)]);
// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]);
__syncthreads(); __syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128); LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
switch(DATA_TYPE) switch (DATA_TYPE)
{ {
case General8bit: case General8bit:
// load code through read-only cache via __ldg // load code through read-only cache via __ldg
...@@ -2134,386 +2134,182 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char ...@@ -2134,386 +2134,182 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
} }
} }
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols) // Inputs:
{ // A [rows, cols]
// 0. reset stats to -FLT_MAX // Outputs:
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD) // rowStats [rows]
// 2. compute col max (per thread); store in smem due to register pressure // out [rows, cols]
// 3. compute row max (per block); store in smem to accumulate full global mem transation template<typename T, int THREADS, int SPARSE_DECOMP>
// 4. store data via atomicMax __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE
using TReduction = T;
#else
using TReduction = float;
#endif
// each block loads TILE_COLs columns and TILE_ROW rows using BlockReduceT = cub::BlockReduce<TReduction, THREADS>;
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<T, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadT; // One block per row.
typedef cub::BlockReduce<float, THREADS> BlockRowReduce; // Threads load column values in a striped arrangement.
typedef cub::BlockReduce<int, THREADS> BlockRowSum; // e.g. t0 reads row[0], row[0+nthreads], ..
typedef cub::BlockExchange<float, THREADS, ITEMS_PER_THREAD> BlockExchange; // and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.
__shared__ union { __shared__ typename BlockReduceT::TempStorage temp_storage;
typename BlockExchange::TempStorage exchange; __shared__ TReduction smem_row_absmax;
typename BlockRowReduce::TempStorage rowreduce;
typename BlockRowSum::TempStorage rowsum;
typename LoadT::TempStorage loadt;
} temp_storage;
__shared__ float smem_row_absmax_values[ITEMS_PER_THREAD*THREADS]; const int row_id = blockIdx.x;
__shared__ int smem_row_nnz_values[TILE_ROWS]; const T* row_data = A + (row_id * cols);
half local_data[ITEMS_PER_THREAD]; // Threads will read the row values in a striped access pattern and find a local absmax.
float local_data_fp32[ITEMS_PER_THREAD]; TReduction row_local_absmax = -FLT_MIN;
float local_col_absmax_values[ITEMS_PER_THREAD]; for (int i = threadIdx.x; i < cols; i += THREADS) {
int local_row_nnz_count = 0; const TReduction absval = fabsf(__ldcs(&(row_data[i])));
float row_absmax = -FLT_MAX;
// 0. reset stats to -FLT_MAX // For sparse decomposition, values outside of the threshold are not to be
for(int j = 0; j < ITEMS_PER_THREAD; j++) // included when calculating the row's absmax.
{ if constexpr (SPARSE_DECOMP) {
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax);
smem_row_absmax_values[threadIdx.x + (j*THREADS)] = -FLT_MAX; } else {
// smem_row_nnz_values[threadIdx.x + (j*THREADS)] = 0; row_local_absmax = fmaxf(row_local_absmax, absval);
} }
#pragma unroll TILE_ROWS
for (int j = 0; j < TILE_ROWS; j++) {
smem_row_nnz_values[j] = 0;
} }
#pragma unroll ITEMS_PER_THREAD // Reduce thread-local absmax across the block.
for(int j = 0; j < ITEMS_PER_THREAD; j++) const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
local_col_absmax_values[j] = -FLT_MAX; if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
__syncthreads(); rowStats[row_id] = smem_row_absmax = row_absmax;
}
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
int i = base_idx;
// we load row after row from the base_position
// 1. load row-by-row ITEMS_PER_THREAD (TILE_SIZE==THREADS*ITEMS_PER_THREAD)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row+row >= rows){ break; }
local_row_nnz_count = 0;
i = base_idx + ((row)*cols);
// each thread gets data from the same column
__syncthreads(); __syncthreads();
LoadT(temp_storage.loadt).Load(&(A[i]), local_data, valid_items, __float2half(0.0f));
#pragma unroll ITEMS_PER_THREAD // Quantize row-wise.
for(int j = 0; j < ITEMS_PER_THREAD; j++) const float scale = __fdividef(127.0f, smem_row_absmax);
local_data[j] = fabsf(local_data[j]); for (int i = threadIdx.x; i < cols; i += THREADS) {
float val = row_data[i];
if constexpr (SPARSE_DECOMP) {
if(SPARSE_DECOMP) // For sparse decomposition, we do not want to quantize the outliers.
#pragma unroll ITEMS_PER_THREAD // Instead they're zeroed out.
for(int j = 0; j < ITEMS_PER_THREAD; j++) out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0;
{ } else {
if((float)local_data[j] >= nnz_threshold) out[row_id * cols + i] = __float2int_rn(val * scale);
{
local_row_nnz_count += 1;
local_data[j] = 0.0f;
} }
} }
}
// 2. compute col max (per thread); store in smem due to register pressure template<typename T, int THREADS, int SPARSE_DECOMP>
#pragma unroll ITEMS_PER_THREAD __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
for(int j = 0; j < ITEMS_PER_THREAD; j++) __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
// take the col max for this row using BlockReduceT = cub::BlockReduce<float, THREADS>;
// we use shared memory because register pressure is too high if we do this locally
//smem_col_absmax_values[threadIdx.x + (j*THREADS)] = fmaxf(smem_col_absmax_values[threadIdx.x + (j*THREADS)], __half2float(local_data[j]));
local_col_absmax_values[j] = fmaxf(local_col_absmax_values[j], __half2float(local_data[j]));
// 3. compute row max (per block); store in smem to accumulate full global mem transation // One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.
// this is slow as it uses extra registers, but we need this to be compatible with Kepler and Maxwell (no fp16 units) __shared__ typename BlockReduceT::TempStorage temp_storage;
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_data_fp32[j] = local_data[j];
__syncthreads(); const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);
row_absmax = (float)BlockRowReduce(temp_storage.rowreduce).Reduce(local_data_fp32, cub::Max()); // Threads will read the row values in a striped access pattern and find a local absmax.
if(SPARSE_DECOMP) float row_local_absmax = -FLT_MIN;
{ for (int i = threadIdx.x; i < cols; i += THREADS) {
__syncthreads(); const float absval = fabsf(row_data[i]);
local_row_nnz_count = BlockRowSum(temp_storage.rowsum).Sum(local_row_nnz_count);
}
// we store the data temporarily in shared memory so we
// can execute a full atomic block transaction into global memory later
// we use a striped arrangement [0, 8, 16, 24, ..] for t0 for faster stores
if(threadIdx.x == 0)
{
smem_row_absmax_values[(row % ITEMS_PER_THREAD) + ((row/ITEMS_PER_THREAD)*ITEMS_PER_THREAD)] = row_absmax;
// each blockIdx.x process 16 rows and 64*4=256 columns -> we sum nnz over 256 columns and have 16 values per block
smem_row_nnz_values[row] = local_row_nnz_count;
}
__syncthreads();
// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
} }
// 4. store data via atomicMax
// to store col data efficiently we need to rewrite the smem blocked data [0, 1, 2, 3...] for t0
// into a striped arrangement: [0, 8, 16, 24, ..] for t0
__syncthreads();
BlockExchange(temp_storage.exchange).BlockedToStriped(local_col_absmax_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
if(base_col+threadIdx.x+(j*THREADS) < cols)
{
float val = colStats[base_col+(threadIdx.x+(j*THREADS))];
if(val < local_col_absmax_values[j])
atomicMax(&colStats[base_col+(threadIdx.x+(j*THREADS))], local_col_absmax_values[j]);
} }
for(int j = 0; j < ITEMS_PER_THREAD; j++) // Reduce thread-local absmax across the block.
if(base_row+threadIdx.x+(j*THREADS) < rows) // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
{ const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols);
float val = rowStats[base_row+(threadIdx.x+(j*THREADS))]; if (threadIdx.x == 0) {
if(val < smem_row_absmax_values[threadIdx.x+(j*THREADS)]) // Save our block's absmax to shared memory for the quantization step.
atomicMax(&rowStats[base_row+(threadIdx.x+(j*THREADS))], smem_row_absmax_values[threadIdx.x+(j*THREADS)]); rowStats[row_id] = row_absmax;
} }
}
if(SPARSE_DECOMP) template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
if(threadIdx.x < TILE_ROWS) template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
nnz_count_row[blockIdx.x*TILE_ROWS+threadIdx.x+1] = smem_row_nnz_values[threadIdx.x];
} template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 0>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
template __global__ void kgetColRowStats<half, 64, 4, 16, 64*4, 1>(half * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f) #define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half *__restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n) template <int ITEMS_PER_THREAD, int THREADS>
{ __global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A,
float *__restrict__ const rowStats,
float *__restrict__ const colStats,
half *out,
half *__restrict__ const bias,
const int numRows,
const int numCols,
const int n
) {
const int n_out = numRows * numCols;
// Strategy: To dequantize we need to load col/row statistics. This can be very expensive int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD;
// since different row/col stats need to be loaded with each thread. int thread_offset = threadIdx.x * ITEMS_PER_THREAD;
// (1, bad algorithm) Loading 32 items per thread would only occur 1 row load, but this increases register pressure
// and would lead to low global load utilization.
// (2, bad algorithm) If each thread loads some columns and multiple rows one needs to do lot of row loads
// for each thread and this is duplicated by a factor of 32/num-cols-per-thread.
// (3, good algorithm) Combining (1) and (2) we use sub-tiles of size 32xk in shared memory per threadblock.
// This allows for efficient row/col loading from shared memory within the tile.
// We can run for example 32x128 sub-tiles and warp-strided loads of 4 elements so that each thread has
// the same col statistic but needs to load 4 row stats from shared memory. To prevent bank conflicts
// we use a block-striped shared memory config [1, 31, 63, 95] so no bank conflicts happen during the
// shared memory loads.
// data is in 32 column-tile major with tile width 32 columns and numRows rows
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
// L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
// C1. Compute val(row_stat*col_stat)/(127*127) (load 1/(127*127 into register))
// C2. Compute normalization values and store col values in register
// S1. Store C1 into 16-bit output
// S2. Store col/row statistics of new buffer in shared memory
// We allow for sub-tiles to span multiple col32 tiles. This is okay
// since the items per thread only rely on a single column statistic.
const int n_out = numRows*numCols;
int num_row_tiles = (numRows/SUBTILE_ROWS) + (numRows % SUBTILE_ROWS == 0 ? 0 : 1);
// we have tiles of size numRows*32, thus col only increases every numRows
// num_row_tiles is the tiles after which the column increases by 32
// blockIdx.x is the index of the current tile
int col = ((threadIdx.x % 32) + ((blockIdx.x/num_row_tiles)*32));
// base_row increases by SUBTILE_ROWS every block. It wraps back to zero once num_row_tiles is reached
int base_row = (blockIdx.x*SUBTILE_ROWS) % (num_row_tiles*SUBTILE_ROWS);
// SUBTILE_ROWS is independent from ITEMS_PER_THREAD is independent from THREADS
// subtiles have 32*SUBTILE_ROWS elements <= THREADS*ITEMS_PER_THREAD
// Total subtiles should be n/(32*SUBTILE_ROWS) where each subtile has SUBTILE_ROW*32/4 threads.
// For example for a 1024x1024 matrix with 128 SUBTILE_ROWS and 4 ITEMS_PER_THREAD we have
// 1024*1024/(128*32) = 256 tiles
// 256 tiles are 256*128*32/4 = 256*1024 threads
// 1. Figure out how index relates to the start of the sub-tile
// 2. Each thread < SUBTILE_ROWS calculates row index
// 3. Load striped and store in shared memory
int local_values[ITEMS_PER_THREAD]; int local_values[ITEMS_PER_THREAD];
half local_output[ITEMS_PER_THREAD]; half local_output[ITEMS_PER_THREAD];
float local_rowStats[ITEMS_PER_THREAD]; float local_rowStats[ITEMS_PER_THREAD];
__shared__ float smem_rowStats[SUBTILE_ROWS]; float local_colStats[ITEMS_PER_THREAD];
float local_biasValue[ITEMS_PER_THREAD];
typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_DIRECT> LoadInt32; typedef cub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadInt32;
typedef cub::BlockExchange<int, THREADS, ITEMS_PER_THREAD> ExchangeInt32;
__shared__ typename LoadInt32::TempStorage loadint32; __shared__ typename LoadInt32::TempStorage loadint32;
__shared__ typename ExchangeInt32::TempStorage exchangeint32;
int row_idx, col_idx;
// L1. Load sub-tile row/col statistics. Each thread only holds 1 col, load rows into shared memory.
float colStat = col >= numCols ? 0.0f : colStats[col];
float local_biasValue = ((bias == NULL) || (col >= numCols)) ? 0.0f : __half2float(bias[col]);
// no block loads for rows for now -- keep it simple
for(int j = threadIdx.x; j < SUBTILE_ROWS; j+=blockDim.x)
{
// todo: is this global mem access slow due to overlaps or does the L1 cache work well here?
int row = (base_row+j) % numRows; // wrap around
// each warp accesses the same element, for four consequitive elements
// todo: update description about striped shared memory, it is not needed
// rowidx: [0, 1, 2, 3...] and each warp reads ITEMS_PER_THREAD consequitive elements
smem_rowStats[j] = rowStats[row];
}
__syncthreads();
// each block processes SUBTILE_ROWS*32 elements
const int items_per_load = THREADS*ITEMS_PER_THREAD;
const int rows_per_load = items_per_load/32;
int subtile_base_row = (threadIdx.x / 32)*ITEMS_PER_THREAD; // row within the tile
int row_offset = 0;
// subtile_idx starts at the base_row*32 + the total offset for a full numRow*32 tile is passed
int subtile_start = (blockIdx.x/num_row_tiles)*(numRows*32) + (base_row*32);
for(int subtile_idx = subtile_start; subtile_idx < subtile_start + (SUBTILE_ROWS*32); subtile_idx+=items_per_load)
{
int valid_rows = numRows - (base_row+row_offset) > rows_per_load ? rows_per_load : numRows - (base_row+row_offset);
int valid_items = valid_rows*32;
if(valid_items <= 0) // the sub-tile might have more elements than the tile itself
break;
// L2. Load data in warp-striped arrangement (t0 holds colidx [0, 0, 0, 0], rowidx [0, 1, 2, 3])
LoadInt32(loadint32).Load(&(A[subtile_idx]), local_values, valid_items, 0);
ExchangeInt32(exchangeint32).BlockedToWarpStriped(local_values, local_values);
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_rowStats[j] = smem_rowStats[subtile_base_row+row_offset+j];
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
local_output[j] = __float2half((local_values[j]*MM_DEQUANT_CONST*local_rowStats[j]*colStat) + local_biasValue);
//absmax_col = fmax(fabsf(local_output[j]), absmax_col);
// we store data in row major
// to store data efficiently, we want to use block exchange: [0, 32, 64, 92] -> [0, 1, 2, 3]
// so that each thread holds ITEMS_PER_THREAD consecutive items for each row
// this way throughput into storage is increased by a factor of ~2x
// for now we use a simple store
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
int outIdx = col + ((base_row+subtile_base_row+row_offset+j)*numCols);
if(outIdx< n_out && col < numCols)
out[outIdx] = local_output[j];
}
row_offset += rows_per_load;
}
}
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols)
{
// assumes TILE_SIZE == THREADS*ITEMS_PER_THREAD
// Each thread reads the same column but multiple rows
// Rows are loaded in shared memory and access is shared across the threadblock (broadcast)
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
// 2. quantize data with row/col stats
// 3. Store data (TILE_SIZE = 512 is a bit slow, but should still be close enough to good performance)
// each block loads TILE_COLs columns and TILE_ROW rows
// after reading a tile the row counter increase by TILE_ROWS
// the col counter reset after reading TILE_COL elements
const int base_row = ((blockIdx.x*TILE_COLS)/tiledCols)*TILE_ROWS;
// col increases by TILE_SIZE for each block and wraps back to 0 after tiledCols is reached
const int base_col = (blockIdx.x*TILE_COLS) % tiledCols;
const int base_idx = (base_row*cols) + base_col;
const int items_per_load = ITEMS_PER_THREAD*THREADS;
typedef cub::BlockLoad<half, THREADS, ITEMS_PER_THREAD, cub::BLOCK_LOAD_VECTORIZE> LoadHalf;
__shared__ typename LoadHalf::TempStorage loadhalf;
typedef cub::BlockStore<char, THREADS, ITEMS_PER_THREAD, cub::BLOCK_STORE_VECTORIZE> StoreInt8;
__shared__ typename StoreInt8::TempStorage storeint8;
__shared__ float smem_row_stats[TILE_ROWS];
__shared__ unsigned int smem_nnz_row_idx[TILE_ROWS];
half local_data[ITEMS_PER_THREAD];
float local_col_stats[ITEMS_PER_THREAD];
char local_quantized_data[ITEMS_PER_THREAD];
// 0. Load row stats data into shared memory; load col stat (1 fixed per thread)
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++) for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
if(base_col+(threadIdx.x*ITEMS_PER_THREAD) + j < cols)
local_col_stats[j] = __fdividef(127.0f, colStats[base_col+(threadIdx.x*ITEMS_PER_THREAD)+j]);
for(int i = threadIdx.x; i < TILE_ROWS; i+=blockDim.x) row_idx = (block_offset + thread_offset + j) / numCols;
{ col_idx = (block_offset + thread_offset + j) % numCols;
if(base_row + i < rows)
smem_row_stats[i] = rowStats[base_row+i];
if(SPARSE_DECOMP) local_colStats[j] = col_idx >= numCols ? 0.0f : __ldg(&colStats[col_idx]);
smem_nnz_row_idx[i] = nnz_block_ptr[(TILE_ROWS*blockIdx.x) + i]; local_rowStats[j] = row_idx >= numRows ? 0.0f : __ldg(&rowStats[row_idx]);
local_biasValue[j] = ((bias == nullptr) || col_idx >= numCols) ? 0.0f : __half2float(bias[col_idx]);
} }
__syncthreads();
// we load row after row from the base_position
// 1. Load data row by row (should be at least with TILE_SIZE = 512)
for(int row = 0; row < TILE_ROWS; row++)
{
if(base_row + row >= rows){ break; }
int i = base_idx + (row*cols);
int valid_items = cols - base_col > items_per_load ? items_per_load : cols - base_col;
LoadHalf(loadhalf).Load(&(A[i]), local_data, valid_items, 0.0f); // Each block loads THREADS * ITEMS_PER_THREAD values from A
float row_stat = __fdividef(127.0f, smem_row_stats[row]); int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out
? THREADS * ITEMS_PER_THREAD
: n_out - block_offset;
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++) for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
{ local_output[j] = __float2half(
// we already pre-normalized the col/row stat: fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j])
// what this does is float/absmax*127 = int8 );
if(SPARSE_DECOMP)
{
if(fabsf((float)local_data[j]) >= threshold)
{
local_quantized_data[j] = 0;
int old_idx = atomicInc(&smem_nnz_row_idx[row], UINT_MAX);
rowidx[old_idx] = base_row+row;
colidx[old_idx] = base_col+(threadIdx.x*ITEMS_PER_THREAD)+j;
val[old_idx] = local_data[j];
}
else
{
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
}
}
else
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*row_stat));
} }
StoreInt8(storeint8).Store(&(out_row_normed[i]), local_quantized_data, valid_items);
// 2. quantize data with row/col stats
#pragma unroll ITEMS_PER_THREAD #pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++) for (int j = 0; j < ITEMS_PER_THREAD; j++) {
{ int outIdx = block_offset + thread_offset + j;
// we already pre-normalized the col/row stat: if (outIdx < n_out) {
// what this does is float/absmax*127 = int8 out[outIdx] = local_output[j];
local_quantized_data[j] = (char)(rintf(__half2float(local_data[j])*local_col_stats[j]));
} }
__syncthreads();
StoreInt8(storeint8).Store(&(out_col_normed[i]), local_quantized_data, valid_items);
} }
} }
...@@ -3516,6 +3312,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3516,6 +3312,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
const int warp_idx = threadIdx.x / 32; const int warp_idx = threadIdx.x / 32;
const int warp_lane = threadIdx.x % 32; const int warp_lane = threadIdx.x % 32;
const int row_B = (THREADS/32)*blockIdx.x + warp_idx; const int row_B = (THREADS/32)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
const int num_values_8bit = num_values_4bit/2; const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f; float local_C = 0.0f;
...@@ -3525,17 +3322,23 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3525,17 +3322,23 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
__shared__ T quant_map[16]; __shared__ T quant_map[16];
T local_absmax = T(0.0f); T local_absmax = T(0.0f);
for(int i = threadIdx.x; i < 16; i++) if (threadIdx.x < 16)
quant_map[i] = T(datatype[i]); quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(__ldg(&datatype[i]));
__syncthreads(); __syncthreads();
// A: [1, K] // A: [1, K]
// B: [N, K] // B: [N, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit) for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += 32*num_values_4bit)
{ {
int inner_idx_halved = inner_idx/2; const int inner_idx_halved = inner_idx/2;
int offset_B = ldb*row_B;
int absidx = ((2*offset_B)+inner_idx)/blocksize; // Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize));
local_absmax = __ldg(&(absmax[absidx])); local_absmax = __ldg(&(absmax[absidx]));
if(row_B < M) if(row_B < M)
...@@ -3567,7 +3370,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3567,7 +3370,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#pragma unroll #pragma unroll
for(int k = 0; k < num_values_8bit/4; k++) for(int k = 0; k < num_values_8bit/4; k++)
{ {
#if __CUDA_ARCH__ >= 800 #if BNB_BF16_AVAILABLE
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax; local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax; local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
#else #else
...@@ -3604,7 +3407,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -3604,7 +3407,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
#pragma unroll #pragma unroll
for(int k = 0; k < num_values_4bit/4; k++) for(int k = 0; k < num_values_4bit/4; k++)
{ {
#if __CUDA_ARCH__ >= 800 #if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k]*local_B[k]); local_C += (float)(local_A[k]*local_B[k]);
#else #else
// bf16 multipliation not supported // bf16 multipliation not supported
...@@ -3810,10 +3613,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>( ...@@ -3810,10 +3613,7 @@ template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_TURING>(
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 0, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template __global__ void kTransformRowToFormat<256, 8, 32, 32*8, 1, COL_AMPERE>(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
template __global__ void kdequant_mm_int32_fp16<4, 128, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 0>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __global__ void kDoubleRowColQuant<64, 4, 16, 64*4, 1>(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x); template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
......
...@@ -112,12 +112,12 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index ...@@ -112,12 +112,12 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB); template <typename T, int SPMM_ITEMS, int BITS> __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template <int ITEMS_PER_THREAD, int SUBTILE_ROWS, int THREADS>__global__ void kdequant_mm_int32_fp16( template <int ITEMS_PER_THREAD, int THREADS>__global__ void kdequant_mm_int32_fp16(
int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats,
half *out, float* newRowStats, float* newcolStats, half * __restrict__ const bias, const int numRows, const int numCols, const int tileCols, const int n); half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template<typename T, int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kgetColRowStats(T * __restrict__ A, float *rowStats, float *colStats, int * nnz_count_row, float nnz_threshold, int rows, int cols, int tiledRows, int tiledCols); template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int SPARSE_DECOMP> __global__ void kDoubleRowColQuant(half *__restrict__ const A, float *__restrict__ const rowStats, float * __restrict__ const colStats, char *out_col_normed, char *out_row_normed, int *rowidx, int *colidx, half *val, int * __restrict__ nnz_block_ptr, float threshold, int rows, int cols, int tiledCols); template<typename T, int THREADS, int SPARSE_DECOMP> __global__ void kInt8VectorQuant(T * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols); template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT> __global__ void kTransformRowToFormat(char *__restrict__ const A, char *out, int rows, int cols, int tiledCols, int outRows, int outCols);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment