Unverified Commit 888788d7 authored by pnunna93's avatar pnunna93 Committed by GitHub
Browse files

Enable ROCm backend with custom ops integration (#1683)



* Port ROCm changes from multi-backend-refactor branch

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update test_functional.py

* Update test_functional.py

* Update cextension.py

* Update cuda_specs.py

* Update cuda_specs.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_cuda_setup_evaluator.py

* Update test_functional.py

* Update modules.py

* Update modules.py

* Update ops.py

* Update test_linear4bit.py

* Update ops.py

* Update ops.py

* Update test_linear4bit.py

* Update test_linear4bit.py

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Create build-rocm.sh

* Update cuda_specs.py

* Fix trailing whitespace

* Remove conflicts.diff

* update for hipblasVersionMajor >=3

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update main.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update test_linear4bit.py

* Lint

* Lint

* Update helpers.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Lint

* Update pythonInterface.cpp

* lint fix

* lint

* Update pythonInterface.cpp

* revert permissions change

* Fix indentation

* Update kernels_hip.cuh

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update kernels_hip.cuh

* Update kernels.hip

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update ops.hip

* Update CMakeLists.txt

* Update functional.py

* Update cextension.py

* Update cextension.py

---------
Co-authored-by: default avatarMISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com>
Co-authored-by: default avatarMISHANMAUYRA <mishanmaurya31081@gmail.com>
Co-authored-by: default avataramcamd <andrew.chapman@amd.com>
Co-authored-by: default avatarPrasanth Nunna <root@banff-cyxtera-s78-1.amd.com>
parent a1cd3f6e
#!/bin/bash
declare build_arch
declare build_os
declare rocm_version
set -xeuo pipefail
bnb_rocm_arch="gfx90a;gfx942;gfx1100"
if [ "${build_os:0:6}" == ubuntu ]; then
image=rocm/dev-ubuntu-22.04:${rocm_version}-complete
echo "Using image $image"
docker run --rm --platform "linux/$build_arch" -i \
-w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \
&& cmake --build ."
fi
output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}")
...@@ -102,10 +102,55 @@ jobs: ...@@ -102,10 +102,55 @@ jobs:
path: output/* path: output/*
retention-days: 7 retention-days: 7
build-shared-libs-rocm:
strategy:
matrix:
os: [ubuntu-22.04]
arch: [x86_64]
rocm_version:
["6.1.2", "6.2.4", "6.3.2"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Docker multiarch
uses: docker/setup-qemu-action@v3
- name: Clean up disk space
run: |
sudo rm -rf \
/usr/share/dotnet \
/opt/ghc \
"/usr/local/share/boost" \
"$AGENT_TOOLSDIRECTORY" \
/opt/hostedtoolcache \
/opt/google/chrome \
/opt/microsoft/msedge \
/opt/microsoft/powershell \
/opt/pipx \
/usr/lib/mono \
/usr/local/julia* \
/usr/local/lib/android \
/usr/local/lib/node_modules \
/usr/local/share/chromium \
/usr/local/share/powershell \
/usr/share/swift
- name: Build C++
run: bash .github/scripts/build-rocm.sh
env:
build_os: ${{ matrix.os }}
build_arch: ${{ matrix.arch }}
rocm_version: ${{ matrix.rocm_version }}
- name: Upload build artifact
uses: actions/upload-artifact@v4
with:
name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }}
path: output/*
retention-days: 7
build-wheels: build-wheels:
needs: needs:
- build-shared-libs - build-shared-libs
- build-shared-libs-cuda - build-shared-libs-cuda
- build-shared-libs-rocm
strategy: strategy:
matrix: matrix:
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest]
...@@ -173,6 +218,7 @@ jobs: ...@@ -173,6 +218,7 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Inspect tmp directory after downloading artifacts - name: Inspect tmp directory after downloading artifacts
run: | run: |
ls -alFR tmp/ ls -alFR tmp/
WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l) WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l)
...@@ -210,6 +256,7 @@ jobs: ...@@ -210,6 +256,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
path: repo path: repo
- name: Delete old pre-release (if exists) - name: Delete old pre-release (if exists)
run: | run: |
cd repo && gh release delete continuous-release_main --cleanup-tag -y cd repo && gh release delete continuous-release_main --cleanup-tag -y
......
...@@ -25,13 +25,14 @@ endif() ...@@ -25,13 +25,14 @@ endif()
# Define included source files # Define included source files
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm) set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal) set(METAL_FILES csrc/mps_kernels.metal)
# C++ sources are always included # C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES}) list(APPEND SRC_FILES ${CPP_FILES})
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
if(APPLE) if(APPLE)
...@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") ...@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
message(FATAL_ERROR "CUDA is not supported on macOS" ) message(FATAL_ERROR "CUDA is not supported on macOS" )
endif() endif()
set(BUILD_CUDA ON) set(BUILD_CUDA ON)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
if(APPLE)
message(FATAL_ERROR "HIP is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP ON)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "mps") elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
endif() endif()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON) set(BUILD_MPS ON)
else() else()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
endif() endif()
...@@ -160,6 +171,33 @@ if(BUILD_CUDA) ...@@ -160,6 +171,33 @@ if(BUILD_CUDA)
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
add_compile_definitions(BUILD_CUDA) add_compile_definitions(BUILD_CUDA)
elseif(BUILD_HIP)
enable_language(HIP)
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
if(DEFINED BNB_ROCM_ARCH)
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
else()
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
endif()
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
list(APPEND SRC_FILES ${HIP_FILES})
string(APPEND BNB_OUTPUT_NAME "_rocm")
# get hip version
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
add_compile_definitions(__HIP_PLATFORM_AMD__)
add_compile_definitions(__HIP_PLATFORM_HCC__)
add_compile_definitions(BUILD_HIP)
elseif(BUILD_MPS) elseif(BUILD_MPS)
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
...@@ -208,6 +246,41 @@ if(BUILD_CUDA) ...@@ -208,6 +246,41 @@ if(BUILD_CUDA)
CUDA_SEPARABLE_COMPILATION ON CUDA_SEPARABLE_COMPILATION ON
) )
endif() endif()
if(BUILD_HIP)
if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
endmacro()
find_package_and_print_version(hipblas REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
if(HIP_VERSION VERSION_LESS "6.1")
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
else()
find_package(hipblaslt)
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
endif()
endif()
if(BUILD_MPS) if(BUILD_MPS)
add_dependencies(bitsandbytes metallib) add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
from ..._ops import register_kernel from ..._ops import register_kernel
from ...cextension import lib from ...cextension import HIP_ENVIRONMENT, lib
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
...@@ -210,7 +210,12 @@ def _get_col_absmax( ...@@ -210,7 +210,12 @@ def _get_col_absmax(
@register_kernel("bitsandbytes::quantize_blockwise", "cuda") @register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
n = A.numel() n = A.numel()
...@@ -264,7 +269,11 @@ def _( ...@@ -264,7 +269,11 @@ def _(
def _dequantize_blockwise_impl( def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None: ) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check( torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32], dtype in [torch.float16, torch.bfloat16, torch.float32],
...@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl( ...@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl(
def _( def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"]) torch._check(quant_type in ["fp4", "nf4"])
torch._check( torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32], A.dtype in [torch.bfloat16, torch.float16, torch.float32],
...@@ -372,7 +385,11 @@ def _dequantize_4bit_impl( ...@@ -372,7 +385,11 @@ def _dequantize_4bit_impl(
dtype: torch.dtype, dtype: torch.dtype,
out: torch.Tensor, out: torch.Tensor,
) -> None: ) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"]) torch._check(quant_type in ["fp4", "nf4"])
torch._check( torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32], dtype in [torch.bfloat16, torch.float16, torch.float32],
......
...@@ -9,7 +9,7 @@ from typing import Optional ...@@ -9,7 +9,7 @@ from typing import Optional
import torch import torch
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,6 +28,11 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -28,6 +28,11 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
if torch.version.hip:
raise RuntimeError(
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
)
logger.warning( logger.warning(
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
...@@ -75,10 +80,11 @@ class CudaBNBNativeLibrary(BNBNativeLibrary): ...@@ -75,10 +80,11 @@ class CudaBNBNativeLibrary(BNBNativeLibrary):
def get_available_cuda_binary_versions() -> list[str]: def get_available_cuda_binary_versions() -> list[str]:
"""Get formatted CUDA versions from existing library files using cuda_specs logic""" """Get formatted CUDA versions from existing library files using cuda_specs logic"""
lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}" lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}"
versions = [] versions = []
for lib in Path(__file__).parent.glob(lib_pattern): for lib in Path(__file__).parent.glob(lib_pattern):
match = re.search(r"cuda(\d{3})", lib.name) pattern = rf"{BNB_BACKEND.lower()}(\d+)"
match = re.search(pattern, lib.name)
if match: if match:
ver_code = int(match.group(1)) ver_code = int(match.group(1))
major = ver_code // 10 major = ver_code // 10
...@@ -89,8 +95,8 @@ def get_available_cuda_binary_versions() -> list[str]: ...@@ -89,8 +95,8 @@ def get_available_cuda_binary_versions() -> list[str]:
def parse_cuda_version(version_str: str) -> str: def parse_cuda_version(version_str: str) -> str:
"""Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')""" """Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')"""
if version_str.isdigit() and len(version_str) == 3: if version_str.isdigit():
return f"{version_str[:2]}.{version_str[2]}" return f"{version_str[:-1]}.{version_str[-1]}"
return version_str # fallback as safety net return version_str # fallback as safety net
...@@ -151,7 +157,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -151,7 +157,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
"""Format detailed error message for library loading failures""" """Format detailed error message for library loading failures"""
analysis = "" analysis = ""
no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error
no_cuda_lib_found = "CUDA binary not found" in original_error no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error
if no_cpu_lib_found: if no_cpu_lib_found:
analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n" analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n"
...@@ -160,9 +166,9 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -160,9 +166,9 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE" version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE"
analysis = ( analysis = (
( (
f"\n🚨 CUDA VERSION MISMATCH 🚨\n" f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n"
f"Requested CUDA version: {requested_version}\n" f"Requested {BNB_BACKEND} version: {requested_version}\n"
f"Detected PyTorch CUDA version: {user_cuda_version}\n" f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n"
f"Available pre-compiled versions: {version_list_str}\n\n" f"Available pre-compiled versions: {version_list_str}\n\n"
"This means:\n" "This means:\n"
"The version you're trying to use is NOT distributed with this package\n\n" "The version you're trying to use is NOT distributed with this package\n\n"
...@@ -177,42 +183,47 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -177,42 +183,47 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
troubleshooting = ( troubleshooting = (
( (
"This typically happens when:\n" f"This typically happens when:\n"
"1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n" f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n"
"2. The library wasn't compiled properly during installation from source\n\n" f"2. The library wasn't compiled properly during installation from source\n\n"
) )
if no_cuda_lib_found if no_cuda_lib_found
else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n" else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n"
) )
note = ( note = (
( (
"To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n" f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n"
"If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n"
) )
if no_cuda_lib_found if no_cuda_lib_found
else "" else ""
) )
compile_instructions = ( compile_instructions = (
( ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n")
if not no_cuda_lib_found
else (
"You have two options:\n" "You have two options:\n"
"1. COMPILE FROM SOURCE (required if no binary exists):\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n"
" https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n"
"2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n"
) )
if no_cuda_lib_found if not HIP_ENVIRONMENT
else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" else (
"You can COMPILE FROM SOURCE as mentioned here:\n"
" https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n"
)
) )
diagnostics = ( diagnostics = (
"🔍 Run this command for detailed diagnostics:\n" f"🔍 Run this command for detailed diagnostics:\n"
"python -m bitsandbytes\n\n" f"python -m bitsandbytes\n\n"
"If you've tried everything and still have issues:\n" f"If you've tried everything and still have issues:\n"
"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
"2. Describe what you've tried in detail\n" f"2. Describe what you've tried in detail\n"
"3. Open an issue with this information:\n" f"3. Open an issue with this information:\n"
" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
) )
return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}" return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}"
...@@ -227,18 +238,19 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -227,18 +238,19 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
) )
return ( return (
f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n"
f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n"
f"To fix this, make sure that:\n" f"To fix this, make sure that:\n"
f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n" f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n"
f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n" f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n"
f"You can add them with (and persist the change by adding the line to your .bashrc):\n" f"You can add them with (and persist the change by adding the line to your .bashrc):\n"
f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n" f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\
{'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n"
f"Original error: {self.error_msg}\n\n" f"Original error: {self.error_msg}\n\n"
f"🔍 Run this command for detailed diagnostics:\n" f"🔍 Run this command for detailed diagnostics:\n"
f"python -m bitsandbytes\n\n" f"python -m bitsandbytes\n\n"
f"If you've tried everything and still have issues:\n" f"If you've tried everything and still have issues:\n"
f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
f"2. Describe what you've tried in detail\n" f"2. Describe what you've tried in detail\n"
f"3. Open an issue with this information:\n" f"3. Open an issue with this information:\n"
f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
...@@ -267,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -267,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if not cuda_binary_path.exists(): if not cuda_binary_path.exists():
raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}") raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
binary_path = cuda_binary_path binary_path = cuda_binary_path
...@@ -286,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -286,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
ROCM_GPU_ARCH = get_rocm_gpu_arch()
try: try:
# to support Intel CPU/GPU (XPU) backend # to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -296,8 +310,12 @@ except BaseException: ...@@ -296,8 +310,12 @@ except BaseException:
ipex_cpu = None ipex_cpu = None
ipex_xpu = None ipex_xpu = None
try: try:
if torch.version.hip:
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
else:
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
lib = get_native_library() lib = get_native_library()
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
......
import dataclasses import dataclasses
from functools import lru_cache from functools import lru_cache
import logging
import re
import subprocess
from typing import Optional from typing import Optional
import torch import torch
...@@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ...@@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
) )
except Exception: except Exception:
return None return None
def get_rocm_gpu_arch() -> str:
"""Get ROCm GPU architecture."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
...@@ -5,7 +5,7 @@ from pathlib import Path ...@@ -5,7 +5,7 @@ from pathlib import Path
import torch import torch
from bitsandbytes.cextension import get_cuda_bnb_library_path from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.diagnostics.utils import print_dedented from bitsandbytes.diagnostics.utils import print_dedented
...@@ -32,9 +32,13 @@ CUDART_PATH_IGNORED_ENVVARS = { ...@@ -32,9 +32,13 @@ CUDART_PATH_IGNORED_ENVVARS = {
} }
CUDA_RUNTIME_LIB_PATTERNS = ( CUDA_RUNTIME_LIB_PATTERNS = (
"cudart64*.dll", # Windows ("libamdhip64.so*",)
"libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. if HIP_ENVIRONMENT
"nvcuda*.dll", # Windows else (
"cudart64*.dll", # Windows
"libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc.
"nvcuda*.dll", # Windows
)
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,7 +60,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path ...@@ -56,7 +60,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path
pass pass
for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS:
for pth in dir.glob(lib_pattern): for pth in dir.glob(lib_pattern):
if pth.is_file(): if pth.is_file() and not pth.is_symlink():
yield pth yield pth
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
...@@ -103,7 +107,7 @@ def find_cudart_libraries() -> Iterator[Path]: ...@@ -103,7 +107,7 @@ def find_cudart_libraries() -> Iterator[Path]:
yield from find_cuda_libraries_in_path_list(value) yield from find_cuda_libraries_in_path_list(value)
def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print( print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
...@@ -128,7 +132,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -128,7 +132,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
) )
def print_cuda_runtime_diagnostics() -> None: def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
)
hip_major, hip_minor = cuda_specs.cuda_version_tuple
if (hip_major, hip_minor) < (6, 1):
print_dedented(
"""
WARNING: bitsandbytes is fully supported only from ROCm 6.1.
""",
)
def print_diagnostics(cuda_specs: CUDASpecs) -> None:
if HIP_ENVIRONMENT:
_print_hip_diagnostics(cuda_specs)
else:
_print_cuda_diagnostics(cuda_specs)
def _print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries()) cudart_paths = list(find_cudart_libraries())
if not cudart_paths: if not cudart_paths:
print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
...@@ -153,3 +187,33 @@ def print_cuda_runtime_diagnostics() -> None: ...@@ -153,3 +187,33 @@ def print_cuda_runtime_diagnostics() -> None:
) )
for pth in cudart_paths: for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}") print(f"* Found CUDA runtime at: {pth}")
def _print_hip_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
print_dedented(
f"""
Found duplicate ROCm runtime files (see below).
We select the PyTorch default ROCm runtime, which is {torch.version.hip},
but this might mismatch with the ROCm version that is needed for bitsandbytes.
To resolve it, install PyTorch built for the ROCm version you want to use
and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib,
""",
)
for pth in cudart_paths:
print(f"* Found ROCm runtime at: {pth}")
def print_runtime_diagnostics() -> None:
if HIP_ENVIRONMENT:
_print_hip_runtime_diagnostics()
else:
_print_cuda_runtime_diagnostics()
...@@ -6,10 +6,11 @@ import traceback ...@@ -6,10 +6,11 @@ import traceback
import torch import torch
from bitsandbytes import __version__ as bnb_version from bitsandbytes import __version__ as bnb_version
from bitsandbytes.cextension import BNB_BACKEND
from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.consts import PACKAGE_GITHUB_URL
from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.cuda_specs import get_cuda_specs
from bitsandbytes.diagnostics.cuda import ( from bitsandbytes.diagnostics.cuda import (
print_cuda_diagnostics, print_diagnostics,
) )
from bitsandbytes.diagnostics.utils import print_dedented, print_header from bitsandbytes.diagnostics.utils import print_dedented, print_header
...@@ -77,19 +78,19 @@ def main(): ...@@ -77,19 +78,19 @@ def main():
cuda_specs = get_cuda_specs() cuda_specs = get_cuda_specs()
if cuda_specs: if cuda_specs:
print_cuda_diagnostics(cuda_specs) print_diagnostics(cuda_specs)
# TODO: There's a lot of noise in this; needs improvement. # TODO: There's a lot of noise in this; needs improvement.
# print_cuda_runtime_diagnostics() # print_cuda_runtime_diagnostics()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
print("PyTorch says CUDA is not available. Possible reasons:") print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:")
print("1. CUDA driver not installed") print(f"1. {BNB_BACKEND} driver not installed")
print("2. Using a CPU-only PyTorch build") print("2. Using a CPU-only PyTorch build")
print("3. No GPU detected") print("3. No GPU detected")
else: else:
print("Checking that the library is importable and CUDA is callable...") print(f"Checking that the library is importable and {BNB_BACKEND} is callable...")
try: try:
sanity_check() sanity_check()
......
...@@ -15,7 +15,7 @@ from typing_extensions import deprecated ...@@ -15,7 +15,7 @@ from typing_extensions import deprecated
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import ipex_cpu, ipex_xpu, lib from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap = {} name2qmap = {}
...@@ -868,10 +868,12 @@ def quantize_fp4( ...@@ -868,10 +868,12 @@ def quantize_fp4(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_storage=torch.uint8, quant_storage=torch.uint8,
): ):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
...@@ -879,10 +881,12 @@ def quantize_nf4( ...@@ -879,10 +881,12 @@ def quantize_nf4(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_storage=torch.uint8, quant_storage=torch.uint8,
): ):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
...@@ -890,7 +894,7 @@ def quantize_4bit( ...@@ -890,7 +894,7 @@ def quantize_4bit(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_type="fp4", quant_type="fp4",
quant_storage=torch.uint8, quant_storage=torch.uint8,
...@@ -904,7 +908,7 @@ def quantize_4bit( ...@@ -904,7 +908,7 @@ def quantize_4bit(
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result. out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*): blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64. The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
...@@ -918,6 +922,10 @@ def quantize_4bit( ...@@ -918,6 +922,10 @@ def quantize_4bit(
- `torch.Tensor`: The quantized tensor with packed 4-bit values. - `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization. - [`QuantState`]: The state object used to undo the quantization.
""" """
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
input_shape = A.shape input_shape = A.shape
_out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
...@@ -968,8 +976,10 @@ def dequantize_fp4( ...@@ -968,8 +976,10 @@ def dequantize_fp4(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
...@@ -978,8 +988,10 @@ def dequantize_nf4( ...@@ -978,8 +988,10 @@ def dequantize_nf4(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
...@@ -988,7 +1000,7 @@ def dequantize_4bit( ...@@ -988,7 +1000,7 @@ def dequantize_4bit(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
quant_type="fp4", quant_type="fp4",
) -> torch.Tensor: ) -> torch.Tensor:
"""Dequantizes a packed 4-bit quantized tensor. """Dequantizes a packed 4-bit quantized tensor.
...@@ -1007,7 +1019,7 @@ def dequantize_4bit( ...@@ -1007,7 +1019,7 @@ def dequantize_4bit(
Required if `quant_state` is not provided and ignored otherwise. Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result. out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*): blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64. The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
...@@ -1017,6 +1029,10 @@ def dequantize_4bit( ...@@ -1017,6 +1029,10 @@ def dequantize_4bit(
Returns: Returns:
`torch.Tensor`: The dequantized tensor. `torch.Tensor`: The dequantized tensor.
""" """
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
......
...@@ -11,6 +11,7 @@ from torch import Tensor, device, dtype, nn ...@@ -11,6 +11,7 @@ from torch import Tensor, device, dtype, nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
...@@ -213,7 +214,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -213,7 +214,7 @@ class Params4bit(torch.nn.Parameter):
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad=False, # quantized weights should be frozen by default requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
compress_statistics: bool = True, compress_statistics: bool = True,
quant_type: str = "fp4", quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8, quant_storage: torch.dtype = torch.uint8,
...@@ -223,6 +224,9 @@ class Params4bit(torch.nn.Parameter): ...@@ -223,6 +224,9 @@ class Params4bit(torch.nn.Parameter):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
self = torch.Tensor._make_subclass(cls, data, requires_grad) self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize self.blocksize = blocksize
self.compress_statistics = compress_statistics self.compress_statistics = compress_statistics
......
#pragma once
#define BNB_WARP_SIZE warpSize
// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_BF16_AVAILABLE true
This diff is collapsed.
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <float.h>
#include <ops_hip.cuh>
#ifndef kernels
#define kernels
__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n);
__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n);
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
__global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
float* new_max1, const float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit1State(
T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
);
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);
#endif
This diff is collapsed.
This diff is collapsed.
...@@ -6,11 +6,29 @@ ...@@ -6,11 +6,29 @@
#if BUILD_CUDA #if BUILD_CUDA
#include <ops.cuh> #include <ops.cuh>
#endif #endif
#if BUILD_HIP
#include <ops_hip.cuh>
#endif
#if BUILD_MPS #if BUILD_MPS
// #include <mps_ops.h> // #include <mps_ops.h>
#endif #endif
#include <cpu_ops.h> #include <cpu_ops.h>
// Compatibility between HIP/CUDA APIs
#if BUILD_HIP
#define cudaStream_t hipStream_t
#define __nv_bfloat16 hip_bfloat16
#define cublasLtHandle_t hipblasLtHandle_t
#define ContextCusparse ContextHipsparse
#define cusparseHandle_t hipsparseHandle_t
#define cudaMallocManaged hipMallocManaged
#define cudaMemAttachHost hipMemAttachHost
#define cudaPeekAtLastError hipPeekAtLastError
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
#define cudaMemPrefetchAsync hipMemPrefetchAsync
#endif
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
// maintain all that boilerplate // maintain all that boilerplate
...@@ -18,7 +36,7 @@ ...@@ -18,7 +36,7 @@
// UNMANGLED CALLS // UNMANGLED CALLS
//=================================================================================== //===================================================================================
#if BUILD_CUDA #if BUILD_CUDA || BUILD_HIP
// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); } //{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
...@@ -291,7 +309,7 @@ void spmm_coo_very_sparse_naive_int8( ...@@ -291,7 +309,7 @@ void spmm_coo_very_sparse_naive_int8(
#endif #endif
extern "C" { extern "C" {
#if BUILD_CUDA #if BUILD_CUDA || BUILD_HIP
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
......
...@@ -7,6 +7,8 @@ from typing import Any ...@@ -7,6 +7,8 @@ from typing import Any
import torch import torch
from bitsandbytes.cextension import HIP_ENVIRONMENT
test_dims_rng = random.Random(42) test_dims_rng = random.Random(42)
...@@ -21,7 +23,7 @@ def get_available_devices(): ...@@ -21,7 +23,7 @@ def get_available_devices():
# If the environment variable is set, use it directly. # If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]] return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"] devices = [] if HIP_ENVIRONMENT else ["cpu"]
if hasattr(torch, "accelerator"): if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API. # PyTorch 2.6+ - determine accelerator using agnostic API.
......
import pytest import pytest
from bitsandbytes.cextension import get_cuda_bnb_library_path from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.cuda_specs import CUDASpecs
...@@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs: ...@@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs:
) )
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110") monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment