Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b7f04e2a
You need to sign in or sign up before continuing.
Commit
b7f04e2a
authored
May 30, 2023
by
Tim Dettmers
Browse files
Added lookup table.
parent
ac5550a0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
6 deletions
+19
-6
Makefile
Makefile
+2
-2
csrc/kernels.cu
csrc/kernels.cu
+9
-0
tests/test_functional.py
tests/test_functional.py
+8
-4
No files found.
Makefile
View file @
b7f04e2a
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
+=
-gencode
arch
=
compute_80,code
=
sm_80
#
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111
+=
-gencode
arch
=
compute_86,code
=
sm_86
#
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
...
...
csrc/kernels.cu
View file @
b7f04e2a
...
@@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
#endif
#endif
}
}
__device__
static
float
nf4_data
[
16
]
=
{
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
};
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
...
@@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
T
quant_map
[
16
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
T
local_A
[
2
];
T
local_A
[
2
];
T
local_B
[
64
];
T
local_B
[
64
];
unsigned
char
local_B_4bit
[
32
];
unsigned
char
local_B_4bit
[
32
];
...
@@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
}
}
}
...
...
tests/test_functional.py
View file @
b7f04e2a
...
@@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'nf4'
])
def
test_bench_4bit_dequant
(
quant_type
):
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
...
@@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
...
@@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
iters
=
100
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
...
@@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
print
(
C1
.
shape
,
C2
.
shape
)
print
(
C1
)
print
(
C2
)
#print(C1.shape, C2.shape)
# tensor cores are non-deterministic
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# so we need to analyze errors around the mean
...
@@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
...
@@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
print
(
err
)
errs
.
append
(
err
)
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
relerrs
.
append
(
relerr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment