Unverified Commit b72b766e authored by pnunna93's avatar pnunna93 Committed by GitHub
Browse files

Fix for warpSize deprecation in ROCm 7.0 (#1762)



* Port ROCm changes from multi-backend-refactor branch

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update ops.py

* Update functional.py

* Update functional.py

* Update functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_ops.py

* Update test_functional.py

* Update test_functional.py

* Update functional.py

* Update functional.py

* Update ops.py

* Update ops.py

* Update test_functional.py

* Update test_functional.py

* Update cextension.py

* Update cuda_specs.py

* Update cuda_specs.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_cuda_setup_evaluator.py

* Update test_functional.py

* Update modules.py

* Update modules.py

* Update ops.py

* Update test_linear4bit.py

* Update ops.py

* Update ops.py

* Update test_linear4bit.py

* Update test_linear4bit.py

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Update python-package.yml

* Create build-rocm.sh

* Update cuda_specs.py

* Fix trailing whitespace

* Remove conflicts.diff

* update for hipblasVersionMajor >=3

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update main.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Update test_linear4bit.py

* Lint

* Lint

* Update helpers.py

* Update test_functional.py

* Update test_linear4bit.py

* Update test_ops.py

* Lint

* Update pythonInterface.cpp

* lint fix

* lint

* Update pythonInterface.cpp

* revert permissions change

* Fix indentation

* Update kernels_hip.cuh

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update kernels_hip.cuh

* Update kernels.hip

* Update kernels.hip

* Update ops.hip

* Update ops_hip.cuh

* Update ops.hip

* Update CMakeLists.txt

* Update functional.py

* Update cextension.py

* Update cextension.py

* warpSize is being made non constexpr in ROCm 7.0

* Merge pull request #90 from ROCm/IFU-rocm_enabled-09-23-2025

Ifu rocm enabled 09 23 2025

* Fix typo

* unskip test_4bit_quant

---------
Co-authored-by: default avatarMISHANMAURYA <118961433+MISHANMAURYA@users.noreply.github.com>
Co-authored-by: default avatarMISHANMAUYRA <mishanmaurya31081@gmail.com>
Co-authored-by: default avataramcamd <andrew.chapman@amd.com>
Co-authored-by: default avatarPrasanth Nunna <root@banff-cyxtera-s78-1.amd.com>
Co-authored-by: default avatarsstamenk <strahinja.stamenkovic@amd.com>
parent bdb8b2b7
...@@ -70,6 +70,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu") ...@@ -70,6 +70,7 @@ elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
message(FATAL_ERROR "XPU is not supported on macOS" ) message(FATAL_ERROR "XPU is not supported on macOS" )
endif() endif()
set(BUILD_CUDA OFF) set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF) set(BUILD_MPS OFF)
set(BUILD_XPU ON) set(BUILD_XPU ON)
else() else()
......
...@@ -19,37 +19,42 @@ ...@@ -19,37 +19,42 @@
#define NUM 4 #define NUM 4
#define NUM_BLOCK 4096 #define NUM_BLOCK 4096
__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0}; __device__ static float fp4_dequantization_lut[8] = {
0.0f, // 0b000
0.005208333333f, // 0b001
0.66666667f, // 0b010
1.0f, // 0b011
0.33333333f, // 0b100
0.5f, // 0b101
0.16666667f, // 0b110
0.25f // 0b111
};
__device__ static float nf4_dequantization_lut[16] = {
-1.0f, // 0b0000
-0.6961928009986877f, // 0b0001
-0.5250730514526367f, // 0b0010
-0.39491748809814453f, // 0b0011
-0.28444138169288635f, // 0b0100
-0.18477343022823334f, // 0b0101
-0.09105003625154495f, // 0b0110
0.0f, // 0b0111
0.07958029955625534f, // 0b1000
0.16093020141124725f, // 0b1001
0.24611230194568634f, // 0b1010
0.33791524171829224f, // 0b1011
0.44070982933044434f, // 0b1100
0.5626170039176941f, // 0b1101
0.7229568362236023f, // 0b1110
1.0f // 0b1111
};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda // source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
// Luckily we have atomicmax and atomicmin in ROCm // Luckily we have atomicmax and atomicmin in ROCm
__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
{ return fp4_dequantization_lut[val & 0b111] * sign;
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 111
return 0.25000000f*absmax*sign; // 1111
else
return 0.16666667f*absmax*sign; // 1110
else
if((val & 0b0001) == 1) // 110
return 0.50000000f*absmax*sign; // 1101
else
return 0.33333333f*absmax*sign; // 1100
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 1.00000000f*absmax*sign; // 1011
else
return 0.66666667f*absmax*sign; // 1010
else
if((val & 0b0001) == 1) // 100
return 5.208333333e-03f*absmax*sign; // 1001
else
return 0.00000000f*absmax*sign; // 1000
} }
__device__ unsigned char dQuantizeFP4(float x) __device__ unsigned char dQuantizeFP4(float x)
...@@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x) ...@@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x)
return 0b0000+sign; return 0b0000+sign;
} }
__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
{
// the values for this tree was generated by test_normal_map_tree
// in the file tests/test_functional.py
if((val & 0b1000) == 8)
if((val & 0b0100) == 4) // 1
if((val & 0b0010) == 2) // 11
if((val & 0b0001) == 1) // 111
return 1.0f;
else
return 0.7229568362236023f;
else
if((val & 0b0001) == 1) // 110
return 0.5626170039176941f;
else
return 0.44070982933044434f;
else
if((val & 0b0010) == 2) //10
if((val & 0b0001) == 1) // 101
return 0.33791524171829224f;
else
return 0.24611230194568634f;
else
if((val & 0b0001) == 1) // 100
return 0.16093020141124725f;
else
return 0.07958029955625534f;
else
if((val & 0b0100) == 4) // 0
if((val & 0b0010) == 2) //01
if((val & 0b0001) == 1) // 011
return 0.0f;
else
return -0.09105003625154495f;
else
if((val & 0b0001) == 1) // 010
return -0.18477343022823334f;
else
return -0.28444138169288635f;
else
if((val & 0b0010) == 2) //00
if((val & 0b0001) == 1) // 001
return -0.39491748809814453f;
else
return -0.5250730514526367f;
else
if((val & 0b0001) == 1) // 000
return -0.6961928009986877f;
else
return -1.0f;
}
__device__ unsigned char dQuantizeNF4(float x) __device__ unsigned char dQuantizeNF4(float x)
{ {
...@@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0); LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
} }
unsigned char packed_4bit = 0;
switch(DATA_TYPE) switch(DATA_TYPE)
{ {
case General8bit: case General8bit:
...@@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float ...@@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++) for(int j = 0; j < NUM_PER_TH/2; j++)
{ {
packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4; qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max); qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
} }
break; break;
case NF4: case NF4:
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++) for(int j = 0; j < NUM_PER_TH/2; j++)
{ {
packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4; qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max); qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
qvals[j] = packed_4bit;
} }
break; break;
} }
...@@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs ...@@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH #pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++) for(int j = 0; j < NUM_PER_TH; j++)
{ {
vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max); vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max); vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
} }
break; break;
case NF4: case NF4:
...@@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16( ...@@ -2109,7 +2057,11 @@ __global__ void kdequant_mm_int32_fp16(
#define DENORM 1.0f/127.0f #define DENORM 1.0f/127.0f
#define MAX_SPARSE_COUNT 32 #define MAX_SPARSE_COUNT 32
#define SMEM_SIZE 8*256 #define SMEM_SIZE 8*256
#define WARP_SIZE warpSize #if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
template <typename T, int SPMM_ITEMS, int BITS> template <typename T, int SPMM_ITEMS, int BITS>
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB) __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
{ {
...@@ -2503,7 +2455,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i ...@@ -2503,7 +2455,7 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 16 #pragma unroll 16
for(int i = 0; i < 16; i++) for(int i = 0; i < 16; i++)
quant_map[i] = nf4_data[i]; quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160]; //__shared__ T quant_map[16*160];
T local_A[2]; T local_A[2];
...@@ -2708,13 +2660,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -2708,13 +2660,13 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps] // load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
// 4 warps -> 4 loads per iter // 4 warps -> 4 loads per iter
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
typedef hipcub::WarpReduce<float, warpSize> WarpReduce; typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/warpSize]; __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
const int warp_idx = threadIdx.x / warpSize; const int warp_idx = threadIdx.x / WARP_SIZE;
const int warp_lane = threadIdx.x % warpSize; const int warp_lane = threadIdx.x % WARP_SIZE;
const int row_B = (THREADS/warpSize)*blockIdx.x + warp_idx; const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
const int offset_B = ldb*row_B; const int offset_B = ldb * row_B;
const int num_values_8bit = num_values_4bit/2; const int num_values_8bit = num_values_4bit/2;
float local_C = 0.0f; float local_C = 0.0f;
...@@ -2732,7 +2684,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc ...@@ -2732,7 +2684,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
// A: [1, K] // A: [1, K]
// B: [M, K] // B: [M, K]
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += warpSize*num_values_4bit) for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
{ {
const int inner_idx_halved = inner_idx/2; const int inner_idx_halved = inner_idx/2;
......
...@@ -20,6 +20,12 @@ ...@@ -20,6 +20,12 @@
#define ERR_NOT_IMPLEMENTED 100 #define ERR_NOT_IMPLEMENTED 100
#if defined(__GFX9__)
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
using namespace BinSearch; using namespace BinSearch;
using std::cout; using std::cout;
using std::endl; using std::endl;
...@@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int ...@@ -692,7 +698,7 @@ template <typename T, int BITS> void gemm_4bit_inference_naive(int m, int n, int
//warpsize - 32 //warpsize - 32
int num_blocks = (m+3)/4; int num_blocks = (m+3)/4;
//warpsize - 64 //warpsize - 64
if (warpSize == 64) { if (WARP_SIZE == 64) {
num_blocks = (m+1)/2; num_blocks = (m+1)/2;
} }
......
...@@ -10,7 +10,7 @@ import torch ...@@ -10,7 +10,7 @@ import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes import functional as F from bitsandbytes import functional as F
from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from bitsandbytes.cextension import HIP_ENVIRONMENT
from tests.helpers import ( from tests.helpers import (
BOOLEAN_TUPLES, BOOLEAN_TUPLES,
TRUE_FALSE, TRUE_FALSE,
...@@ -463,6 +463,7 @@ class TestIGEMMFunctional: ...@@ -463,6 +463,7 @@ class TestIGEMMFunctional:
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim")) @pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim")) @pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose")) @pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose): def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x): def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True) maxA = torch.amax(x, dim=2, keepdim=True)
...@@ -1408,10 +1409,7 @@ class TestQuantize4BitFunctional: ...@@ -1408,10 +1409,7 @@ class TestQuantize4BitFunctional:
@pytest.mark.parametrize("device", get_available_devices()) @pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"]) @pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
@pytest.mark.skipif( @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
reason="this test is not supported on ROCm with gfx90a architecture yet",
)
def test_gemv_eye_4bit(self, device, storage_type, dtype): def test_gemv_eye_4bit(self, device, storage_type, dtype):
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype): if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
pytest.skip("This configuration is not supported on HPU.") pytest.skip("This configuration is not supported on HPU.")
......
...@@ -9,6 +9,7 @@ import pytest ...@@ -9,6 +9,7 @@ import pytest
import torch import torch
import bitsandbytes as bnb import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.nn.modules import Linear8bitLt from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import ( from tests.helpers import (
TRUE_FALSE, TRUE_FALSE,
...@@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit): ...@@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit):
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph")) @pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode")) @pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4") @pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode): def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows": if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows") pytest.skip("Triton is not officially supported on Windows")
......
...@@ -211,6 +211,7 @@ class Test4bitBlockwiseQuantOps: ...@@ -211,6 +211,7 @@ class Test4bitBlockwiseQuantOps:
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype")) @pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"]) @pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512]) @pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize): def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype): if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.") pytest.skip("This configuration is not supported on HPU.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment