"vscode:/vscode.git/clone" did not exist on "cdbf802860390265dd9be6ca42d043587efcf59f"
Unverified Commit 888788d7 authored by pnunna93's avatar pnunna93 Committed by GitHub
Browse files

Enable ROCm backend with custom ops integration (#1683)



* Port ROCm changes from multi-backend-refactor branch

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update test_functional.py

* Update test_functional.py

* Update cextension.py

* Update cuda_specs.py

* Update cuda_specs.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_cuda_setup_evaluator.py

* Update test_functional.py

* Update modules.py

* Update modules.py

* Update ops.py

* Update test_linear4bit.py

* Update ops.py

* Update ops.py

* Update test_linear4bit.py

* Update test_linear4bit.py

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Create build-rocm.sh

* Update cuda_specs.py

* Fix trailing whitespace

* Remove conflicts.diff

* update for hipblasVersionMajor >=3

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update main.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update test_linear4bit.py

* Lint

* Lint

* Update helpers.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Lint

* Update pythonInterface.cpp

* lint fix

* lint

* Update pythonInterface.cpp

* revert permissions change

* Fix indentation

* Update kernels_hip.cuh

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update kernels_hip.cuh

* Update kernels.hip

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update ops.hip

* Update CMakeLists.txt

* Update functional.py

* Update cextension.py

* Update cextension.py

---------
Co-authored-by: default avatarMISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com>
Co-authored-by: default avatarMISHANMAUYRA <mishanmaurya31081@gmail.com>
Co-authored-by: default avataramcamd <andrew.chapman@amd.com>
Co-authored-by: default avatarPrasanth Nunna <root@banff-cyxtera-s78-1.amd.com>
parent a1cd3f6e
#!/bin/bash
declare build_arch
declare build_os
declare rocm_version
set -xeuo pipefail
bnb_rocm_arch="gfx90a;gfx942;gfx1100"
if [ "${build_os:0:6}" == ubuntu ]; then
image=rocm/dev-ubuntu-22.04:${rocm_version}-complete
echo "Using image $image"
docker run --rm --platform "linux/$build_arch" -i \
-w /src -v "$PWD:/src" "$image" sh -c \
"apt-get update \
&& DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends cmake \
&& cmake -DCOMPUTE_BACKEND=hip -DBNB_ROCM_ARCH=\"${bnb_rocm_arch}\" . \
&& cmake --build ."
fi
output_dir="output/${build_os}/${build_arch}"
mkdir -p "${output_dir}"
(shopt -s nullglob && cp bitsandbytes/*.{so,dylib,dll} "${output_dir}")
...@@ -102,10 +102,55 @@ jobs: ...@@ -102,10 +102,55 @@ jobs:
path: output/* path: output/*
retention-days: 7 retention-days: 7
build-shared-libs-rocm:
strategy:
matrix:
os: [ubuntu-22.04]
arch: [x86_64]
rocm_version:
["6.1.2", "6.2.4", "6.3.2"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Set up Docker multiarch
uses: docker/setup-qemu-action@v3
- name: Clean up disk space
run: |
sudo rm -rf \
/usr/share/dotnet \
/opt/ghc \
"/usr/local/share/boost" \
"$AGENT_TOOLSDIRECTORY" \
/opt/hostedtoolcache \
/opt/google/chrome \
/opt/microsoft/msedge \
/opt/microsoft/powershell \
/opt/pipx \
/usr/lib/mono \
/usr/local/julia* \
/usr/local/lib/android \
/usr/local/lib/node_modules \
/usr/local/share/chromium \
/usr/local/share/powershell \
/usr/share/swift
- name: Build C++
run: bash .github/scripts/build-rocm.sh
env:
build_os: ${{ matrix.os }}
build_arch: ${{ matrix.arch }}
rocm_version: ${{ matrix.rocm_version }}
- name: Upload build artifact
uses: actions/upload-artifact@v4
with:
name: shared_library_rocm_${{ matrix.os }}_${{ matrix.arch }}_${{ matrix.rocm_version }}
path: output/*
retention-days: 7
build-wheels: build-wheels:
needs: needs:
- build-shared-libs - build-shared-libs
- build-shared-libs-cuda - build-shared-libs-cuda
- build-shared-libs-rocm
strategy: strategy:
matrix: matrix:
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest] os: [ubuntu-22.04, ubuntu-22.04-arm, windows-latest, macos-latest]
...@@ -173,6 +218,7 @@ jobs: ...@@ -173,6 +218,7 @@ jobs:
merge-multiple: true merge-multiple: true
- name: Inspect tmp directory after downloading artifacts - name: Inspect tmp directory after downloading artifacts
run: | run: |
ls -alFR tmp/ ls -alFR tmp/
WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l) WHEEL_COUNT=$(find tmp/ -type f -name "*.whl" | wc -l)
...@@ -210,6 +256,7 @@ jobs: ...@@ -210,6 +256,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
path: repo path: repo
- name: Delete old pre-release (if exists) - name: Delete old pre-release (if exists)
run: | run: |
cd repo && gh release delete continuous-release_main --cleanup-tag -y cd repo && gh release delete continuous-release_main --cleanup-tag -y
......
...@@ -25,13 +25,14 @@ endif() ...@@ -25,13 +25,14 @@ endif()
# Define included source files # Define included source files
set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp) set(CPP_FILES csrc/common.cpp csrc/cpu_ops.cpp csrc/pythonInterface.cpp)
set(CUDA_FILES csrc/ops.cu csrc/kernels.cu) set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm) set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal) set(METAL_FILES csrc/mps_kernels.metal)
# C++ sources are always included # C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES}) list(APPEND SRC_FILES ${CPP_FILES})
set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, mps)") set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda mps) set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF) option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
if(APPLE) if(APPLE)
...@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda") ...@@ -47,15 +48,25 @@ if(${COMPUTE_BACKEND} STREQUAL "cuda")
message(FATAL_ERROR "CUDA is not supported on macOS" ) message(FATAL_ERROR "CUDA is not supported on macOS" )
endif() endif()
set(BUILD_CUDA ON) set(BUILD_CUDA ON)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "hip")
if(APPLE)
message(FATAL_ERROR "HIP is not supported on macOS" )
endif()
set(BUILD_CUDA OFF)
set(BUILD_HIP ON)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
elseif(${COMPUTE_BACKEND} STREQUAL "mps") elseif(${COMPUTE_BACKEND} STREQUAL "mps")
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
endif() endif()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON) set(BUILD_MPS ON)
else() else()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
endif() endif()
...@@ -160,6 +171,33 @@ if(BUILD_CUDA) ...@@ -160,6 +171,33 @@ if(BUILD_CUDA)
string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}") string(APPEND BNB_OUTPUT_NAME "_cuda${CUDA_VERSION_SHORT}")
add_compile_definitions(BUILD_CUDA) add_compile_definitions(BUILD_CUDA)
elseif(BUILD_HIP)
enable_language(HIP)
message(STATUS "HIP Compiler: ${CMAKE_HIP_COMPILER}")
if(DEFINED BNB_ROCM_ARCH)
set(CMAKE_HIP_ARCHITECTURES ${BNB_ROCM_ARCH})
else()
if (NOT AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES "gfx90a;gfx942;gfx1100")
elseif (AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif()
endif()
message(STATUS "HIP Targets: ${CMAKE_HIP_ARCHITECTURES}")
list(APPEND SRC_FILES ${HIP_FILES})
string(APPEND BNB_OUTPUT_NAME "_rocm")
# get hip version
execute_process(COMMAND hipconfig --version OUTPUT_VARIABLE HIP_CONFIG_VERSION)
string(REGEX MATCH "[0-9]+\\.[0-9]+" HIP_VERSION "${HIP_CONFIG_VERSION}")
string(REPLACE "." "" HIP_VERSION_SHORT "${HIP_VERSION}")
string(APPEND BNB_OUTPUT_NAME "${HIP_VERSION_SHORT}")
add_compile_definitions(__HIP_PLATFORM_AMD__)
add_compile_definitions(__HIP_PLATFORM_HCC__)
add_compile_definitions(BUILD_HIP)
elseif(BUILD_MPS) elseif(BUILD_MPS)
if(NOT APPLE) if(NOT APPLE)
message(FATAL_ERROR "MPS is only supported on macOS" ) message(FATAL_ERROR "MPS is only supported on macOS" )
...@@ -208,6 +246,41 @@ if(BUILD_CUDA) ...@@ -208,6 +246,41 @@ if(BUILD_CUDA)
CUDA_SEPARABLE_COMPILATION ON CUDA_SEPARABLE_COMPILATION ON
) )
endif() endif()
if(BUILD_HIP)
if(NOT DEFINED ENV{ROCM_PATH})
set(ROCM_PATH /opt/rocm)
else()
set(ROCM_PATH $ENV{ROCM_PATH})
endif()
list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH})
macro(find_package_and_print_version PACKAGE_NAME)
find_package("${PACKAGE_NAME}" ${ARGN})
message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}")
endmacro()
find_package_and_print_version(hipblas REQUIRED)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
if(HIP_VERSION VERSION_LESS "6.1")
target_compile_definitions(bitsandbytes PUBLIC NO_HIPBLASLT)
else()
find_package(hipblaslt)
target_link_libraries(bitsandbytes PUBLIC roc::hipblaslt)
endif()
endif()
if(BUILD_MPS) if(BUILD_MPS)
add_dependencies(bitsandbytes metallib) add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
from ..._ops import register_kernel from ..._ops import register_kernel
from ...cextension import lib from ...cextension import HIP_ENVIRONMENT, lib
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda") @register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
...@@ -210,7 +210,12 @@ def _get_col_absmax( ...@@ -210,7 +210,12 @@ def _get_col_absmax(
@register_kernel("bitsandbytes::quantize_blockwise", "cuda") @register_kernel("bitsandbytes::quantize_blockwise", "cuda")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize) torch._check_is_size(blocksize)
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}") torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
n = A.numel() n = A.numel()
...@@ -264,7 +269,11 @@ def _( ...@@ -264,7 +269,11 @@ def _(
def _dequantize_blockwise_impl( def _dequantize_blockwise_impl(
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
) -> None: ) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}") torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
torch._check( torch._check(
dtype in [torch.float16, torch.bfloat16, torch.float32], dtype in [torch.float16, torch.bfloat16, torch.float32],
...@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl( ...@@ -294,7 +303,11 @@ def _dequantize_blockwise_impl(
def _( def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"]) torch._check(quant_type in ["fp4", "nf4"])
torch._check( torch._check(
A.dtype in [torch.bfloat16, torch.float16, torch.float32], A.dtype in [torch.bfloat16, torch.float16, torch.float32],
...@@ -372,7 +385,11 @@ def _dequantize_4bit_impl( ...@@ -372,7 +385,11 @@ def _dequantize_4bit_impl(
dtype: torch.dtype, dtype: torch.dtype,
out: torch.Tensor, out: torch.Tensor,
) -> None: ) -> None:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64]) if HIP_ENVIRONMENT:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
else:
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
torch._check(quant_type in ["fp4", "nf4"]) torch._check(quant_type in ["fp4", "nf4"])
torch._check( torch._check(
dtype in [torch.bfloat16, torch.float16, torch.float32], dtype in [torch.bfloat16, torch.float16, torch.float32],
......
...@@ -9,7 +9,7 @@ from typing import Optional ...@@ -9,7 +9,7 @@ from typing import Optional
import torch import torch
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -28,6 +28,11 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: ...@@ -28,6 +28,11 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
override_value = os.environ.get("BNB_CUDA_VERSION") override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value: if override_value:
library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1) library_name = re.sub(r"cuda\d+", f"cuda{override_value}", library_name, count=1)
if torch.version.hip:
raise RuntimeError(
f"BNB_CUDA_VERSION={override_value} detected for ROCm!! \n"
f"Clear the variable and retry: export BNB_CUDA_VERSION=\n"
)
logger.warning( logger.warning(
f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n" f"WARNING: BNB_CUDA_VERSION={override_value} environment variable detected; loading {library_name}.\n"
"This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n" "This can be used to load a bitsandbytes version built with a CUDA version that is different from the PyTorch CUDA version.\n"
...@@ -75,10 +80,11 @@ class CudaBNBNativeLibrary(BNBNativeLibrary): ...@@ -75,10 +80,11 @@ class CudaBNBNativeLibrary(BNBNativeLibrary):
def get_available_cuda_binary_versions() -> list[str]: def get_available_cuda_binary_versions() -> list[str]:
"""Get formatted CUDA versions from existing library files using cuda_specs logic""" """Get formatted CUDA versions from existing library files using cuda_specs logic"""
lib_pattern = f"libbitsandbytes_cuda*{DYNAMIC_LIBRARY_SUFFIX}" lib_pattern = f"libbitsandbytes_{BNB_BACKEND.lower()}*{DYNAMIC_LIBRARY_SUFFIX}"
versions = [] versions = []
for lib in Path(__file__).parent.glob(lib_pattern): for lib in Path(__file__).parent.glob(lib_pattern):
match = re.search(r"cuda(\d{3})", lib.name) pattern = rf"{BNB_BACKEND.lower()}(\d+)"
match = re.search(pattern, lib.name)
if match: if match:
ver_code = int(match.group(1)) ver_code = int(match.group(1))
major = ver_code // 10 major = ver_code // 10
...@@ -89,8 +95,8 @@ def get_available_cuda_binary_versions() -> list[str]: ...@@ -89,8 +95,8 @@ def get_available_cuda_binary_versions() -> list[str]:
def parse_cuda_version(version_str: str) -> str: def parse_cuda_version(version_str: str) -> str:
"""Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')""" """Convert raw version string (e.g. '118' from env var) to formatted version (e.g. '11.8')"""
if version_str.isdigit() and len(version_str) == 3: if version_str.isdigit():
return f"{version_str[:2]}.{version_str[2]}" return f"{version_str[:-1]}.{version_str[-1]}"
return version_str # fallback as safety net return version_str # fallback as safety net
...@@ -151,7 +157,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -151,7 +157,7 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
"""Format detailed error message for library loading failures""" """Format detailed error message for library loading failures"""
analysis = "" analysis = ""
no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error no_cpu_lib_found = "libbitsandbytes_cpu.so: cannot open" in original_error
no_cuda_lib_found = "CUDA binary not found" in original_error no_cuda_lib_found = f"{BNB_BACKEND} binary not found" in original_error
if no_cpu_lib_found: if no_cpu_lib_found:
analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n" analysis = "\n🚨 Failed to load CPU-only bitsandbytes library 🚨\n\n"
...@@ -160,9 +166,9 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -160,9 +166,9 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE" version_list_str = "\n - " + "\n - ".join(available_versions) if available_versions else "NONE"
analysis = ( analysis = (
( (
f"\n🚨 CUDA VERSION MISMATCH 🚨\n" f"\n🚨 {BNB_BACKEND} VERSION MISMATCH 🚨\n"
f"Requested CUDA version: {requested_version}\n" f"Requested {BNB_BACKEND} version: {requested_version}\n"
f"Detected PyTorch CUDA version: {user_cuda_version}\n" f"Detected PyTorch {BNB_BACKEND} version: {user_cuda_version}\n"
f"Available pre-compiled versions: {version_list_str}\n\n" f"Available pre-compiled versions: {version_list_str}\n\n"
"This means:\n" "This means:\n"
"The version you're trying to use is NOT distributed with this package\n\n" "The version you're trying to use is NOT distributed with this package\n\n"
...@@ -177,42 +183,47 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -177,42 +183,47 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
troubleshooting = ( troubleshooting = (
( (
"This typically happens when:\n" f"This typically happens when:\n"
"1. bitsandbytes doesn't ship with a pre-compiled binary for your CUDA version\n" f"1. bitsandbytes doesn't ship with a pre-compiled binary for your {BNB_BACKEND} version\n"
"2. The library wasn't compiled properly during installation from source\n\n" f"2. The library wasn't compiled properly during installation from source\n\n"
) )
if no_cuda_lib_found if no_cuda_lib_found
else "This typically happens when you checked the code out from source and your torch installation doesn't detect CUDA on your machine.\n\n" else f"This typically happens when you checked the code out from source and your torch installation doesn't detect {BNB_BACKEND} on your machine.\n\n"
) )
note = ( note = (
( (
"To make bitsandbytes work, the compiled library version MUST exactly match the linked CUDA version.\n" f"To make bitsandbytes work, the compiled library version MUST exactly match the linked {BNB_BACKEND} version.\n"
"If your CUDA version doesn't have a pre-compiled binary, you MUST compile from source.\n\n" f"If your {BNB_BACKEND} version doesn't have a pre-compiled binary, you MUST compile from source.\n\n"
) )
if no_cuda_lib_found if no_cuda_lib_found
else "" else ""
) )
compile_instructions = ( compile_instructions = (
( ("COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n")
if not no_cuda_lib_found
else (
"You have two options:\n" "You have two options:\n"
"1. COMPILE FROM SOURCE (required if no binary exists):\n" "1. COMPILE FROM SOURCE (required if no binary exists):\n"
" https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n" " https://huggingface.co/docs/bitsandbytes/main/en/installation#cuda-compile\n"
"2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n" "2. Use BNB_CUDA_VERSION to specify a DIFFERENT CUDA version from the detected one, which is installed on your machine and matching an available pre-compiled version listed above\n\n"
) )
if no_cuda_lib_found if not HIP_ENVIRONMENT
else "COMPILE FROM SOURCE for CPU-only:\n `cmake -DCOMPUTE_BACKEND=cpu -S . && make`\n\n" else (
"You can COMPILE FROM SOURCE as mentioned here:\n"
" https://huggingface.co/docs/bitsandbytes/main/en/installation?backend=AMD+ROCm#amd-gpu\n"
)
) )
diagnostics = ( diagnostics = (
"🔍 Run this command for detailed diagnostics:\n" f"🔍 Run this command for detailed diagnostics:\n"
"python -m bitsandbytes\n\n" f"python -m bitsandbytes\n\n"
"If you've tried everything and still have issues:\n" f"If you've tried everything and still have issues:\n"
"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
"2. Describe what you've tried in detail\n" f"2. Describe what you've tried in detail\n"
"3. Open an issue with this information:\n" f"3. Open an issue with this information:\n"
" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
) )
return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}" return f"{analysis}{base_msg}{troubleshooting}{note}{compile_instructions}{original_error}\n{diagnostics}"
...@@ -227,18 +238,19 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary): ...@@ -227,18 +238,19 @@ class ErrorHandlerMockBNBNativeLibrary(BNBNativeLibrary):
) )
return ( return (
f"\n🚨 CUDA SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n" f"\n🚨 {BNB_BACKEND} SETUP ERROR: Missing dependency: {missing_lib} 🚨\n\n"
f"CUDA {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n" f"{BNB_BACKEND} {cuda_major_version}.x runtime libraries were not found in the LD_LIBRARY_PATH.\n\n"
f"To fix this, make sure that:\n" f"To fix this, make sure that:\n"
f"1. You have installed CUDA {cuda_major_version}.x toolkit on your system\n" f"1. You have installed {BNB_BACKEND} {cuda_major_version}.x toolkit on your system\n"
f"2. The CUDA runtime libraries are in your LD_LIBRARY_PATH\n\n" f"2. The {BNB_BACKEND} runtime libraries are in your LD_LIBRARY_PATH\n\n"
f"You can add them with (and persist the change by adding the line to your .bashrc):\n" f"You can add them with (and persist the change by adding the line to your .bashrc):\n"
f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/cuda-{cuda_major_version}.x/lib64\n\n" f" export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/path/to/{BNB_BACKEND.lower()}-{cuda_major_version}.x/\
{'lib64' if not HIP_ENVIRONMENT else 'lib'}\n\n"
f"Original error: {self.error_msg}\n\n" f"Original error: {self.error_msg}\n\n"
f"🔍 Run this command for detailed diagnostics:\n" f"🔍 Run this command for detailed diagnostics:\n"
f"python -m bitsandbytes\n\n" f"python -m bitsandbytes\n\n"
f"If you've tried everything and still have issues:\n" f"If you've tried everything and still have issues:\n"
f"1. Include ALL version info (operating system, bitsandbytes, pytorch, cuda, python)\n" f"1. Include ALL version info (operating system, bitsandbytes, pytorch, {BNB_BACKEND.lower()}, python)\n"
f"2. Describe what you've tried in detail\n" f"2. Describe what you've tried in detail\n"
f"3. Open an issue with this information:\n" f"3. Open an issue with this information:\n"
f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n" f" https://github.com/bitsandbytes-foundation/bitsandbytes/issues\n\n"
...@@ -267,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -267,7 +279,7 @@ def get_native_library() -> BNBNativeLibrary:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if not cuda_binary_path.exists(): if not cuda_binary_path.exists():
raise RuntimeError(f"Configured CUDA binary not found at {cuda_binary_path}") raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}")
binary_path = cuda_binary_path binary_path = cuda_binary_path
...@@ -286,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary: ...@@ -286,6 +298,8 @@ def get_native_library() -> BNBNativeLibrary:
return BNBNativeLibrary(dll) return BNBNativeLibrary(dll)
ROCM_GPU_ARCH = get_rocm_gpu_arch()
try: try:
# to support Intel CPU/GPU (XPU) backend # to support Intel CPU/GPU (XPU) backend
import intel_extension_for_pytorch as ipex import intel_extension_for_pytorch as ipex
...@@ -296,8 +310,12 @@ except BaseException: ...@@ -296,8 +310,12 @@ except BaseException:
ipex_cpu = None ipex_cpu = None
ipex_xpu = None ipex_xpu = None
try: try:
if torch.version.hip:
HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
else:
HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
lib = get_native_library() lib = get_native_library()
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
......
import dataclasses import dataclasses
from functools import lru_cache from functools import lru_cache
import logging
import re
import subprocess
from typing import Optional from typing import Optional
import torch import torch
...@@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]: ...@@ -73,3 +76,27 @@ def get_cuda_specs() -> Optional[CUDASpecs]:
) )
except Exception: except Exception:
return None return None
def get_rocm_gpu_arch() -> str:
"""Get ROCm GPU architecture."""
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
if match:
return "gfx" + match.group(1)
else:
return "unknown"
else:
return "unknown"
except Exception as e:
logger.error(f"Could not detect ROCm GPU architecture: {e}")
if torch.cuda.is_available():
logger.warning(
"""
ROCm GPU architecture detection failed despite ROCm being available.
""",
)
return "unknown"
...@@ -5,7 +5,7 @@ from pathlib import Path ...@@ -5,7 +5,7 @@ from pathlib import Path
import torch import torch
from bitsandbytes.cextension import get_cuda_bnb_library_path from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.diagnostics.utils import print_dedented from bitsandbytes.diagnostics.utils import print_dedented
...@@ -32,9 +32,13 @@ CUDART_PATH_IGNORED_ENVVARS = { ...@@ -32,9 +32,13 @@ CUDART_PATH_IGNORED_ENVVARS = {
} }
CUDA_RUNTIME_LIB_PATTERNS = ( CUDA_RUNTIME_LIB_PATTERNS = (
"cudart64*.dll", # Windows ("libamdhip64.so*",)
"libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. if HIP_ENVIRONMENT
"nvcuda*.dll", # Windows else (
"cudart64*.dll", # Windows
"libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc.
"nvcuda*.dll", # Windows
)
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -56,7 +60,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path ...@@ -56,7 +60,7 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path
pass pass
for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS: for lib_pattern in CUDA_RUNTIME_LIB_PATTERNS:
for pth in dir.glob(lib_pattern): for pth in dir.glob(lib_pattern):
if pth.is_file(): if pth.is_file() and not pth.is_symlink():
yield pth yield pth
except (OSError, PermissionError): except (OSError, PermissionError):
pass pass
...@@ -103,7 +107,7 @@ def find_cudart_libraries() -> Iterator[Path]: ...@@ -103,7 +107,7 @@ def find_cudart_libraries() -> Iterator[Path]:
yield from find_cuda_libraries_in_path_list(value) yield from find_cuda_libraries_in_path_list(value)
def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print( print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
...@@ -128,7 +132,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: ...@@ -128,7 +132,37 @@ def print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
) )
def print_cuda_runtime_diagnostics() -> None: def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
print_dedented(
f"""
Library not found: {binary_path}.
Maybe you need to compile it from source? If you compiled from source, check that ROCm version
in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version
and rebuild bitsandbytes.
""",
)
hip_major, hip_minor = cuda_specs.cuda_version_tuple
if (hip_major, hip_minor) < (6, 1):
print_dedented(
"""
WARNING: bitsandbytes is fully supported only from ROCm 6.1.
""",
)
def print_diagnostics(cuda_specs: CUDASpecs) -> None:
if HIP_ENVIRONMENT:
_print_hip_diagnostics(cuda_specs)
else:
_print_cuda_diagnostics(cuda_specs)
def _print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries()) cudart_paths = list(find_cudart_libraries())
if not cudart_paths: if not cudart_paths:
print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.") print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
...@@ -153,3 +187,33 @@ def print_cuda_runtime_diagnostics() -> None: ...@@ -153,3 +187,33 @@ def print_cuda_runtime_diagnostics() -> None:
) )
for pth in cudart_paths: for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}") print(f"* Found CUDA runtime at: {pth}")
def _print_hip_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
print_dedented(
f"""
Found duplicate ROCm runtime files (see below).
We select the PyTorch default ROCm runtime, which is {torch.version.hip},
but this might mismatch with the ROCm version that is needed for bitsandbytes.
To resolve it, install PyTorch built for the ROCm version you want to use
and set LD_LIBRARY_PATH to your ROCm install path, e.g.
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib,
""",
)
for pth in cudart_paths:
print(f"* Found ROCm runtime at: {pth}")
def print_runtime_diagnostics() -> None:
if HIP_ENVIRONMENT:
_print_hip_runtime_diagnostics()
else:
_print_cuda_runtime_diagnostics()
...@@ -6,10 +6,11 @@ import traceback ...@@ -6,10 +6,11 @@ import traceback
import torch import torch
from bitsandbytes import __version__ as bnb_version from bitsandbytes import __version__ as bnb_version
from bitsandbytes.cextension import BNB_BACKEND
from bitsandbytes.consts import PACKAGE_GITHUB_URL from bitsandbytes.consts import PACKAGE_GITHUB_URL
from bitsandbytes.cuda_specs import get_cuda_specs from bitsandbytes.cuda_specs import get_cuda_specs
from bitsandbytes.diagnostics.cuda import ( from bitsandbytes.diagnostics.cuda import (
print_cuda_diagnostics, print_diagnostics,
) )
from bitsandbytes.diagnostics.utils import print_dedented, print_header from bitsandbytes.diagnostics.utils import print_dedented, print_header
...@@ -77,19 +78,19 @@ def main(): ...@@ -77,19 +78,19 @@ def main():
cuda_specs = get_cuda_specs() cuda_specs = get_cuda_specs()
if cuda_specs: if cuda_specs:
print_cuda_diagnostics(cuda_specs) print_diagnostics(cuda_specs)
# TODO: There's a lot of noise in this; needs improvement. # TODO: There's a lot of noise in this; needs improvement.
# print_cuda_runtime_diagnostics() # print_cuda_runtime_diagnostics()
if not torch.cuda.is_available(): if not torch.cuda.is_available():
print("PyTorch says CUDA is not available. Possible reasons:") print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:")
print("1. CUDA driver not installed") print(f"1. {BNB_BACKEND} driver not installed")
print("2. Using a CPU-only PyTorch build") print("2. Using a CPU-only PyTorch build")
print("3. No GPU detected") print("3. No GPU detected")
else: else:
print("Checking that the library is importable and CUDA is callable...") print(f"Checking that the library is importable and {BNB_BACKEND} is callable...")
try: try:
sanity_check() sanity_check()
......
...@@ -15,7 +15,7 @@ from typing_extensions import deprecated ...@@ -15,7 +15,7 @@ from typing_extensions import deprecated
from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
from .cextension import ipex_cpu, ipex_xpu, lib from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
name2qmap = {} name2qmap = {}
...@@ -868,10 +868,12 @@ def quantize_fp4( ...@@ -868,10 +868,12 @@ def quantize_fp4(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_storage=torch.uint8, quant_storage=torch.uint8,
): ):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
...@@ -879,10 +881,12 @@ def quantize_nf4( ...@@ -879,10 +881,12 @@ def quantize_nf4(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_storage=torch.uint8, quant_storage=torch.uint8,
): ):
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage) return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
...@@ -890,7 +894,7 @@ def quantize_4bit( ...@@ -890,7 +894,7 @@ def quantize_4bit(
A: torch.Tensor, A: torch.Tensor,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize=64, blocksize=None,
compress_statistics=False, compress_statistics=False,
quant_type="fp4", quant_type="fp4",
quant_storage=torch.uint8, quant_storage=torch.uint8,
...@@ -904,7 +908,7 @@ def quantize_4bit( ...@@ -904,7 +908,7 @@ def quantize_4bit(
absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values. absmax (`torch.Tensor`, *optional*): A tensor to use to store the absmax values.
out (`torch.Tensor`, *optional*): A tensor to use to store the result. out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*): blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64. The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False. compress_statistics (`bool`, *optional*): Whether to additionally quantize the absmax values. Defaults to False.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
...@@ -918,6 +922,10 @@ def quantize_4bit( ...@@ -918,6 +922,10 @@ def quantize_4bit(
- `torch.Tensor`: The quantized tensor with packed 4-bit values. - `torch.Tensor`: The quantized tensor with packed 4-bit values.
- [`QuantState`]: The state object used to undo the quantization. - [`QuantState`]: The state object used to undo the quantization.
""" """
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
input_shape = A.shape input_shape = A.shape
_out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default( _out, _absmax = torch.ops.bitsandbytes.quantize_4bit.default(
...@@ -968,8 +976,10 @@ def dequantize_fp4( ...@@ -968,8 +976,10 @@ def dequantize_fp4(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
...@@ -978,8 +988,10 @@ def dequantize_nf4( ...@@ -978,8 +988,10 @@ def dequantize_nf4(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4") return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
...@@ -988,7 +1000,7 @@ def dequantize_4bit( ...@@ -988,7 +1000,7 @@ def dequantize_4bit(
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
absmax: Optional[torch.Tensor] = None, absmax: Optional[torch.Tensor] = None,
out: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
quant_type="fp4", quant_type="fp4",
) -> torch.Tensor: ) -> torch.Tensor:
"""Dequantizes a packed 4-bit quantized tensor. """Dequantizes a packed 4-bit quantized tensor.
...@@ -1007,7 +1019,7 @@ def dequantize_4bit( ...@@ -1007,7 +1019,7 @@ def dequantize_4bit(
Required if `quant_state` is not provided and ignored otherwise. Required if `quant_state` is not provided and ignored otherwise.
out (`torch.Tensor`, *optional*): A tensor to use to store the result. out (`torch.Tensor`, *optional*): A tensor to use to store the result.
blocksize (`int`, *optional*): blocksize (`int`, *optional*):
The size of the blocks. Defaults to 64. The size of the blocks. Defaults to 128 on ROCm and 64 otherwise.
Valid values are 64, 128, 256, 512, 1024, 2048, and 4096. Valid values are 64, 128, 256, 512, 1024, 2048, and 4096.
quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`. quant_type (`str`, *optional*): The data type to use: `nf4` or `fp4`. Defaults to `fp4`.
...@@ -1017,6 +1029,10 @@ def dequantize_4bit( ...@@ -1017,6 +1029,10 @@ def dequantize_4bit(
Returns: Returns:
`torch.Tensor`: The dequantized tensor. `torch.Tensor`: The dequantized tensor.
""" """
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
if quant_state is None: if quant_state is None:
assert absmax is not None and out is not None assert absmax is not None and out is not None
......
...@@ -11,6 +11,7 @@ from torch import Tensor, device, dtype, nn ...@@ -11,6 +11,7 @@ from torch import Tensor, device, dtype, nn
import torch.nn.functional as F import torch.nn.functional as F
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
from bitsandbytes.optim import GlobalOptimManager from bitsandbytes.optim import GlobalOptimManager
from bitsandbytes.utils import ( from bitsandbytes.utils import (
...@@ -213,7 +214,7 @@ class Params4bit(torch.nn.Parameter): ...@@ -213,7 +214,7 @@ class Params4bit(torch.nn.Parameter):
data: Optional[torch.Tensor] = None, data: Optional[torch.Tensor] = None,
requires_grad=False, # quantized weights should be frozen by default requires_grad=False, # quantized weights should be frozen by default
quant_state: Optional[QuantState] = None, quant_state: Optional[QuantState] = None,
blocksize: int = 64, blocksize: Optional[int] = None,
compress_statistics: bool = True, compress_statistics: bool = True,
quant_type: str = "fp4", quant_type: str = "fp4",
quant_storage: torch.dtype = torch.uint8, quant_storage: torch.dtype = torch.uint8,
...@@ -223,6 +224,9 @@ class Params4bit(torch.nn.Parameter): ...@@ -223,6 +224,9 @@ class Params4bit(torch.nn.Parameter):
if data is None: if data is None:
data = torch.empty(0) data = torch.empty(0)
if blocksize is None:
blocksize = 64 if not HIP_ENVIRONMENT else 128
self = torch.Tensor._make_subclass(cls, data, requires_grad) self = torch.Tensor._make_subclass(cls, data, requires_grad)
self.blocksize = blocksize self.blocksize = blocksize
self.compress_statistics = compress_statistics self.compress_statistics = compress_statistics
......
#pragma once
#define BNB_WARP_SIZE warpSize
// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
#define BNB_MAX_THREADS_PER_SM 2048
#define BNB_BF16_AVAILABLE true
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include "kernels_hip.cuh"
#include "common_hip.cuh"
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
#include <hip/hip_math_constants.h>
//#include <mma.h>
#define HLF_MAX 65504
#define TH 1024
#define NUM 4
#define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
// Luckily we have atomicmax and atomicmin in ROCm
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
{
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else
return 0.33333333f*absmax*sign; // 1100
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
}
__device__ unsigned char dQuantizeFP4(float x)
{
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12
// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assume input data is in [-1.0, 1.0]
// !be careful here, its easy to make a mistake
// that is difficult to notice if you add an extra
// zero somewhere!
int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if(x > 0.29166667f)
if( x > 0.583333f)
if( x > 0.8333333f)
return 0b0011+sign;
else
return 0b0010+sign;
else
if(x > 0.4166667f)
return 0b101+sign;
else
return 0b100+sign;
else
if(x > 0.0859375f)
if(x > 0.20833333f)
return 0b0111+sign;
else
return 0b0110+sign;
else
if(x > 0.00260417f)
return 0b0001+sign;
else
return 0b0000+sign;
}
__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ unsigned char dQuantizeNF4(float x)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if(x > 0.03979014977812767f)
if(x > 0.3893125355243683f) // 1
if(x > 0.6427869200706482f) // 11
if(x > 0.8614784181118011f) // 111
return 0b1111;
else
return 0b1110;
else
if(x > 0.5016634166240692f) // 110
return 0b1101;
else
return 0b1100;
else
if(x > 0.2035212516784668f) // 10
if(x > 0.2920137718319893f) // 101
return 0b1011;
else
return 0b1010;
else
if(x > 0.1202552504837513f) // 100
return 0b1001;
else
return 0b1000;
else
if(x > -0.33967943489551544f) // 0
if(x > -0.13791173323988914f) // 01
if(x > -0.045525018125772476f) // 011
return 0b0111;
else
return 0b0110;
else
if(x > -0.23460740596055984f) // 010
return 0b0101;
else
return 0b0100;
else
if(x > -0.6106329262256622f) // 00
if(x > -0.4599952697753906f) // 001
return 0b0011;
else
return 0b0010;
else
if(x > -0.8480964004993439f) // 000
return 0b0001;
else
return 0b0000;
}
// sign function for lion
// taken from https://stackoverflow.com/a/4609795, but not sure if there's a proper way to do this in CUDA
template <typename T> __device__ int sgn(T val)
{
return (T(0) < val) - (val < T(0));
}
template <int STOCHASTIC>
__device__ unsigned char dQuantize(float* smem_code, const float rand, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = -1.0f;
float upper = 1.0f;
float val = smem_code[pivot];
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
}
val = smem_code[pivot];
}
if(upper_pivot == 255)
upper = smem_code[upper_pivot];
if(lower_pivot == 0)
lower = smem_code[lower_pivot];
if(!STOCHASTIC)
{
if(x > val)
{
float midpoint = (upper+val)*0.5f;
if(x > midpoint)
{
return upper_pivot;
}
else
return pivot;
}
else
{
float midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
else
{
if(x > val)
{
float dist_to_upper = fabsf(upper-x);
float dist_full = upper-val;
if(rand >= dist_to_upper/dist_full) return upper_pivot;
else return pivot;
}
else
{
float dist_to_lower = fabsf(lower-x);
float dist_full = val-lower;
if(rand >= dist_to_lower/dist_full) return lower_pivot;
else return pivot;
}
}
}
template <int SIGNED>
__device__ __forceinline__ unsigned char quantize_2D(float *__restrict__ quadrants, float *__restrict__ const smem_code, float x)
{
int pivot = 127;
int upper_pivot = 255;
int lower_pivot = 0;
float lower = SIGNED ? -1.0f : 0.0f;
float upper = 1.0f;
float midpoint;
float val = quadrants[1];
int local_pivot = 1;
int offset = 1;
// i>>=1 = {32, 16, 8, 4, 2, 1}
for(int i = 64; i > 0; i>>=1)
{
if(x > val)
{
lower_pivot = pivot;
lower = val;
pivot+=i;
//val = i == 64 ? quadrants[2] : smem_code[pivot];
local_pivot += offset;
}
else
{
upper_pivot = pivot;
upper = val;
pivot-=i;
//val = i == 64 ? quadrants[0] : smem_code[pivot];
local_pivot -= offset;
}
val = i >= 64 ? quadrants[local_pivot] : smem_code[pivot];
offset -= 1;
}
if(x > val)
{
midpoint = (upper+val)*0.5f;
if(x > midpoint)
return upper_pivot;
else
return pivot;
}
else
{
midpoint = (lower+val)*0.5f;
if(x < midpoint)
return lower_pivot;
else
return pivot;
}
}
__launch_bounds__(TH, 4)
__global__ void kQuantize(float * code, float * __restrict__ const A, unsigned char *out, const int n)
{
const int n_full = (NUM_BLOCK*(n/NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK);
int valid_items = (blockIdx.x+1 == gridDim.x) ? n - (blockIdx.x*NUM_BLOCK) : NUM_BLOCK;
const int base_idx = (blockIdx.x * NUM_BLOCK);
float vals[NUM];
unsigned char qvals[NUM];
//const int lane_id = threadIdx.x % 2;
typedef hipcub::BlockLoad<float, TH, NUM, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<unsigned char, TH, NUM, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ float smem_code[256];
//__shared__ float smem_code[2][257];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
//smem_code[0][threadIdx.x] = code[threadIdx.x];
//smem_code[1][threadIdx.x] = smem_code[0][threadIdx.x];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_BLOCK)
{
// number of values already processed in blocks +
// number of values already processed in this block +
// rand_offset % mod value
valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i;
__syncthreads();
LoadFloat(loadf).Load(&(A[i]), vals, valid_items);
#pragma unroll 4
for(int j = 0; j < NUM; j++)
qvals[j] = dQuantize<0>(smem_code, 0.0f, vals[j]);
__syncthreads();
StoreChar(storec).Store(&(out[i]), qvals, valid_items);
}
}
template<typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
//__launch_bounds__(TH, 4)
__global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n)
{
const int n_full = gridDim.x * BLOCK_SIZE;
int valid_items = 0;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
T vals[NUM_PER_TH];
float rand_vals[NUM_PER_TH];
unsigned char qvals[(DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH];
//float local_abs_max = -FLT_MAX;
float local_abs_max = 0.0f;
int local_rand_idx = 0;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/NUM_PER_TH, (DATA_TYPE > 0) ? NUM_PER_TH/2 : NUM_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_PER_TH> BlockReduce;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_PER_TH, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
__shared__ typename LoadT::TempStorage loadt;
__shared__ typename LoadFloat::TempStorage loadf;
__shared__ typename StoreChar::TempStorage storec;
__shared__ typename BlockReduce::TempStorage reduce;
__shared__ float smem_code[256];
__shared__ float smem_absmax_value[1];
if(DATA_TYPE == General8bit)
for(int i = threadIdx.x; i < 256; i+=blockDim.x)
smem_code[i] = code[i];
for (int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_abs_max = -FLT_MAX;
__syncthreads();
LoadT(loadt).Load(&(A[i]), vals, valid_items, (T)0.0f);
// 1. compute local max
// 2. broadcast local max
// 3. normalize inputs and quantize
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j]));
local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, hipcub::Max(), valid_items);
if(threadIdx.x == 0) {
smem_absmax_value[0] = 1.0f / local_abs_max;
absmax[i / BLOCK_SIZE] = local_abs_max;
}
__syncthreads();
local_abs_max = smem_absmax_value[0];
if(STOCHASTIC)
{
local_rand_idx = ((blockIdx.x*NUM_BLOCK) + (threadIdx.x*NUM) + rand_offset) % (1024-4);
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case General8bit:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
if(!STOCHASTIC)
qvals[j] = dQuantize<0>(smem_code, 0.0f, ((float)vals[j])*local_abs_max);
else
qvals[j] = dQuantize<1>(smem_code, rand_vals[j], ((float)vals[j])*local_abs_max);
}
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
qvals[j] = packed_4bit;
}
break;
}
__syncthreads();
StoreChar(storec).Store(&(out[(DATA_TYPE > 0) ? i/2 : i]), qvals, (DATA_TYPE > 0) ? (valid_items+1)/2 : valid_items);
}
}
template<typename T, int TILE_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * absmax, T *out, const int blocksize, const int n)
{
const int n_load = (gridDim.x * TILE_SIZE);
int valid_items_load = 0;
int valid_items_store = 0;
const int base_idx = (blockIdx.x * TILE_SIZE);
T vals[NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1)];
unsigned char qvals[NUM_PER_TH];
float local_abs_max = -FLT_MAX;
typedef hipcub::BlockLoad<unsigned char, THREADS, NUM_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<T, THREADS, NUM_PER_TH*((DATA_TYPE > 0) ? 2 : 1), hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ typename LoadChar::TempStorage loadchar;
__shared__ typename StoreT::TempStorage storet;
for (int i = base_idx; i < n_load; i += gridDim.x*TILE_SIZE)
{
if (DATA_TYPE > 0)
{
valid_items_load = min(TILE_SIZE, (n + 1) / 2 - i);
valid_items_store = min(TILE_SIZE * 2, n - i * 2);
}
else
{
valid_items_load = min(TILE_SIZE, n - i);
valid_items_store = valid_items_load;
}
// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
local_abs_max = __ldg(&absmax[(i+threadIdx.x*NUM_PER_TH) >> (31 - __clz(blocksize))]);
__syncthreads();
LoadChar(loadchar).Load(&(A[i]), qvals, valid_items_load, 128);
switch (DATA_TYPE)
{
case General8bit:
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
vals[j] = __ldg(&code[qvals[j]])*local_abs_max;
break;
case FP4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
vals[j*2] = dDequantizeNF4(qvals[j] >> 4)* local_abs_max;
vals[j*2 + 1] = dDequantizeNF4(qvals[j] & 0x0F)* local_abs_max;
}
break;
}
__syncthreads();
StoreT(storet).Store(&(out[(DATA_TYPE > 0) ? i*2 : i]), vals, valid_items_store);
}
}
__global__ void kDequantize(float *code, unsigned char *A, float *out, const int n)
{
const unsigned int numThreads = blockDim.x * gridDim.x;
const int idx = (blockIdx.x * blockDim.x) + threadIdx.x;
__shared__ float smem_code[256];
if(threadIdx.x < 256)
{
smem_code[threadIdx.x] = code[threadIdx.x];
}
__syncthreads();
for (int i = idx;i < n; i += numThreads)
{
out[i] = smem_code[A[i]];
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
T g_vals[NUM_VALS];
float s1_vals[NUM_VALS];
float s2_vals[NUM_VALS];
const float correction1 = 1.0f/(1.0f - powf(beta1, step));
const float correction2 = 1.0f/(1.0f - powf(beta2, step));
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
typename LoadFloat::TempStorage loadf;
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
case ADAM:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
s1_vals[j] = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
s1_vals[j] *= s1_vals[j]; // update l2 norm (update*update)
break;
}
}
# pragma unroll NUM_VALS-1
for(unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0]);
if(threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
//__syncwarp();
}
}
#define NUM_PER_THREAD 4
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit2State(T* g, T* p,
float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
float s2_vals[NUM_PER_THREAD];
// AdEMAMix has an additional state buffer, which we packed
// into state1. We need thread-local storage here for these.
// TODO: Mark with [[maybe_unused]] after upgrade to min compiler.
float s3_vals[NUM_PER_THREAD];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
typename Store::TempStorage store;
typename LoadFloat::TempStorage loadf;
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state2[i]), s2_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
// Load additional state1 data for AdEMAMix
// TODO: Make constexpr after updating min compiler
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[n + i]), s3_vals, valid_items);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
switch(OPTIMIZER)
{
case ADEMAMIX:
// m1 update: m1 = beta1 * m1 + (1-beta1) * g
s1_vals[j] = (s1_vals[j] * beta1) + ((1.0f - beta1) * (float)g_vals[j]);
// m2 update: m2 = m2 * beta3 + (1-beta3) * g
s3_vals[j] = (s3_vals[j] * beta3) + ((1.0f - beta3) * (float)g_vals[j]);
// nu update: nu = beta2 * nu + (1-beta2) * g^2
s2_vals[j] = (s2_vals[j] * beta2) + ((1.0f - beta2) * (float)g_vals[j] * (float)g_vals[j]);
p_vals[j] = (float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
);
if (weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j]) * (1.0f - (lr * weight_decay));
break;
case ADAM:
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f -beta1)*((float)g_vals[j]));
s2_vals[j] = s2_vals[j]*beta2 + ((1.0f -beta2)*(((float)g_vals[j])*((float)g_vals[j])));
p_vals[j] = ((float)p_vals[j]) + (update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(eps*correction2))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
break;
}
}
__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state2[i]), s2_vals, valid_items);
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[n + i]), s3_vals, valid_items);
}
}
}
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__launch_bounds__(BLOCK_SIZE/NUM_VALS, 1)
__global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
float* state1, float *unorm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
const int base_idx = (blockIdx.x * blockDim.x * NUM_VALS);
int valid_items = 0;
T g_vals[NUM_VALS];
float s1_vals[NUM_VALS];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockLoad<float, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
__shared__ union {
typename Load::TempStorage load;
typename LoadFloat::TempStorage loadf;
typename BlockReduce::TempStorage reduce;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i >= (BLOCK_SIZE) ? (BLOCK_SIZE) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items, 0.0f);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items, 0.0f);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
g_vals[j] = gnorm_scale*((float)g_vals[j]);
# pragma unroll NUM_VALS
for(unsigned int j = 0; j < NUM_VALS; j++)
{
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j]; // state update
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]); // state update
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*(float)g_vals[j]); // state update
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j])); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]); // state update
s1_vals[j] = __fdividef((float)g_vals[j],sqrtf(s1_vals[j])+eps); // update value
s1_vals[j] = s1_vals[j]*s1_vals[j]; // update norm
break;
}
}
# pragma unroll
for(unsigned int j = 1; j < NUM_VALS; j++)
s1_vals[0] += s1_vals[j];
__syncthreads();
s1_vals[0] = BlockReduce(temp_storage.reduce).Sum(s1_vals[0], valid_items);
if(threadIdx.x == 0)
atomicAdd(&unorm[0], s1_vals[0]);
//__syncwarp();
}
}
template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
{
const int n_full = ((TH*NUM_PER_THREAD)*(n/(TH*NUM_PER_THREAD))) + (n % (TH*NUM_PER_THREAD) == 0 ? 0 : (TH*NUM_PER_THREAD));
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = 0;
float update_scale = 0.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm+eps){ update_scale = (max_unorm*param_norm+eps)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
T g_vals[NUM_PER_THREAD];
T p_vals[NUM_PER_THREAD];
float s1_vals[NUM_PER_THREAD];
typedef hipcub::BlockLoad<T, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> Load;
typedef hipcub::BlockStore<T, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> Store;
typedef hipcub::BlockLoad<float, TH, NUM_PER_THREAD, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadFloat;
typedef hipcub::BlockStore<float, TH, NUM_PER_THREAD, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreFloat;
__shared__ union {
typename Load::TempStorage load;
typename Store::TempStorage store;
typename LoadFloat::TempStorage loadf;
typename StoreFloat::TempStorage storef;
} temp_storage;
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*TH*NUM_PER_THREAD)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
Load(temp_storage.load).Load(&(g[i]), g_vals, valid_items);
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
g_vals[j] = gnorm_scale*((float)g_vals[j]);
if(weight_decay > 0.0f)
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
break;
}
}
}
__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
}
}
#define NUM8BIT 16
#define NUM_THREADS 256
#define NUM_PER_BLOCK 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit2State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_max_s2 = -FLT_MAX;
float local_unorm = 0.0f;
float s2_vals[NUM8BIT];
float s1_vals[NUM8BIT];
T g_vals[NUM8BIT];
unsigned char m_c1[NUM8BIT];
unsigned char r_c2[NUM8BIT];
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
typename BlockReduce::TempStorage reduce;
} temp_storage;
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
if(threadIdx.x < 256)
{
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[threadIdx.x] = quantiles2[threadIdx.x];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += NUM_THREADS*gridDim.x*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state2[i]), r_c2, valid_items, 128);
__syncthreads();
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0]*beta1;
s1_vals[j] += (1.0f-beta1)*g_val;
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
}
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s2_vals[j] = smem_quantiles2[r_c2[j]]*max2[0]*beta2;
s2_vals[j] += (1.0f-beta2)*g_val*g_val;
local_max_s2 = fmaxf(local_max_s2, fabsf(s2_vals[j]));
}
if(unorm != NULL)
{
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
float correction1 = __fdividef(1.0f, 1.0f - powf(beta1, step));
float correction2 = __fdividef(1.0f, 1.0f - powf(beta2, step));
s1_vals[j] *= correction1;
s2_vals[j] *= correction2;
float update_val = s1_vals[j]/(sqrtf(s2_vals[j])+eps); // update
local_unorm += update_val*update_val;
}
}
}
__syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items);
__syncthreads();
local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, hipcub::Max(), valid_items);
if(unorm != NULL)
{
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items);
}
if(threadIdx.x == 0)
{
atomicMax(&new_max1[0], local_max_s1);
atomicMax(&new_max2[0], local_max_s2);
if(unorm != NULL){ atomicAdd(&unorm[0], local_unorm); }
}
}
#define NUM_PER_THREAD2 4
#define NUM_THREADS2 1024
#define NUM_PER_BLOCK2 4096
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS2, 1)
kOptimizerStatic8bit2State(T* p, T* const g, unsigned char* state1, unsigned char* state2,
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2];
float s2_vals[NUM_PER_THREAD2];
const float correction1 = 1.0f - powf(beta1, step);
const float correction2 = sqrtf(1.0f - powf(beta2, step));
const float step_size = -lr*correction2/correction1;
//const float step_size = -lr*correction2/correction1;
float new_max_val1 = 1.0f/new_max1[0];
float new_max_val2 = 1.0f/new_max2[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
unsigned char c2s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
__shared__ float smem_quantiles2[256];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 512)
{
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
else
smem_quantiles2[threadIdx.x-256] = quantiles2[threadIdx.x-256];
}
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[c1s[j]];
s1_vals[j] = s1_vals[j]*max1[0];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
s2_vals[j] = smem_quantiles2[c2s[j]];
s2_vals[j] = s2_vals[j]*max2[0];
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
c2s[j] = dQuantize<0>(smem_quantiles2, 0.0f, s2_vals[j]*new_max_val2);
}
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
p_vals[j] = (T)(((float)p_vals[j]) + ((update_scale*step_size*(s1_vals[j]/(sqrtf(s2_vals[j])+(correction2*eps))))));
if(weight_decay > 0.0f)
p_vals[j] = update_scale*((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
__syncthreads();
}
}
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(NUM_THREADS, 2)
kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned char*__restrict__ const state1,
float *unorm,
const float beta1, const float beta2,
const float eps, const int step,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
const float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = gridDim.x * NUM_PER_BLOCK;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD);
int valid_items = n - (blockIdx.x*NUM_PER_BLOCK) > NUM_PER_BLOCK ? NUM_PER_BLOCK : n - (blockIdx.x*NUM_PER_BLOCK);
float g_val = 0.0f;
float local_max_s1 = -FLT_MAX;
float local_unorm = 0.0f;
float s1_vals[NUM8BIT];
T g_vals[NUM8BIT];
unsigned char m_c1[NUM8BIT];
typedef hipcub::BlockLoad<T, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS, NUM8BIT, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadUInt8;
typedef hipcub::BlockReduce<float, NUM_THREADS> BlockReduce;
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadUInt8::TempStorage loadc;
typename BlockReduce::TempStorage reduce;
} temp_storage;
__shared__ float smem_quantiles1[256];
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS*NUM8BIT)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadUInt8(temp_storage.loadc).Load(&(state1[i]), m_c1, valid_items, 128);
#pragma unroll 16
for(int j = 0; j < NUM8BIT; j++)
{
g_val = g_vals[j];
g_val *= gnorm_scale;
s1_vals[j] = smem_quantiles1[m_c1[j]]*max1[0];
switch(OPTIMIZER)
{
case ADAGRAD:
case MOMENTUM:
if(step == 1)
s1_vals[j] = (float)g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
if(unorm != NULL)
local_unorm += s1_vals[j]*s1_vals[j];
break;
case LION:
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
}
local_max_s1 = fmaxf(local_max_s1, fabsf(s1_vals[j]));
}
}
__syncthreads();
local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, hipcub::Max(), valid_items);
if(threadIdx.x == 0){ atomicMax(&new_max1[0], local_max_s1); }
if(unorm != NULL)
{
__syncthreads();
local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, hipcub::Sum(), valid_items);
if(threadIdx.x == 0){ atomicAdd(&unorm[0], local_unorm); }
}
}
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(1024, 1)
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* max1, float* new_max1,
float weight_decay,
const float gnorm_scale, const int n)
{
const int n_full = (blockDim.x * gridDim.x)*NUM_PER_THREAD2;
const int base_idx = (blockIdx.x * blockDim.x * NUM_PER_THREAD2);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[NUM_PER_THREAD2];
float new_max_val1 = 1.0f/new_max1[0];
float update_scale = 1.0f;
if(max_unorm > 0.0f)
{
update_scale = max_unorm > 0.0f ? sqrtf(unorm[0]) : 1.0f;
if(update_scale > max_unorm*param_norm){ update_scale = (max_unorm*param_norm)/update_scale; }
else{ update_scale = 1.0f; }
}
else{ update_scale = 1.0f; }
unsigned char c1s[NUM_PER_THREAD2];
T p_vals[NUM_PER_THREAD2];
T g_vals[NUM_PER_THREAD2];
typedef hipcub::BlockLoad<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, NUM_THREADS2, NUM_PER_THREAD2, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[256];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
if(threadIdx.x < 256)
smem_quantiles1[threadIdx.x] = quantiles1[threadIdx.x];
__syncthreads();
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*NUM_THREADS2*NUM_PER_THREAD2)
{
valid_items = n - i >= (TH*NUM_PER_THREAD) ? (TH*NUM_PER_THREAD) : n - i;
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }
# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD2; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case ADAGRAD:
case MOMENTUM:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[c1s[j]]*max1[0];
switch(OPTIMIZER){
case ADAGRAD:
case MOMENTUM:
if(step == 1)
s1_vals[j] = g_vals[j];
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
break;
}
c1s[j] = dQuantize<0>(smem_quantiles1, 0.0f, s1_vals[j]*new_max_val1);
// make sure state1 term has still the same sign after quantization
if(signbit(smem_quantiles1[c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
}
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
}
}
template<typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int step, const int n)
{
const int n_full = (BLOCK_SIZE*(n/BLOCK_SIZE)) + (n % BLOCK_SIZE == 0 ? 0 : BLOCK_SIZE);
int valid_items = 0;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/NUM_VALS> BlockReduce;
typedef hipcub::BlockLoad<T, BLOCK_SIZE/NUM_VALS, NUM_VALS, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename BlockReduce::TempStorage reduce;
__shared__ typename LoadT::TempStorage loadT;
T vals[NUM_VALS];
float local_sum = 0.0f;
for (unsigned int i = (blockIdx.x * BLOCK_SIZE); i < n_full; i += gridDim.x*BLOCK_SIZE)
{
valid_items = n - i > BLOCK_SIZE ? BLOCK_SIZE : n - i;
local_sum = 0.0f;
__syncthreads();
LoadT(loadT).Load(&(g[i]), vals, valid_items, (T)0.0f);
#pragma unroll NUM_VALS
for(int j = 0; j < NUM_VALS; j++)
local_sum += ((float)vals[j])*((float)vals[j]);
local_sum = BlockReduce(reduce).Sum(local_sum, valid_items);
if(threadIdx.x == 0)
{
if(step == 1)
{
// initialize with the same norm for all positions
//#pragma unroll 10
for(int j = 0; j < 100; j++)
atomicAdd(&gnorm_vec[j], local_sum);
}
else
atomicAdd(&gnorm_vec[step % 100], local_sum);
}
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit2StateBlockwise(
T* p,
T* __restrict__ const g,
unsigned char* state1,
unsigned char* state2,
const float beta1,
const float beta2,
const float beta3,
const float alpha,
const float eps,
const int step,
const float lr,
float* __restrict__ const quantiles1,
float* __restrict__ const quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
const bool skip_zeros,
const int n
) {
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[N_PER_TH];
float s2_vals[N_PER_TH];
float s3_vals[N_PER_TH];
// 2-5%
const float correction1 = 1.0f - __powf(beta1, step);
const float correction2 = sqrtf(1.0f -__powf(beta2, step));
const float step_size = __fdividef(-lr*correction2,correction1);
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float new_local_abs_max2 = -FLT_MAX;
float new_local_abs_max3 = -FLT_MAX;
float quadrants1[QUAD];
float quadrants2[QUAD];
unsigned char c1s[N_PER_TH];
unsigned char c2s[N_PER_TH];
unsigned char c3s[N_PER_TH];
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
__shared__ float smem_quantiles2[LANES][257];
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce2;
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce3;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ typename BlockReduce2::TempStorage reduce2;
__shared__ typename BlockReduce2::TempStorage reduce3;
__shared__ float smem_exchange1[1];
__shared__ float smem_exchange2[1];
__shared__ float smem_exchange3[1]; // [[maybe_unused]]
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
smem_quantiles2[0][threadIdx.x] = quantiles2[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
{
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
smem_quantiles2[j][threadIdx.x] = smem_quantiles2[0][threadIdx.x];
}
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
{
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
quadrants2[k] = smem_quantiles2[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
}
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state2[i]), c2s, valid_items, 0);
// AdEMAMix has an additional state packed into state1.
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[n + i]), c3s, valid_items, 128);
}
new_local_abs_max1 = -FLT_MAX;
new_local_abs_max2 = -FLT_MAX;
new_local_abs_max3 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
s2_vals[j] = smem_quantiles2[lane_id][c2s[j]]*absmax2[i/BLOCK_SIZE];
g_val = g_vals[j];
//float ratio = (g_val*g_val)/fmaxf(s2_vals[j], eps*eps);
//g_val = ratio > 2.0f ? 2.0f*g_val/ratio : g_val;
g_val *= gnorm_scale;
s2_vals[j] = (s2_vals[j]*beta2) + (((1.0f-beta2)*g_val*g_val));
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
s1_vals[j] = (s1_vals[j]*beta1) + (((1.0f-beta1)*g_val));
if (OPTIMIZER == ADEMAMIX) {
// The absmax for the third state is appended to absmax1
s3_vals[j] = smem_quantiles1[lane_id][c3s[j]] * absmax1[(n + i)/BLOCK_SIZE];
s3_vals[j] = (s3_vals[j] * beta3) + (((1.0f - beta3) * g_val));
}
}
else
{
s1_vals[j] = 0.0f;
s2_vals[j] = 0.0f;
if (OPTIMIZER == ADEMAMIX) {
s3_vals[j] = 0.0f;
}
}
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
new_local_abs_max2 = fmaxf(new_local_abs_max2, fabsf(s2_vals[j]));
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = fmaxf(new_local_abs_max3, fabsf(s3_vals[j]));
}
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max());
new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, hipcub::Max());
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, hipcub::Max());
}
if(threadIdx.x == 0)
{
smem_exchange1[0] = new_local_abs_max1;
smem_exchange2[0] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) {
smem_exchange3[0] = new_local_abs_max3;
}
}
__syncthreads();
if(threadIdx.x == 0)
{
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
absmax2[i/BLOCK_SIZE] = new_local_abs_max2;
if (OPTIMIZER == ADEMAMIX) {
absmax1[(n + i)/BLOCK_SIZE] = new_local_abs_max3;
}
}
else
{
new_local_abs_max1 = smem_exchange1[0];
new_local_abs_max2 = smem_exchange2[0];
if (OPTIMIZER == ADEMAMIX) {
new_local_abs_max3 = smem_exchange3[0];
}
}
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
));
} else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
}
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
}
}
// store: 0.85/1.44 -> 2.48/1.57
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
c2s[j] = quantize_2D<0>(quadrants2, smem_quantiles2[lane_id], __fdividef(s2_vals[j],new_local_abs_max2));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
if (OPTIMIZER == ADEMAMIX) {
c3s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s3_vals[j],new_local_abs_max3));
if (signbit(smem_quantiles1[lane_id][c3s[j]]) != signbit(s3_vals[j])) {
c3s[j] += (s3_vals[j] > 0.0f) ? 1 : -1;
}
}
}
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state2[i]), c2s, valid_items);
if (OPTIMIZER == ADEMAMIX) {
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[n + i]), c3s, valid_items);
}
}
}
#define LANES 2
#define QUAD 3
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
float* absmax1,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
const int base_idx = (blockIdx.x * BLOCK_SIZE);
int valid_items = 0;
float g_val = 0.0f;
float s1_vals[N_PER_TH];
// 2-5%
const int lane_id = threadIdx.x % LANES;
float new_local_abs_max1 = -FLT_MAX;
float quadrants1[QUAD];
unsigned char c1s[N_PER_TH];
T g_vals[N_PER_TH];
T p_vals[N_PER_TH];
typedef hipcub::BlockLoad<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
typedef hipcub::BlockLoad<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_LOAD_WARP_TRANSPOSE> LoadChar;
typedef hipcub::BlockStore<unsigned char, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreChar;
typedef hipcub::BlockStore<T, BLOCK_SIZE/N_PER_TH, N_PER_TH, hipcub::BLOCK_STORE_WARP_TRANSPOSE> StoreT;
__shared__ float smem_quantiles1[LANES][257];
typedef hipcub::BlockReduce<float, BLOCK_SIZE/N_PER_TH> BlockReduce1;
__shared__ typename BlockReduce1::TempStorage reduce1;
__shared__ float smem_exchange1[1];
__shared__ union {
typename LoadT::TempStorage loadh;
typename LoadChar::TempStorage loadc;
typename StoreChar::TempStorage storec;
typename StoreT::TempStorage storeh;
} temp_storage;
// init: 0.2 -> 0.23
// 0.23 -> 0.23
smem_quantiles1[0][threadIdx.x] = quantiles1[threadIdx.x];
# pragma unroll
for(unsigned int j = 1; j < LANES; j++)
smem_quantiles1[j][threadIdx.x] = smem_quantiles1[0][threadIdx.x];
__syncthreads();
#pragma unroll
for(int k = 0; k < QUAD; k++)
quadrants1[k] = smem_quantiles1[lane_id][(k*256/(QUAD+1)) + (256/(QUAD+1)-1)];
for (unsigned int i = base_idx; i < n_full; i += gridDim.x*BLOCK_SIZE)
{
// loads: 0.23 -> 0.85/1.44
valid_items = n - i >= BLOCK_SIZE ? BLOCK_SIZE : n - i;
__syncthreads();
LoadT(temp_storage.loadh).Load(&(g[i]), g_vals, valid_items, (T)0.0f);
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
new_local_abs_max1 = -FLT_MAX;
// update: 2.48/1.57 -> 2.51/1.60
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
g_val = float(g_vals[j]);
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f) {
switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
case RMSPROP:
g_val += ((float)p_vals[j])*weight_decay;
break;
case LION:
p_vals[j] = ((float)p_vals[j])*(1.0f-lr*weight_decay);
break;
}
}
s1_vals[j] = smem_quantiles1[lane_id][c1s[j]]*absmax1[i/BLOCK_SIZE];
switch(OPTIMIZER)
{
case MOMENTUM:
if(step == 1)
s1_vals[j] = g_val;
else
s1_vals[j] = (s1_vals[j]*beta1) + g_val;
break;
case LION:
// here, using gvals[j] to store the gradient smoothed by beta1 for the following parameter update, before the momentum is updated by beta2
g_vals[j] = lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*g_val));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + (g_val*g_val);
break;
}
}
new_local_abs_max1 = fmaxf(new_local_abs_max1, fabsf(s1_vals[j]));
}
// reduce: 2.51/1.60 -> 2.67/1.69
new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, hipcub::Max());
if(threadIdx.x == 0)
smem_exchange1[0] = new_local_abs_max1;
__syncthreads();
if(threadIdx.x == 0)
absmax1[i/BLOCK_SIZE] = new_local_abs_max1;
else
new_local_abs_max1 = smem_exchange1[0];
// reduce: 2.67/1.69 -> 2.67/1.70
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
switch(OPTIMIZER)
{
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
case ADAGRAD:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
}
}
}
// store: 0.85/1.44 -> 2.48/1.57
__syncthreads();
StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
// quantizaztion: 2.67/1.70 -> 3.4/3.3
# pragma unroll N_PER_TH
for(unsigned int j = 0; j < N_PER_TH; j++)
{
c1s[j] = quantize_2D<1>(quadrants1, smem_quantiles1[lane_id], __fdividef(s1_vals[j],new_local_abs_max1));
// make sure state1 term has still the same sign after quantization
// (not needed for state2 term which has only positive values)
if(signbit(smem_quantiles1[lane_id][c1s[j]]) != signbit(s1_vals[j]))
{
if(s1_vals[j] > 0.0f)
c1s[j] += 1;
else
c1s[j] -= 1;
}
}
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
}
}
// Inputs:
// A [rows, cols]
// Outputs:
// rowStats [rows]
// out [rows, cols]
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
// Otherwise `T` is `fp16`. This can be removed when Maxwell is dropped.
#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR >= 2) || BNB_FP16_AVAILABLE
using TReduction = T;
#else
using TReduction = float;
#endif
using BlockReduceT = hipcub::BlockReduce<TReduction, THREADS>;
// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.
__shared__ typename BlockReduceT::TempStorage temp_storage;
__shared__ TReduction smem_row_absmax;
const int row_id = blockIdx.x;
const T* row_data = A + (row_id * cols);
// Threads will read the row values in a striped access pattern and find a local absmax.
TReduction row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const TReduction absval = fabsf(__ldcs(&(row_data[i])));
// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < TReduction(threshold) ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}
// Reduce thread-local absmax across the block.
const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = smem_row_absmax = row_absmax;
}
__syncthreads();
// Quantize row-wise.
const float scale = __fdividef(127.0f, smem_row_absmax);
for (int i = threadIdx.x; i < cols; i += THREADS) {
float val = row_data[i];
if constexpr (SPARSE_DECOMP) {
// For sparse decomposition, we do not want to quantize the outliers.
// Instead they're zeroed out.
out[row_id * cols + i] = fabs(val) < threshold ? __float2int_rn(val * scale) : 0;
} else {
out[row_id * cols + i] = __float2int_rn(val * scale);
}
}
}
template<typename T, int THREADS, int SPARSE_DECOMP>
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
__global__ void kgetRowStats(T * __restrict__ A, float *rowStats, float threshold, int rows, int cols) {
using BlockReduceT = hipcub::BlockReduce<float, THREADS>;
// One block per row.
// Threads load column values in a striped arrangement.
// e.g. t0 reads row[0], row[0+nthreads], ..
// and t1 reads row[1], row[1+nthreads], ..
// Each thread will determine its local absmax.
// We then do a blockwise reduction to determine the row's absmax.
__shared__ typename BlockReduceT::TempStorage temp_storage;
const int row_id = blockIdx.x;
const T* __restrict__ row_data = A + (row_id * cols);
// Threads will read the row values in a striped access pattern and find a local absmax.
float row_local_absmax = -FLT_MIN;
for (int i = threadIdx.x; i < cols; i += THREADS) {
const float absval = fabsf(row_data[i]);
// For sparse decomposition, values outside of the threshold are not to be
// included when calculating the row's absmax.
if constexpr (SPARSE_DECOMP) {
row_local_absmax = fmaxf(row_local_absmax, absval < threshold ? absval : row_local_absmax);
} else {
row_local_absmax = fmaxf(row_local_absmax, absval);
}
}
// Reduce thread-local absmax across the block.
// TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, hipcub::Max(), cols);
if (threadIdx.x == 0) {
// Save our block's absmax to shared memory for the quantization step.
rowStats[row_id] = row_absmax;
}
}
template __global__ void kgetRowStats<half, 1024, 0>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kgetRowStats<half, 1024, 1>(half * __restrict__ A, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 0>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
template __global__ void kInt8VectorQuant<half, 1024, 1>(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols);
#define MM_DEQUANT_CONST 6.200012e-05f //1.0f/(127.0f*127.0f)
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A,
float *__restrict__ const rowStats,
float *__restrict__ const colStats,
half *out,
half *__restrict__ const bias,
const int numRows,
const int numCols,
const int n
) {
const int n_out = numRows * numCols;
int block_offset = blockIdx.x * THREADS * ITEMS_PER_THREAD;
int thread_offset = threadIdx.x * ITEMS_PER_THREAD;
int local_values[ITEMS_PER_THREAD];
half local_output[ITEMS_PER_THREAD];
float local_rowStats[ITEMS_PER_THREAD];
float local_colStats[ITEMS_PER_THREAD];
float local_biasValue[ITEMS_PER_THREAD];
typedef hipcub::BlockLoad<int, THREADS, ITEMS_PER_THREAD, hipcub::BLOCK_LOAD_DIRECT> LoadInt32;
__shared__ typename LoadInt32::TempStorage loadint32;
int row_idx, col_idx;
#pragma unroll ITEMS_PER_THREAD
for(int j = 0; j < ITEMS_PER_THREAD; j++)
{
row_idx = (block_offset + thread_offset + j) / numCols;
col_idx = (block_offset + thread_offset + j) % numCols;
local_colStats[j] = col_idx >= numCols ? 0.0f : colStats[col_idx];
local_rowStats[j] = row_idx >= numRows ? 0.0f : rowStats[row_idx];
local_biasValue[j] = ((bias == nullptr) || (col_idx >= numCols)) ? 0.0f : __half2float(bias[col_idx]);
}
// Each block loads THREADS * ITEMS_PER_THREAD values from A
int valid_items = block_offset + THREADS * ITEMS_PER_THREAD < n_out
? THREADS * ITEMS_PER_THREAD
: n_out - block_offset;
LoadInt32(loadint32).Load(&(A[block_offset]), local_values, valid_items, 0);
#pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; ++j) {
local_output[j] = __float2half(
fmaf(local_values[j] * local_rowStats[j] * local_colStats[j], MM_DEQUANT_CONST, local_biasValue[j])
);
}
#pragma unroll ITEMS_PER_THREAD
for (int j = 0; j < ITEMS_PER_THREAD; j++) {
int outIdx = block_offset + thread_offset + j;
if (outIdx < n_out) {
out[outIdx] = local_output[j];
}
}
}
#define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256
#define WARP_SIZE warpSize
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{
// 0. load balancing: We process rows with most columns first (count_vec)and we process one row per block
// If a block finishes, the next one is scheduled. Since the last blocks like have fewer
// elements they finish faster "fillin up" the gaps left by larger blocks
// without tensor cores
// 1. use rowidx_length to find what to load (as many blocks as there are rows)
// 2. Load A into registers
// 3. each warp loads all required rows of B but each warp is offset by k
// 4. Do mma operations that accumulate into registers
// 5. Each warp stores its output row into matrix C
const int count = max_count[blockIdx.x];
const int local_max_idx = max_idx[blockIdx.x];
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
const int local_row_idx = rowidx[offset];
const int warp_id = threadIdx.x / WARP_SIZE;
const int warp_idx = threadIdx.x % WARP_SIZE;
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
const int num_items = BITS == 8 ? 8 : 8;
int idx_col_B = warp_offset;
int local_idx_col_B_offset = 0;
half local_valA[MAX_SPARSE_COUNT];
int local_colidxA[MAX_SPARSE_COUNT];
half local_valC[SPMM_ITEMS];
T local_valsB[num_items];
half local_valOut[num_items];
// 128 byte loads per warp == 4 bytes per thread
// 2. Load A into registers
for(int j = 0; j < MAX_SPARSE_COUNT; j++)
{
local_valA[j] = j < count ? values[offset+j] : __float2half(0.0f);
local_colidxA[j] = j < count ? colidx[offset+j] : 0;
}
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
// added 3 bytes = 6 values between warps should reduce bank conflicts
__shared__ half smem_dequant_stats[SMEM_SIZE];
while(idx_col_B < colsB)
{
if(dequant_stats != NULL)
{
for(int i = threadIdx.x; i < SMEM_SIZE; i+=blockDim.x)
if((idx_col_B+i-local_idx_col_B_offset) < colsB)
smem_dequant_stats[i] = dequant_stats[idx_col_B+i-local_idx_col_B_offset];
__syncthreads();
}
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j++)
local_valC[j] = 0.0f;
#pragma unroll
for(int i = 0; i < count; i++)
{
// 3. each warp loads all required rows of B but each warp is offset by k
int row_offset = colsB*local_colidxA[i];
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
// 4. Multiply the tile -> accumulate outputs in shared memory until 128 bytes it reached
int idx = idx_col_B + (warp_idx*SPMM_ITEMS) + j;
if(idx >= colsB){ break; }
if((idx+num_items < colsB))
{
if(BITS == 8)
reinterpret_cast<float2(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float2*>(B)[(row_offset+ idx)/num_items];
else
reinterpret_cast<float4(&)[num_items]>(local_valsB)[0] = reinterpret_cast<float4*>(B)[(row_offset+ idx)/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx+k < colsB)
local_valsB[k] = B[row_offset+idx+k];
else
local_valsB[k] = 0.0f;
}
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
{
if(BITS == 8 && dequant_stats != NULL)
// we do texture cache reads (__ldg) on dequant_stats which should be super fast
{
float valB = local_valsB[k];
float valA = local_valA[i];
if(valB != 0.0 && valA != 0.0)
local_valC[j+k] = (float)local_valC[j+k] + ((float)smem_dequant_stats[idx+k-local_idx_col_B_offset])*DENORM*valB*valA;
}
else
local_valC[j+k] = (float)local_valC[j+k] + (float)local_valsB[k]*(float)local_valA[i];
}
}
}
int idx_row_C = (colsB*local_row_idx);
#pragma unroll SPMM_ITEMS
for(int j = 0; j < SPMM_ITEMS; j+=num_items)
{
//int idx_col_C = idx_col_B + (32*j) + warp_idx;
int idx_col_C = idx_col_B + warp_idx*SPMM_ITEMS + j;
int idx_val = idx_col_C + idx_row_C;
if(idx_col_C +num_items < colsB)
{
// load outputs to do inplace addition
reinterpret_cast<float4(&)[num_items/4]>(local_valOut)[0] = reinterpret_cast<float4*>(out)[idx_val/num_items];
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
local_valC[(j/num_items) + k] = (float)local_valC[(j/num_items) + k] + (float)local_valOut[k];
reinterpret_cast<float4*>(out)[idx_val/num_items] = reinterpret_cast<float4(&)[num_items]>(local_valC)[j/num_items];
}
else
{
#pragma unroll num_items
for(int k = 0; k < num_items; k++)
if(idx_col_C + k < colsB)
out[idx_val+k] = (float)out[idx_val+k]+(float)local_valC[j+k];
}
}
idx_col_B += blockDim.x*SPMM_ITEMS;
local_idx_col_B_offset += blockDim.x*SPMM_ITEMS;
}
}
#define WARPS 3
template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M, int N, int K, T * __restrict__ const A, T* B, T * out, int lda, int ldb, int ldc)
{
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
const int val_per_iter = blockDim.x-32;
T local_A[4];
T local_B[128];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
__shared__ T smem_A[8*16 + (2*16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
//__shared__ T smem_C[8*32];
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);
int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 3;
}
else
{
if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;
__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
//local_A[0] = A[idx];
//#pragma unroll 32
//for(int col = 0; col < 32; col++)
// local_B[col] = B[(col_offset+col)*ldb+idx];
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+(1*val_per_iter)];
local_A[2] = A[idx+(2*val_per_iter)];
local_A[3] = A[idx+(3*val_per_iter)];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B[col] = B[(col_offset+col)*ldb+idx];
local_B[col+32] = B[(col_offset+col)*ldb+idx+(1*val_per_iter)];
local_B[col+64] = B[(col_offset+col)*ldb+idx+(2*val_per_iter)];
local_B[col+96] = B[(col_offset+col)*ldb+idx+(3*val_per_iter)];
}
loaded_values = 3;
}
else
{
if(loaded_values == 3)
{
local_A[0] = local_A[1];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(32)];
}
else if(loaded_values == 2)
{
local_A[0] = local_A[2];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(64)];
}
else
{
local_A[0] = local_A[3];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = local_B[col+(96)];
}
loaded_values--;
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
__syncthreads();
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_A[0]), c_frag, 32, rocwmma::mem_row_major);
if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_A[warp_lane];
#endif
}
template <typename T> __device__ void printnonzero(T *A, int num_values, const char * strval)
{
for(int i = 0; i < num_values; i++)
if((float)A[i] != 0.0)
printf("%s %i %f\n", strval, i, (float)A[i]);
}
template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
//// element-wise kernel
//// 1. Load batch x k into registers + //// 2. Load k x k into registers
//// 3. dequantize and store in second pair of k x k + //// 4. matmul
//// 5. sum with cub
//// 6. store outputs
//// TC kernel
//// use k warps per thread block
//// 1. threadblock use read-only cache to read in register tile for A into shared memory
//// 2. each warp loops over shared memory tiles of A of size 8x16 and loads them into fragments
//// 3. each warp reads a segment of values 16x32 from B
//// 4. do dequantization from register of B into second pair of registers
//// 5. store (4) into fragment
//// 6. matmul aggregate into fragment C
//// 7. aggregate files of C into shared memory block C
//// 8. sum (7)
//// 9. write outputs to matmul output matrix
#if __CUDA_ARCH__ >= 750
using namespace nvcuda;
int col_offset = blockIdx.x *32;
const int warp_id = threadIdx.x / 32;
const int warp_idx = threadIdx.x % 32;
const int half_warp_id = threadIdx.x / 16;
const int half_warp_lane = threadIdx.x % 16;
const int batch_size_warps = (WARPS-1)*2;
T quant_map[16];
#pragma unroll 16
for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i];
//__shared__ T quant_map[16*160];
T local_A[2];
T local_B[64];
unsigned char local_B_4bit[32];
const int a_tile_offset = 16;
const int b_tile_offset = (16*32 + 16);
__shared__ T smem_A[8*16 + (16*(batch_size_warps-1))];
__shared__ T smem_B[2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1))];
__shared__ T smem_C[8*32];
rocwmma::fragment<rocwmma::matrix_a, 8, 32, 16, half, rocwmma::row_major> a_frag;
rocwmma::fragment<rocwmma::matrix_b, 8, 32, 16, half, rocwmma::col_major> b_frag;
rocwmma::fragment<rocwmma::accumulator, 8, 32, 16, half> c_frag;
rocwmma::fill_fragment(c_frag, 0.0f);
for(int i = threadIdx.x; i < (8*32); i+=blockDim.x)
smem_C[i] = 0.0f;
__syncthreads();
int ticktock = 0;
int idx = 0 + threadIdx.x;
int loaded_values = 0;
// prefetch
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(1.0f);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(1.0f);
//local_B[col] = d2DequantizeFP4(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = d2DequantizeFP4(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = 127*(local_B_4bit[col/2] >> 4)*(float)(17.0);
//local_B[col+1] = 127*(local_B_4bit[col/2] & 0x0F)*(float)(17.0);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(17.0);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(17.0);
local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(17.0);
local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(17.0);
}
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
//if(threadIdx.x == 0)
//printf("aa %i %i\n", idx, loaded_values);
//for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
for(int base_idx = blockDim.x-32; base_idx < K; base_idx+=blockDim.x-32)
{
idx = base_idx + threadIdx.x;
//if(threadIdx.x == 0)
//printf("%i %i\n", idx, loaded_values);
//__syncthreads();
if(idx < K && warp_id < (WARPS-1))
{
if(loaded_values == 0)
{
local_A[0] = A[idx];
local_A[1] = A[idx+blockDim.x-32];
#pragma unroll 32
for(int col = 0; col < 32; col++)
{
local_B_4bit[col] = B[(col_offset+col)*ldb+idx];
local_B_4bit[col+16] = B[(col_offset+col)*ldb+idx];
}
loaded_values = 1;
}
else
{
local_A[0] = local_A[1];
loaded_values--;
int absidx = (idx + col_offset)/blocksize;
half local_absmax = __ldg(&(absmax[absidx]));
#pragma unroll 64
for(int col = 0; col < 64; col+=2)
{
//local_B[col] = dhDequantizeNF4(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = dhDequantizeNF4(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = T(127)*T(local_B_4bit[col/2] >> 4)*T(absidx);
//local_B[col+1] = T(127)*T(local_B_4bit[col/2] & 0x0F)*T(absidx);
//local_B[col] = quant_map[160*(local_B_4bit[col/2] >> 4)+warp_idx]*T(local_absmax);
//local_B[col+1] = quant_map[160*(local_B_4bit[col/2] & 0x0F)+warp_idx]*T(local_absmax);
local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
//printnonzero<T>(local_B, 128, "");
}
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = local_A[0];
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = local_B[col];
}
else if(warp_id < (WARPS-1))
{
local_A[0] = T(0.0);
smem_A[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*a_tile_offset)] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
local_B[col] = 0.0f;
#pragma unroll 32
for(int col = 0; col < 32; col++)
smem_B[half_warp_lane + (((batch_size_warps*ticktock)+half_warp_id)*b_tile_offset) + (col*16)] = 0.0f;
}
ticktock = ticktock == 0 ? 1 : 0;
if(warp_id == (WARPS-1))
for(int k = 0; k < batch_size_warps; k++)
{
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
}
__syncthreads();
//if(threadIdx.x == 0)
//{
// printnonzero<T>(smem_A, 8*16 + (2*16*(batch_size_warps-1)), "A: ");
// printnonzero<T>(smem_B, 2*batch_size_warps*16*32 + (2*16*(batch_size_warps-1)), "B: ");
//}
if(warp_id != (WARPS-1)){ return; }
// only warp_id == (WARPS-1) from here
int warp_lane = threadIdx.x % 32;
ticktock = ticktock == 0 ? 1 : 0;
for(int k = 0; k < batch_size_warps; k++)
{
//if(warp_lane == 0)
//printf("%i %i %i %i\n", (ticktock*batch_size_warps + k)*a_tile_offset, k, ticktock, threadIdx.x);
rocwmma::load_matrix_sync(a_frag, &(smem_A[(ticktock*batch_size_warps + k)*a_tile_offset]), 16); // 111 mu
rocwmma::load_matrix_sync(b_frag, &(smem_B[(ticktock*batch_size_warps + k)*b_tile_offset]), 16); // 35 mu
rocwmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
}
// 129 mu
if(warp_id == (WARPS-1))
rocwmma::store_matrix_sync(&(smem_C[0]), c_frag, 32, rocwmma::mem_row_major);
//printnonzero<T>(smem_C, 32, "");
if(col_offset + warp_lane < M)
out[col_offset + warp_lane] = smem_C[warp_lane];
#endif
}
// No of 4bit values processed by each thread
#define num_values_4bit 32
template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inference_naive(int M, int N, int K, T * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, T * out, int lda, int ldb, int ldc, int blocksize)
{
// per threadblock:
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, warpSize> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize];
const int warp_idx = threadIdx.x / warpSize;
const int warp_lane = threadIdx.x % warpSize;
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B;
const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f;
unsigned char local_B_4bit[num_values_8bit];
T local_B[num_values_4bit/4];
T local_A[num_values_4bit/4];
__shared__ T quant_map[16];
T local_absmax = T(0.0f);
if (threadIdx.x < 16)
quant_map[threadIdx.x] = T(__ldg(&datatype[threadIdx.x]));
//for(int i = threadIdx.x; i < 16; i++)
//quant_map[i] = T(datatype[i]);
__syncthreads();
// A: [1, K]
// B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit)
{
const int inner_idx_halved = inner_idx/2;
// Since blocksize will always be a power-of-2, we avoid more expensive
// division by the blocksize and instead use a shift operation.
// This is equivalent to (i+threadId.x*NUM_PER_TH)/blocksize.
const int absidx = ((2*offset_B)+inner_idx) >> (31 - __clz(blocksize));
local_absmax = __ldg(&(absmax[absidx]));
if(row_B < M)
{
if((inner_idx_halved + num_values_8bit) < (K/2))
{
// this is the most important for performance considerations
reinterpret_cast<int4(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast<int4*>(B)[(offset_B+(inner_idx_halved))/(num_values_8bit)];
}
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
if((inner_idx_halved) + j < (K/2))
local_B_4bit[j] = B[offset_B+inner_idx_halved + j];
else
local_B_4bit[j] = 0b01110111;
}
}
else
{
#pragma unroll
for(int j = 0; j < (num_values_8bit); j++)
local_B_4bit[j] = 0b01110111;
}
for(int i = 0; i < 4; i++)
{
#pragma unroll
for(int k = 0; k < num_values_8bit/4; k++)
{
#if BNB_BF16_AVAILABLE
local_B[k*2] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*local_absmax;
local_B[k*2 + 1] = quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*local_absmax;
#else
// bf16 multipliation not supported
local_B[k*2] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] >> 4]*(float)local_absmax);
local_B[k*2 + 1] = T((float)quant_map[local_B_4bit[(i*num_values_8bit/4) + k] & 0x0F]*(float)local_absmax);
#endif
}
if(inner_idx+(num_values_4bit/4) + (i*num_values_4bit/4) < K)
{
// this is also relatively important for performance
if(BITS==16)
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/4) + i];
}
else
{
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[0] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 0];
reinterpret_cast<int4(&)[num_values_4bit]>(local_A)[1] = reinterpret_cast<int4*>(A)[inner_idx/(num_values_4bit/8) + (2*i) + 1];
}
}
else
#pragma unroll
for(int k = 0; k < num_values_4bit/4; k++)
if(inner_idx + (i*num_values_4bit/4) + k < K)
local_A[k] = A[inner_idx + k + (i*num_values_4bit/4)];
else
local_A[k] = T(0.0f);
// accumulate in float; small performance hit for Ampere, but lower error for outputs
#pragma unroll
for(int k = 0; k < num_values_4bit/4; k++)
{
#if BNB_BF16_AVAILABLE
local_C += (float)(local_A[k]*local_B[k]);
#else
// bf16 multipliation not supported
local_C += ((float)local_A[k]*(float)local_B[k]);
#endif
}
}
}
local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C);
if(row_B < M && warp_lane == 0)
out[row_B] = T(local_C);
}
template <typename T, int FUNC> __global__ void kfunc(T *A, T *B, T value, long n)
{
for(long i = (blockDim.x*blockIdx.x) + threadIdx.x; i < n; i+=(blockDim.x*gridDim.x))
{
switch(FUNC)
{
case FILL:
A[i] = (T)value;
break;
case ARANGE:
A[i] = (T)i;
break;
case _MUL:
A[i] = A[i]*B[i];
break;
}
}
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template __global__ void kfunc<float, FILL>(float *A, float *B, float value, long n);
template __global__ void kfunc<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template __global__ void kfunc<float, ARANGE>(float *A, float *B, float value, long n);
template __global__ void kfunc<float, _MUL>(float *A, float *B, float value, long n);
// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 16, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 16, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 32, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
// these are not used and make no sense, but the compiler needs them
//template __global__ void gemm_device<float, 32, 128>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 256>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 192>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 160>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 128>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
//template __global__ void gemm_device<float, 32, 32>(int M, int N, int K, float * __restrict__ const A, float* B, float * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 32>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 64>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void gemm_device<half, 16, 96>(int M, int N, int K, half * __restrict__ const A, half* B, half * out, int lda, int ldb, int ldc);
template __global__ void kgemm_4bit_inference<half, 96>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 128>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 160>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference<half, 256>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<half, 128, 16>(int M, int N, int K, half * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, half * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<hip_bfloat16, 128, 16>(int M, int N, int K, hip_bfloat16 * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kgemm_4bit_inference_naive<float, 128, 32>(int M, int N, int K, float * __restrict__ const A, unsigned char *B, float *absmax, const float *datatype, float * out, int lda, int ldb, int ldc, int blocksize);
template __global__ void kspmm_coo_very_sparse_naive<half, 8, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 16, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<half, 32, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 8, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 16, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kspmm_coo_very_sparse_naive<signed char, 32, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB);
template __global__ void kdequant_mm_int32_fp16<4, 512>(int *__restrict__ const A, float *__restrict__ const rowStats, float *__restrict__ const colStats, half *out, half * __restrict__ const bias, const int numRows, const int numCols, const int n);
template __device__ unsigned char dQuantize<0>(float* smem_code, const float rand, float x);
template __device__ unsigned char dQuantize<1>(float* smem_code, const float rand, float x);
#define MAKE_PreconditionOptimizer32bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, hip_bfloat16)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, hip_bfloat16)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, hip_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, hip_bfloat16)
#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(MOMENTUM, hip_bfloat16)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(RMSPROP, hip_bfloat16)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, hip_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, hip_bfloat16)
#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float* state2, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
MAKE_PreconditionOptimizer32bit2State(ADAM, float)
MAKE_PreconditionOptimizer32bit2State(ADAM, half)
MAKE_PreconditionOptimizer32bit2State(ADAM, hip_bfloat16)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, hip_bfloat16)
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<hip_bfloat16, ADAM>(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<hip_bfloat16, ADEMAMIX>(hip_bfloat16* g, hip_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
float *unorm, \
const float beta1, \
const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
const float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit1State(MOMENTUM, half)
MAKE_PreconditionStatic8bit1State(MOMENTUM, float)
MAKE_PreconditionStatic8bit1State(RMSPROP, half)
MAKE_PreconditionStatic8bit1State(RMSPROP, float)
MAKE_PreconditionStatic8bit1State(LION, half)
MAKE_PreconditionStatic8bit1State(LION, float)
MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
MAKE_PreconditionStatic8bit1State(ADAGRAD, float)
#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* max1, float* new_max1, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit1State(MOMENTUM, half)
MAKE_optimizerStatic8bit1State(MOMENTUM, float)
MAKE_optimizerStatic8bit1State(RMSPROP, half)
MAKE_optimizerStatic8bit1State(RMSPROP, float)
MAKE_optimizerStatic8bit1State(LION, half)
MAKE_optimizerStatic8bit1State(LION, float)
MAKE_optimizerStatic8bit1State(ADAGRAD, half)
MAKE_optimizerStatic8bit1State(ADAGRAD, float)
#define MAKE_PreconditionStatic8bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, unsigned char* __restrict__ const state2, \
float *unorm, \
const float beta1, const float beta2, \
const float eps, const int step, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
const float gnorm_scale, \
const int n); \
MAKE_PreconditionStatic8bit2State(ADAM, half)
MAKE_PreconditionStatic8bit2State(ADAM, float)
#define MAKE_optimizerStatic8bit2State(oname, gtype) \
template __global__ void kOptimizerStatic8bit2State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, unsigned char* state2, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, \
const int n); \
MAKE_optimizerStatic8bit2State(ADAM, half)
MAKE_optimizerStatic8bit2State(ADAM, float)
template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<float, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, float *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<hip_bfloat16, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<hip_bfloat16, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n);
template __global__ void kDequantizeBlockwise<hip_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, hip_bfloat16 *out, const int blocksize, const int n);
#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, hip_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, hip_bfloat16, 256, 1)
#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
float* absmax1, \
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, hip_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, hip_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, hip_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, hip_bfloat16, 256, 1)
template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <float.h>
#include <ops_hip.cuh>
#ifndef kernels
#define kernels
__global__ void kQuantize(float* code, float* __restrict__ const A, unsigned char* out, const int n);
__global__ void kDequantize(float* code, unsigned char* A, float* out, const int n);
template <typename T, int BLOCK_SIZE, int NUM_PER_TH, int STOCHASTIC, int DATA_TYPE>
__global__ void kQuantizeBlockwise(
float* code, T* __restrict__ const A, float* absmax, unsigned char* out, float* __restrict__ const rand,
const int rand_offset, const int n
);
template <typename T, int BLOCK_SIZE, int THREADS, int NUM_PER_TH, int DATA_TYPE>
__global__ void
kDequantizeBlockwise(float* code, unsigned char* A, float* absmax, T* out, const int blocksize, const int n);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit2State(
T* g, T* p, float* state1, float* state2, float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const bool skip_zeros,
const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPreconditionOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float beta1, const float beta2, const float eps,
const float weight_decay, const int step, const float lr, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(
T* g, T* p, float* state1, float* unorm, const float max_unorm, const float param_norm, const float beta1,
const float beta2, const float eps, const float weight_decay, const int step, const float lr,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit1State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, float* unorm, const float beta1,
const float beta2, const float eps, const int step, float* __restrict__ const quantiles1, float* max1,
float* new_max1, const float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit1State(
T* p, T* const g, unsigned char* state1, const float* unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* max1, float* new_max1, float weight_decay, const float gnorm_scale,
const int n
);
template <typename T, int OPTIMIZER>
__global__ void kPreconditionOptimizerStatic8bit2State(
T* p, T* __restrict__ const g, unsigned char* __restrict__ const state1, unsigned char* __restrict__ const state2,
float* unorm, const float beta1, const float beta2, const float eps, const int step,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER>
__global__ void kOptimizerStatic8bit2State(
T* p, T* const g, unsigned char* state1, unsigned char* state2, const float* unorm, const float max_unorm,
const float param_norm, const float beta1, const float beta2, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* max1, float* max2,
float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit2StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2,
const float beta3, const float alpha, const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1, const float beta1, const float beta2, const float eps,
const int step, const float lr, float* __restrict__ const quantiles1, float* absmax1, float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n
);
template <typename T, int BLOCK_SIZE, int NUM_VALS>
__global__ void kPercentileClipping(T* __restrict__ g, float* gnorm_vec, int step, const int n);
template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB
);
template <int ITEMS_PER_THREAD, int THREADS>
__global__ void kdequant_mm_int32_fp16(
int* __restrict__ const A, float* __restrict__ const rowStats, float* __restrict__ const colStats, half* out,
half* __restrict__ const bias, const int numRows, const int numCols, const int n
);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kgetRowStats(T* __restrict__ A, float* rowStats, float threshold, int rows, int cols);
template <typename T, int THREADS, int SPARSE_DECOMP>
__global__ void kInt8VectorQuant(T* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols);
template <int THREADS, int ITEMS_PER_THREAD, int TILE_ROWS, int TILE_COLS, int TRANSPOSE, int FORMAT>
__global__ void kTransformRowToFormat(
char* __restrict__ const A, char* out, int rows, int cols, int tiledCols, int outRows, int outCols
);
template <typename T, int BITS, int THREADS>
__global__ void gemm_device(int M, int N, int K, T* __restrict__ const A, T* B, T* out, int lda, int ldb, int ldc);
template <typename T, int THREADS>
__global__ void kgemm_4bit_inference(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc,
int blocksize
);
template <typename T, int THREADS, int BITS>
__global__ void kgemm_4bit_inference_naive(
int M, int N, int K, T* __restrict__ const A, unsigned char* B, float* absmax, const float* datatype, T* out,
int lda, int ldb, int ldc, int blocksize
);
template <typename T, int FUNC> __global__ void kfunc(T* A, T* B, T value, long n);
#endif
// !!! This is a file automatically generated by hipify!!!
#include "hip/hip_runtime.h"
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#include <ops_hip.cuh>
#include <kernels_hip.cuh>
#include <hipcub/hipcub.hpp>
#include <hipblas/hipblas.h>
#include <hipsparse/hipsparse.h>
#ifndef NO_HIPBLASLT
#include <hipblaslt/hipblaslt.h>
#endif
#include <limits>
#include <BinSearch.h>
#include <cassert>
#include <common.h>
#define ERR_NOT_IMPLEMENTED 100
using namespace BinSearch;
using std::cout;
using std::endl;
void quantize(float *code, float *A, unsigned char *out, int n)
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
hipLaunchKernelGGL(( kQuantize), dim3(num_blocks), dim3(1024), 0, 0, code, A, out, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void dequantize(float *code, unsigned char *A, float *out, int n, hipStream_t stream)
{
int num_blocks = n/1024;
num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1;
hipLaunchKernelGGL(( kDequantize), dim3(num_blocks), dim3(1024), 0, stream, code, A, out, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T, int STOCHASTIC, int DATA_TYPE> void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float *rand, int rand_offset, int blocksize, const int n)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
if(blocksize == 4096)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 4096, 4, STOCHASTIC, DATA_TYPE>), dim3(num_blocks), dim3(1024), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 2048)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 2048, 4, 0, DATA_TYPE>), dim3(num_blocks), dim3(512), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 1024)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 1024, 4, 0, DATA_TYPE>), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 512)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 512, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(256), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 256)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 256, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(128), 0, 0, code, A, absmax, out, rand, rand_offset, n);
else if(blocksize == 128)
hipLaunchKernelGGL(( kQuantizeBlockwise<T, 128, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(64), 0, 0, code, A, absmax, out, rand, rand_offset, n);
//else if(blocksize == 64)
// hipLaunchKernelGGL(( kQuantizeBlockwise<T, 64, 2, 0, DATA_TYPE>), dim3(num_blocks), dim3(32), 0, 0, code, A, absmax, out, rand, rand_offset, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T, int DATA_TYPE> void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, hipStream_t stream)
{
int num_blocks = n/blocksize;
num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1;
int tile_size = (DATA_TYPE > 0) ? 1024 : 512;
if(DATA_TYPE > 0)
hipLaunchKernelGGL(( kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize/2, n);
else
hipLaunchKernelGGL(( kDequantizeBlockwise<T, 512, 64, 8, DATA_TYPE>), dim3((n+tile_size-1)/tile_size), dim3(64), 0, stream, code, A, absmax, out, blocksize, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p,
float* state1, float* state2, float *unorm, float max_unorm, float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha,
const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
switch(OPTIMIZER)
{
case ADAM:
case ADEMAMIX:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8>), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
hipLaunchKernelGGL(( kOptimizer32bit2State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
hipLaunchKernelGGL(( kOptimizer32bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case LION:
// in lion, the momentum update after the parameter update
hipLaunchKernelGGL(( kOptimizer32bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
if(max_unorm > 0.0f)
{
CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizer32bit1State<T, OPTIMIZER, 4096, 8>), dim3(num_blocks), dim3(512), 0, 0, g, p, state1, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
break;
}
}
template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
unsigned char* state1, unsigned char* state2,
float *unorm, float max_unorm, float param_norm,
float beta1, float beta2,
float eps, int step, float lr,
float* quantiles1, float* quantiles2,
float* max1, float* max2, float* new_max1, float* new_max2,
float weight_decay,
const float gnorm_scale, int n)
{
int num_blocks = n/4096;
num_blocks = n % 4096 == 0 ? num_blocks : num_blocks + 1;
if(max_unorm > 0.0f){ CUDA_CHECK_RETURN(hipMemset(unorm, 0, 1*sizeof(float))); }
switch(OPTIMIZER)
{
case ADAM:
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
CUDA_CHECK_RETURN(hipMemset(new_max2, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit2State<T, OPTIMIZER>), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, state2, unorm, beta1, beta2, eps, step, quantiles1, quantiles2, max1, max2, new_max1, new_max2, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
hipLaunchKernelGGL(( kOptimizerStatic8bit2State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, max1, max2, new_max1, new_max2, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
hipLaunchKernelGGL(( kOptimizerStatic8bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case LION:
// in lion, the momentum update happens after the parameter update
hipLaunchKernelGGL(( kOptimizerStatic8bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(1024), 0, 0, p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
CUDA_CHECK_RETURN(hipMemset(new_max1, 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER>), dim3(num_blocks), dim3(256), 0, 0, p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
default:
break;
}
}
#define BLOCKSIZE_2STATE 256
#define NUM_2STATE 1
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
T* g,
unsigned char* state1,
unsigned char* state2,
float beta1,
float beta2,
float beta3,
float alpha,
float eps,
int step,
float lr,
float* quantiles1,
float* quantiles2,
float* absmax1,
float* absmax2,
float weight_decay,
const float gnorm_scale,
bool skip_zeros,
int n
) {
int num_blocks = 0;
switch(OPTIMIZER)
{
case ADAM:
case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
hipLaunchKernelGGL(( kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE>), dim3(num_blocks), dim3(BLOCKSIZE_2STATE/NUM_2STATE), 0, 0, p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
case MOMENTUM:
case RMSPROP:
case ADAGRAD:
case LION:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
hipLaunchKernelGGL(( kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE>), dim3(num_blocks), dim3(BLOCKSIZE_1STATE/NUM_1STATE), 0, 0, p, g, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
break;
}
}
template<typename T> void percentileClipping(T * g, float *gnorm_vec, int step, const int n)
{
int num_blocks = n/2048;
num_blocks = n % 2048 == 0 ? num_blocks : num_blocks + 1;
CUDA_CHECK_RETURN(hipMemset(&gnorm_vec[step % 100], 0, 1*sizeof(float)));
hipLaunchKernelGGL(( kPercentileClipping<T, 2048, 4>), dim3(num_blocks), dim3(512), 0, 0, g, gnorm_vec, step, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc)
{
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
hipblasStatus_t status;
#if hipblasVersionMajor >= 3
status = hipblasGemmEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, HIP_R_8I, lda, B, HIP_R_8I, ldb, beta,
C, HIP_R_32I, ldc,
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
#else
status = hipblasGemmEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, HIPBLAS_R_8I, lda, B, HIPBLAS_R_8I, ldb, beta,
C, HIPBLAS_R_32I, ldc,
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
#endif
if (status != HIPBLAS_STATUS_SUCCESS)
{
std::cout << "HIPBLAS ERROR: Status " << status << std::endl;
}
}
void strided_gemmex(Context *context, bool transposeA, bool transposeB, int m, int n, int k, void *A, void *B, void *C, int lda, int ldb, int ldc,
long long int strideA, long long int strideB, long long int strideC, int batchCount)
{
const int falpha = 1;
const int fbeta = 0;
const void * alpha = &falpha;
const void * beta = &fbeta;
hipblasStatus_t status;
//cout << transposeA << transposeB << endl;
//printf("%i %i %i\n", m,n,k);
//printf("%i %i %i\n", lda,ldb,ldc);
//printf("%i %i %i\n", strideA, strideB, strideC);
//printf("%i\n", batchCount);
#if hipblasVersionMajor >= 3
status = hipblasGemmStridedBatchedEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, HIP_R_8I, lda, (long long int)strideA, B, HIP_R_8I, ldb, (long long int)strideB, beta,
C, HIP_R_32I, ldc, (long long int)strideC, batchCount,
HIPBLAS_COMPUTE_32I, HIPBLAS_GEMM_DEFAULT);
#else
status = hipblasGemmStridedBatchedEx(context->m_handle,
transposeA ? HIPBLAS_OP_T : HIPBLAS_OP_N,
transposeB ? HIPBLAS_OP_T : HIPBLAS_OP_N,
m, n, k,
alpha, A, HIPBLAS_R_8I, lda, (long long int)strideA, B, HIPBLAS_R_8I, ldb, (long long int)strideB, beta,
C, HIPBLAS_R_32I, ldc, (long long int)strideC, batchCount,
HIPBLAS_R_32I, HIPBLAS_GEMM_DEFAULT);
#endif
if (status != HIPBLAS_STATUS_SUCCESS)
{
std::cout << "HIPBLAS ERROR: Status " << status << std::endl;
}
}
int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
#ifdef NO_HIPBLASLT
#else
template<int ORDER> hipblasLtOrder_t get_order()
{
switch(ORDER)
{
case ROW:
return HIPBLASLT_ORDER_ROW;
break;
case COL:
return HIPBLASLT_ORDER_COL;
break;
case COL32:
//return HIPBLASLT_ORDER_COL32;
return HIPBLASLT_ORDER_COL;
break;
case COL_TURING:
//return HIPBLASLT_ORDER_COL4_4R2_8C;
return HIPBLASLT_ORDER_COL;
break;
case COL_AMPERE:
//return HIPBLASLT_ORDER_COL32_2R_4R4;
return HIPBLASLT_ORDER_COL;
break;
default:
break;
}
return HIPBLASLT_ORDER_ROW;
}
template hipblasLtOrder_t get_order<ROW>();
template hipblasLtOrder_t get_order<COL>();
template hipblasLtOrder_t get_order<COL32>();
//template hipblasLtOrder_t get_order<COL_TURING>();
//template hipblasLtOrder_t get_order<COL_AMPERE>();
#endif
template<int ORDER> int get_leading_dim(int dim1, int dim2)
{
switch(ORDER)
{
case ROW:
return dim2;
break;
case COL:
return dim1;
break;
default:
return dim1;
break;
/*case COL32:
// 32*row tiles
return dim1*32;
break;
case COL_TURING:
return 32*roundoff(dim1, 8);
break;
case COL_AMPERE:
// 32*32 tiles
return 32*roundoff(dim1, 32);
break;
default:
return 0;
break;
*/
}
}
static std::string hipError_to_string(const hipError_t ret)
{
switch(ret)
{
case hipSuccess:
return "hipSuccess";
case hipErrorInvalidContext:
return "hipErrorInvalidContext";
case hipErrorInvalidKernelFile:
return "hipErrorInvalidKernelFile";
case hipErrorMemoryAllocation:
return "hipErrorMemoryAllocation";
case hipErrorInitializationError:
return "hipErrorInitializationError";
case hipErrorLaunchFailure:
return "hipErrorLaunchFailure";
case hipErrorLaunchOutOfResources:
return "hipErrorLaunchOutOfResources";
case hipErrorInvalidDevice:
return "hipErrorInvalidDevice";
case hipErrorInvalidValue:
return "hipErrorInvalidValue";
case hipErrorInvalidDevicePointer:
return "hipErrorInvalidDevicePointer";
case hipErrorInvalidMemcpyDirection:
return "hipErrorInvalidMemcpyDirection";
case hipErrorUnknown:
return "hipErrorUnknown";
case hipErrorInvalidResourceHandle:
return "hipErrorInvalidResourceHandle";
case hipErrorNotReady:
return "hipErrorNotReady";
case hipErrorNoDevice:
return "hipErrorNoDevice";
case hipErrorPeerAccessAlreadyEnabled:
return "hipErrorPeerAccessAlreadyEnabled";
case hipErrorPeerAccessNotEnabled:
return "hipErrorPeerAccessNotEnabled";
case hipErrorRuntimeMemory:
return "hipErrorRuntimeMemory";
case hipErrorRuntimeOther:
return "hipErrorRuntimeOther";
case hipErrorHostMemoryAlreadyRegistered:
return "hipErrorHostMemoryAlreadyRegistered";
case hipErrorHostMemoryNotRegistered:
return "hipErrorHostMemoryNotRegistered";
case hipErrorMapBufferObjectFailed:
return "hipErrorMapBufferObjectFailed";
case hipErrorTbd:
return "hipErrorTbd";
default:
throw std::runtime_error("unknown hipError");
}
}
template <int DTYPE_OUT, int SCALE_ROWS> int igemmlt(
hipblasLtHandle_t ltHandle,
int m, int n, int k,
const int8_t *A,
const int8_t *B,
void *C,
float *row_scale,
int lda, int ldb, int ldc,
hipStream_t stream
) {
#ifdef NO_HIPBLASLT
return ERR_NOT_IMPLEMENTED;
#else
// Calculate C = A^T @ B, in col-major layout.
//
// Use the IMMA kernels requires:
// * A must be transposed and B must be non-transposed.
// * Dimensions m and k must be multiples of 4.
// * All pointers must be 4-byte aligned; 16-byte alignment preferred.
int has_error = 0;
const int64_t max_workspace_size = 0;//set to 0 to avoid choosing GSU kernel
hipblasLtMatmulDesc_t matmulDesc;
hipblasLtMatrixLayout_t aDesc, bDesc, cDesc;
hipblasOperation_t opT = HIPBLAS_OP_T;
hipDataType outType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_8I;
hipDataType scaleType = DTYPE_OUT == 32 ? HIP_R_32I : HIP_R_32F;
hipblasLtPointerMode_t pointerMode = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&aDesc, HIP_R_8I, m, k, lda));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&bDesc, HIP_R_8I, m, n, ldb));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutCreate(&cDesc, outType, k, n, ldc));
// Default layout order is col major
has_error |= checkHipblasStatus(hipblasLtMatmulDescCreate(&matmulDesc, HIPBLAS_COMPUTE_32I, scaleType));
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(matmulDesc, HIPBLASLT_MATMUL_DESC_TRANSA, &opT, sizeof(opT)));
if (DTYPE_OUT == 32) {
/* Algo and workspace TODO: need to rework to not be duplicated */
// Set User Preference attributes
hipblasLtMatmulPreference_t pref;
checkHipblasStatus(hipblasLtMatmulPreferenceCreate(&pref));
checkHipblasStatus(
hipblasLtMatmulPreferenceSetAttribute(pref,
HIPBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&max_workspace_size,
sizeof(max_workspace_size)));
const int request_solutions = 1;
hipblasLtMatmulHeuristicResult_t heuristicResult[request_solutions];
int returnedAlgoCount = 0;
checkHipblasStatus(hipblasLtMatmulAlgoGetHeuristic(ltHandle,
matmulDesc,
aDesc,
bDesc,
cDesc,
cDesc,
pref,
request_solutions,
heuristicResult,
&returnedAlgoCount));
if (returnedAlgoCount == 0)
{
has_error = 1;
fprintf(stderr, "Error: Matmul Algo Heuristic didn't return algorithms\n");
} else {
int alpha = 1, beta = 0;
has_error |= checkHipblasStatus(hipblasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int32_t*)C, cDesc,
(int32_t*)C, cDesc,
&heuristicResult[0].algo, NULL, 0, stream
));
}
} else {
// This path is unlikely to be used, as 8-bit accumulation can lead to likely overflows.
if (!SCALE_ROWS) {
float alpha = 1.0f, beta = 0.0f;
has_error |= checkHipblasStatus(hipblasLtMatmul(
ltHandle, matmulDesc,
&alpha, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
} else {
hipblasLtPointerMode_t alphaVec = HIPBLASLT_POINTER_MODE_ALPHA_DEVICE_VECTOR_BETA_HOST;
float beta = 0.0f;
has_error |= checkHipblasStatus(hipblasLtMatmulDescSetAttribute(
matmulDesc,
HIPBLASLT_MATMUL_DESC_POINTER_MODE,
&pointerMode,
sizeof(alphaVec)
));
has_error |= checkHipblasStatus(hipblasLtMatmul(
ltHandle, matmulDesc,
row_scale, A, aDesc,
B, bDesc, &beta,
(int8_t*)C, cDesc,
(int8_t*)C, cDesc,
NULL, NULL, 0, stream
));
}
}
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(cDesc));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(bDesc));
has_error |= checkHipblasStatus(hipblasLtMatrixLayoutDestroy(aDesc));
has_error |= checkHipblasStatus(hipblasLtMatmulDescDestroy(matmulDesc));
if(has_error == 1)
printf("error detected");
return has_error;
#endif // NO_HIPBLASLT
}
int fill_up_to_nearest_multiple(int value, int multiple)
{
return value + (value % multiple == 0 ? 0 : (multiple - (value % multiple)));
}
void dequant_mm_int32_fp16(int *A, float *rowStats, float *colStats, half *out, half *bias, int numRows, int numCols, hipStream_t stream)
{
const int threads = 512;
const int num_per_thread = 4;
const int num_per_block = threads * num_per_thread;
const int n = numRows*numCols;
const int num_blocks = (n + num_per_block - 1) / num_per_block;
hipLaunchKernelGGL(( kdequant_mm_int32_fp16<num_per_thread, threads>), dim3(num_blocks), dim3(threads), 0, stream, A, rowStats, colStats, out, bias, numRows, numCols, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void int8VectorQuant(half * __restrict__ A, int8_t *out, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) { if (threshold == 0.0) {
kInt8VectorQuant<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
} else {
kInt8VectorQuant<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, out, rowStats, threshold, rows, cols);
}
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void getRowStats(half *A, float *rowStats, float threshold, int rows, int cols, hipStream_t stream) {
if (threshold == 0.0)
kgetRowStats<half, 1024, 0><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
else
kgetRowStats<half, 1024, 1><<<rows, 1024, 0, stream>>>(A, rowStats, threshold, rows, cols);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
void spmm_coo(hipsparseHandle_t handle, int *A_rowidx, int *A_colidx, half *A_vals, int A_nnz, int A_rows, int A_cols, int B_cols, int ldb, half *B, int ldc, half* C, bool transposed_B)
{
#ifdef NO_HIPBLASLT
#else
hipsparseSpMatDescr_t descA;
hipsparseDnMatDescr_t descB, descC;
float alpha = 1.0f;
float beta = 0.0f;
void *dBuffer = NULL;
size_t bufferSize = 0;
CHECK_HIPSPARSE( hipsparseCreateCoo(&descA, A_rows, A_cols, A_nnz,
A_rowidx, A_colidx, A_vals,
HIPSPARSE_INDEX_32I,
HIPSPARSE_INDEX_BASE_ZERO, HIP_R_16F) );
// Create dense matrix C
CHECK_HIPSPARSE( hipsparseCreateDnMat(&descC, A_rows, B_cols, ldc, C,
HIP_R_16F, HIPSPARSE_ORDER_ROW) );
// Create dense matrix B
if(transposed_B)
{
int tmp = A_cols;
A_cols = B_cols;
B_cols = tmp;
}
CHECK_HIPSPARSE( hipsparseCreateDnMat(&descB, A_cols, B_cols, ldb, B,
HIP_R_16F, HIPSPARSE_ORDER_ROW) );
// allocate an external buffer if needed
CHECK_HIPSPARSE( hipsparseSpMM_bufferSize(
handle,
HIPSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, HIP_R_32F,
HIPSPARSE_SPMM_ALG_DEFAULT, &bufferSize) );
CUDA_CHECK_RETURN( hipMalloc(&dBuffer, bufferSize) );
// execute SpMM
CHECK_HIPSPARSE( hipsparseSpMM(handle,
HIPSPARSE_OPERATION_NON_TRANSPOSE,
transposed_B ? HIPSPARSE_OPERATION_TRANSPOSE : HIPSPARSE_OPERATION_NON_TRANSPOSE,
&alpha, descA, descB, &beta, descC, HIP_R_32F,
HIPSPARSE_SPMM_ALG_DEFAULT, dBuffer));
// destroy matrix/vector descriptors
CHECK_HIPSPARSE( hipsparseDestroySpMat(descA) );
CHECK_HIPSPARSE( hipsparseDestroyDnMat(descB) );
CHECK_HIPSPARSE( hipsparseDestroyDnMat(descC) );
CUDA_CHECK_RETURN( hipFree(dBuffer) );
#endif
}
template <typename T, int BITS> void spmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB)
{
hipLaunchKernelGGL(( kspmm_coo_very_sparse_naive<T, 8, BITS>), dim3(nnz_rows), dim3(256), 0, 0, max_count, max_idx, offset_rowidx, rowidx, colidx, values, B, out, dequant_stats, nnz, rowsA, rowsB, colsB);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T> void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits)
{
int num_blocks = (m+31)/32;
if(bits == 32)
hipLaunchKernelGGL(( gemm_device<T, 32, 32>), dim3(num_blocks), dim3(32), 0, 0, m, n, k, A, B, out, lda, ldb, ldc);
if(bits == 16)
hipLaunchKernelGGL(( gemm_device<T, 16, 160>), dim3(num_blocks), dim3(160), 0, 0, m, n, k, A, B, out, lda, ldb, ldc);
}
template <typename T> void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize)
{
int num_blocks = (m+31)/32;
hipLaunchKernelGGL(( kgemm_4bit_inference<T, 96>), dim3(num_blocks), dim3(96), 0, 0, m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize);
}
template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream)
{
//warpsize - 32
int num_blocks = (m+3)/4;
//warpsize - 64
if (warpSize == 64) {
num_blocks = (m+1)/2;
}
hipLaunchKernelGGL(( kgemm_4bit_inference_naive<T, 128, BITS>), dim3(num_blocks), dim3(128), 0, stream, m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
template <typename T, int FUNC> void func(T *A, T *B, T value, long n)
{
int threads = 512;
int blocks = n/threads;
blocks = n % threads == 0 ? blocks : blocks + 1;
blocks = blocks > 65535 ? 65535 : blocks;
hipLaunchKernelGGL(( kfunc<T, FUNC>), dim3(blocks), dim3(512), 0, 0, A, B, value, n);
CUDA_CHECK_RETURN(hipPeekAtLastError());
}
//==============================================================
// TEMPLATE DEFINITIONS
//==============================================================
template void func<float, FILL>(float *A, float *B, float value, long n);
template void func<unsigned char, FILL>(unsigned char *A, unsigned char *B, unsigned char value, long n);
template void func<float, ARANGE>(float *A, float *B, float value, long n);
template void func<float, _MUL>(float *A, float *B, float value, long n);
template void gemm_4bit_inference<half>(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize);
template void gemm_4bit_inference_naive<half, 16>(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);
template void gemm_4bit_inference_naive<hip_bfloat16, 16>(int m, int n, int k, hip_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, hip_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);
template void gemm_4bit_inference_naive<float, 32>(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, hipStream_t stream);
//template void gemm_host<float>(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits);
template void gemm_host<half>(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits);
template void spmm_coo_very_sparse_naive<half, 16>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, half *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template void spmm_coo_very_sparse_naive<signed char, 8>(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, signed char *B, half *out, float *dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB);
template int igemmlt<32, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream);
template int igemmlt<8, 0>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream);
template int igemmlt<8, 1>(hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t *A, const int8_t *B, void *C, float *row_scale, int lda, int ldb, int ldc, hipStream_t stream);
template void quantizeBlockwise<half, 1, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, General8bit>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, FP4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<half, 0, NF4>(float * code, half *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 1, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, General8bit>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, FP4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<float, 0, NF4>(float * code, float *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<hip_bfloat16, 1, General8bit>(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<hip_bfloat16, 0, General8bit>(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<hip_bfloat16, 0, FP4>(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void quantizeBlockwise<hip_bfloat16, 0, NF4>(float * code, hip_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n);
template void dequantizeBlockwise<float, General8bit>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<float, FP4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<float, NF4>(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<half, General8bit>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<half, FP4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<half, NF4>(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<hip_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<hip_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream);
template void dequantizeBlockwise<hip_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, hip_bfloat16 *out, int blocksize, const int n, hipStream_t stream);
#define MAKE_optimizer32bit(name, gtype) \
template void optimizer32bit<gtype, name>(gtype* g, gtype* p, \
float* state1, float* state2, float* unorm, float max_unorm, float param_norm, \
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
MAKE_optimizer32bit(ADAM, half)
MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, hip_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, hip_bfloat16)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, hip_bfloat16)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, hip_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, hip_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, hip_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
#define MAKE_optimizerStatic8bit(name, gtype) \
template void optimizerStatic8bit<gtype, name>(gtype* p, gtype* g, unsigned char* state1, unsigned char* state2, \
float *unorm, float max_unorm, float param_norm, \
float beta1, float beta2, \
float eps, int step, float lr, \
float* quantiles1, float* quantiles2, \
float* max1, float* max2, float* new_max1, float* new_max2, \
float weight_decay, \
const float gnorm_scale, int n); \
MAKE_optimizerStatic8bit(ADAM, half)
MAKE_optimizerStatic8bit(ADAM, float)
MAKE_optimizerStatic8bit(MOMENTUM, half)
MAKE_optimizerStatic8bit(MOMENTUM, float)
MAKE_optimizerStatic8bit(RMSPROP, half)
MAKE_optimizerStatic8bit(RMSPROP, float)
MAKE_optimizerStatic8bit(LION, half)
MAKE_optimizerStatic8bit(LION, float)
MAKE_optimizerStatic8bit(ADAGRAD, half)
MAKE_optimizerStatic8bit(ADAGRAD, float)
#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \
MAKE_optimizerStatic8bitBlockwise(half, ADAM);
MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(hip_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
template void percentileClipping(float * g, float *gnorm_vec, int step, const int n);
template void percentileClipping(half * g, float *gnorm_vec, int step, const int n);
template int get_leading_dim<ROW>(int dim1, int dim2);
template int get_leading_dim<COL>(int dim1, int dim2);
template int get_leading_dim<COL32>(int dim1, int dim2);
// !!! This is a file automatically generated by hipify!!!
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
#ifndef ops_H
#define ops_H
#include <assert.h>
#include <cstdint>
#include <iostream>
#include <stdio.h>
#include <unistd.h>
#include <functional>
#include <hip/hip_fp16.h>
#include <hip/hip_runtime_api.h>
#include <hipblaslt/hipblaslt.h>
#include <hipsparse/hipsparse.h>
#include <rocblas/rocblas.h>
#include <vector>
#define CUDA_CHECK_RETURN(value) \
{ \
hipError_t _m_cudaStat = value; \
if (_m_cudaStat != hipSuccess) { \
fprintf(stderr, "Error %s at line %d in file %s\n", hipGetErrorString(_m_cudaStat), __LINE__, __FILE__); \
exit(1); \
} \
}
#define CHECK_HIPSPARSE(value) \
{ \
hipsparseStatus_t _m_hipStat = value; \
if (_m_hipStat != HIPSPARSE_STATUS_SUCCESS) { \
fprintf( \
stderr, "Error %s at line %d in file %s\n", hipsparseGetErrorString(_m_hipStat), __LINE__, __FILE__ \
); \
exit(1); \
} \
}
inline void checkHipStatus(hipError_t status) {
if (status != hipSuccess) {
printf("hip API failed with status %d: %s\n", status, hipGetErrorString(status));
throw std::logic_error("hip API failed");
}
}
inline int checkHipblasStatus(hipblasStatus_t status) {
if (status != HIPBLAS_STATUS_SUCCESS) {
printf("hipBLAS API failed with status %d\n", status);
// throw std::logic_error("cuBLAS API failed");
return 1;
}
return 0;
}
typedef enum Operations_t {
ksmul = 0,
} Operations_t;
typedef enum Optimizer_t {
ADAM = 0,
MOMENTUM = 1,
RMSPROP = 2,
LARS = 3,
ADAGRAD = 4,
LION = 5,
ADEMAMIX = 6,
} Optimizer_t;
typedef enum Transform_t {
ROW = 0,
COL = 1,
COL32 = 2,
COL_TURING = 3,
COL_AMPERE = 4,
} Transform_t;
typedef enum DataType_t {
General8bit = 0,
FP4 = 1,
NF4 = 2,
} DataType_t;
typedef enum Funcs_t {
FILL = 0,
ARANGE = 1,
_MUL = 2,
} Funcs_t;
class Context {
public:
rocblas_handle m_handle;
Context() {
rocblas_handle handle;
rocblas_create_handle(&handle);
m_handle = handle;
}
};
class ContextLt {
public:
hipblasLtHandle_t m_handle;
ContextLt() {
hipblasLtHandle_t handle;
hipblasLtCreate(&handle);
m_handle = handle;
}
};
class ContextHipsparse {
public:
hipsparseHandle_t m_handle;
ContextHipsparse() {
hipsparseHandle_t handle;
hipsparseCreate(&handle);
m_handle = handle;
}
};
void quantize(float* code, float* A, unsigned char* out, int n);
void dequantize(float* code, unsigned char* A, float* out, int n, hipStream_t stream);
template <typename T, int STOCHASTIC, int DATA_TYPE>
void quantizeBlockwise(
float* code, T* A, float* absmax, unsigned char* out, float* rand, int rand_offset, int blocksize, const int n
);
template <typename T, int DATA_TYPE>
void dequantizeBlockwise(
float* code, unsigned char* A, float* absmax, T* out, int block_size, const int n, hipStream_t stream
);
template <typename T, int OPTIMIZER>
void optimizer32bit(
T* g, T* p, float* state1, float* state2, float* unorm, float max_unorm, float param_norm, float beta1, float beta2,
float beta3, float alpha, float eps, float weight_decay, int step, float lr, const float gnorm_scale,
bool skip_zeros, int n
);
template <typename T, int OPTIMIZER>
void optimizerStatic8bit(
T* p, T* g, unsigned char* state1, unsigned char* state2, float* unorm, float max_unorm, float param_norm,
float beta1, float beta2, float eps, int step, float lr, float* quantiles1, float* quantiles2, float* max1,
float* max2, float* new_max1, float* new_max2, float weight_decay, const float gnorm_scale, int n
);
template <typename T, int OPTIMIZER>
void optimizerStatic8bitBlockwise(
T* p, T* g, unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2,
float weight_decay, const float gnorm_scale, bool skip_zeros, int n
);
template <typename T> void percentileClipping(T* g, float* gnorm_vec, int step, const int n);
void gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc
);
void strided_gemmex(
Context* context, bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda,
int ldb, int ldc, long long int strideA, long long int strideB, long long int strideC, int batchCount
);
template <int DTYPE_OUT, int SCALE_ROWS>
int igemmlt(
hipblasLtHandle_t ltHandle, int m, int n, int k, const int8_t* A, const int8_t* B, void* C, float* row_scale,
int lda, int ldb, int ldc, hipStream_t stream
);
void cutlass_igemm(
bool transposeA, bool transposeB, int m, int n, int k, void* A, void* B, void* C, int lda, int ldb, int ldc
);
void dequant_mm_int32_fp16(
int* A, float* rowStats, float* colStats, half* out, half* bias, int numRows, int numCols, hipStream_t stream
);
void getRowStats(half* A, float* rowStats, float threshold, int rows, int cols, hipStream_t stream);
void int8VectorQuant(
half* __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols, hipStream_t stream
);
void spmm_coo(
hipsparseHandle_t handle, int* A_rowidx, int* A_colidx, half* A_vals, int A_nnz, int A_rows, int A_cols, int B_cols,
int ldb, half* B, int ldc, half* C, bool transposed_B
);
template <typename T, int BITS>
void spmm_coo_very_sparse_naive(
int* max_count, int* max_idx, int* offset_rowidx, int* rowidx, int* colidx, half* values, T* B, half* out,
float* dequant_stats, int nnz_rows, int nnz, int rowsA, int rowsB, int colsB
);
void matmul4bite(half* A, unsigned char* B, half* out, int lda, int ldb, int rowsA, int colsA, int colsB);
template <typename T> void gemm_host(int m, int n, int k, T* A, T* B, T* out, int lda, int ldb, int ldc, int bits);
template <typename T>
void gemm_4bit_inference(
int m, int n, int k, T* A, unsigned char* B, float* absmax, T* out, int lda, int ldb, int ldc, int blocksize
);
template <typename T, int BITS>
void gemm_4bit_inference_naive(
int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
int blocksize, hipStream_t stream
);
template <typename T, int FUNC> void func(T* A, T* B, T value, long n);
#endif
...@@ -6,11 +6,29 @@ ...@@ -6,11 +6,29 @@
#if BUILD_CUDA #if BUILD_CUDA
#include <ops.cuh> #include <ops.cuh>
#endif #endif
#if BUILD_HIP
#include <ops_hip.cuh>
#endif
#if BUILD_MPS #if BUILD_MPS
// #include <mps_ops.h> // #include <mps_ops.h>
#endif #endif
#include <cpu_ops.h> #include <cpu_ops.h>
// Compatibility between HIP/CUDA APIs
#if BUILD_HIP
#define cudaStream_t hipStream_t
#define __nv_bfloat16 hip_bfloat16
#define cublasLtHandle_t hipblasLtHandle_t
#define ContextCusparse ContextHipsparse
#define cusparseHandle_t hipsparseHandle_t
#define cudaMallocManaged hipMallocManaged
#define cudaMemAttachHost hipMemAttachHost
#define cudaPeekAtLastError hipPeekAtLastError
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDevAttrConcurrentManagedAccess hipDeviceAttributeConcurrentManagedAccess
#define cudaMemPrefetchAsync hipMemPrefetchAsync
#endif
// We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary. // We cannot call templated code from C, so we wrap the template in a C compatible call here if necessary.
// We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to // We use macro functions to expand all the different optimizers. Looks ugly, and is ugly, but its better than to
// maintain all that boilerplate // maintain all that boilerplate
...@@ -18,7 +36,7 @@ ...@@ -18,7 +36,7 @@
// UNMANGLED CALLS // UNMANGLED CALLS
//=================================================================================== //===================================================================================
#if BUILD_CUDA #if BUILD_CUDA || BUILD_HIP
// void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc) // void gemm_host_fp32(int M, int N, int K, float * A, float* B, float * out, int lda, int ldb, int ldc)
//{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); } //{ gemm_host<float>(M, N, K, A, B, out, lda, ldb, ldc, 32); }
...@@ -291,7 +309,7 @@ void spmm_coo_very_sparse_naive_int8( ...@@ -291,7 +309,7 @@ void spmm_coo_very_sparse_naive_int8(
#endif #endif
extern "C" { extern "C" {
#if BUILD_CUDA #if BUILD_CUDA || BUILD_HIP
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); } void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) { void cdequantize(float* code, unsigned char* A, float* out, int n, cudaStream_t stream) {
......
...@@ -7,6 +7,8 @@ from typing import Any ...@@ -7,6 +7,8 @@ from typing import Any
import torch import torch
from bitsandbytes.cextension import HIP_ENVIRONMENT
test_dims_rng = random.Random(42) test_dims_rng = random.Random(42)
...@@ -21,7 +23,7 @@ def get_available_devices(): ...@@ -21,7 +23,7 @@ def get_available_devices():
# If the environment variable is set, use it directly. # If the environment variable is set, use it directly.
return [os.environ["BNB_TEST_DEVICE"]] return [os.environ["BNB_TEST_DEVICE"]]
devices = ["cpu"] devices = [] if HIP_ENVIRONMENT else ["cpu"]
if hasattr(torch, "accelerator"): if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API. # PyTorch 2.6+ - determine accelerator using agnostic API.
......
import pytest import pytest
from bitsandbytes.cextension import get_cuda_bnb_library_path from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cuda_specs import CUDASpecs from bitsandbytes.cuda_specs import CUDASpecs
...@@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs: ...@@ -13,11 +13,13 @@ def cuda120_spec() -> CUDASpecs:
) )
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec):
monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) monkeypatch.delenv("BNB_CUDA_VERSION", raising=False)
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120"
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm")
def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog):
monkeypatch.setenv("BNB_CUDA_VERSION", "110") monkeypatch.setenv("BNB_CUDA_VERSION", "110")
assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110"
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
from tests.helpers import ( from tests.helpers import (
BOOLEAN_TUPLES, BOOLEAN_TUPLES,
TRUE_FALSE, TRUE_FALSE,
...@@ -92,7 +93,10 @@ class Test8BitBlockwiseQuantizeFunctional: ...@@ -92,7 +93,10 @@ class Test8BitBlockwiseQuantizeFunctional:
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested")) @pytest.mark.parametrize("nested", TRUE_FALSE, ids=id_formatter("nested"))
@pytest.mark.parametrize("blocksize", [4096, 2048, 1024, 512, 256, 128, 64]) @pytest.mark.parametrize(
"blocksize",
[4096, 2048, 1024, 512, 256, 128, 64] if not HIP_ENVIRONMENT else [4096, 2048, 1024, 512, 256, 128],
)
@pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed")) @pytest.mark.parametrize("signed", TRUE_FALSE, ids=id_formatter("signed"))
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed): def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
iters = 100 iters = 100
...@@ -823,6 +827,7 @@ class TestLLMInt8Functional: ...@@ -823,6 +827,7 @@ class TestLLMInt8Functional:
torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2) torch.testing.assert_close(A * (idx == 0), A2, rtol=0.05, atol=1.5e-2)
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required")
class TestSpMMFunctional: class TestSpMMFunctional:
@pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1")) @pytest.mark.parametrize("dim1", [256, 1024], ids=id_formatter("dim1"))
...@@ -1100,7 +1105,10 @@ class TestQuantize4BitFunctional: ...@@ -1100,7 +1105,10 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512, 1024, 2048, 4096]) @pytest.mark.parametrize(
"blocksize",
[64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096],
)
def test_4bit_quant(self, device, dtype, quant_type, blocksize): def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
pytest.skip("This configuration is not supported on HPU.") pytest.skip("This configuration is not supported on HPU.")
...@@ -1135,7 +1143,7 @@ class TestQuantize4BitFunctional: ...@@ -1135,7 +1143,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128], ids=id_formatter("blocksize")) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128], ids=id_formatter("blocksize"))
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=describe_dtype)
def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype): def test_4bit_compressed_stats(self, device, quant_type, blocksize, dtype):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
...@@ -1201,6 +1209,9 @@ class TestQuantize4BitFunctional: ...@@ -1201,6 +1209,9 @@ class TestQuantize4BitFunctional:
# torch.cuda.synchronize() # torch.cuda.synchronize()
# print((time.time()-t0)/iters*1e6) # print((time.time()-t0)/iters*1e6)
@pytest.mark.skipif(
HIP_ENVIRONMENT, reason="gemv 4bit tests are partially enabled on MI300, others being fixed for warpsize 64"
)
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}") @pytest.mark.parametrize("double_quant", TRUE_FALSE, ids=lambda double_quant: f"DQ_{double_quant}")
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"])
...@@ -1361,6 +1372,10 @@ class TestQuantize4BitFunctional: ...@@ -1361,6 +1372,10 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"]) @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
@pytest.mark.skipif(
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
reason="this test is not supported on ROCm with gfx90a architecture yet",
)
def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant): def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3): if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3") pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
......
...@@ -8,6 +8,7 @@ import pytest ...@@ -8,6 +8,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
describe_dtype, describe_dtype,
...@@ -191,7 +192,7 @@ def test_linear_serialization( ...@@ -191,7 +192,7 @@ def test_linear_serialization(
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_copy_param(device, quant_type, blocksize, compress_statistics): def test_copy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type): if device == "hpu" and not is_supported_on_hpu(quant_type):
...@@ -213,7 +214,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics): ...@@ -213,7 +214,7 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type): if device == "hpu" and not is_supported_on_hpu(quant_type):
...@@ -242,7 +243,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics): ...@@ -242,7 +243,7 @@ def test_deepcopy_param(device, quant_type, blocksize, compress_statistics):
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) @pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128]) @pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics")) @pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics): def test_params4bit_real_serialization(device, quant_type, blocksize, compress_statistics):
if device == "hpu" and not is_supported_on_hpu(quant_type): if device == "hpu" and not is_supported_on_hpu(quant_type):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment