Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
b7f04e2a
Commit
b7f04e2a
authored
May 30, 2023
by
Tim Dettmers
Browse files
Added lookup table.
parent
ac5550a0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
19 additions
and
6 deletions
+19
-6
Makefile
Makefile
+2
-2
csrc/kernels.cu
csrc/kernels.cu
+9
-0
tests/test_functional.py
tests/test_functional.py
+8
-4
No files found.
Makefile
View file @
b7f04e2a
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
...
@@ -47,8 +47,8 @@ CC_cublasLt110 := -gencode arch=compute_75,code=sm_75
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt110
+=
-gencode
arch
=
compute_80,code
=
sm_80
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
:=
-gencode
arch
=
compute_75,code
=
sm_75
CC_cublasLt111
+=
-gencode
arch
=
compute_80,code
=
sm_80
#
CC_cublasLt111 += -gencode arch=compute_80,code=sm_80
CC_cublasLt111
+=
-gencode
arch
=
compute_86,code
=
sm_86
#
CC_cublasLt111 += -gencode arch=compute_86,code=sm_86
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
:=
-gencode
arch
=
compute_89,code
=
sm_89
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
CC_ADA_HOPPER
+=
-gencode
arch
=
compute_90,code
=
sm_90
...
...
csrc/kernels.cu
View file @
b7f04e2a
...
@@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
...
@@ -3297,6 +3297,7 @@ template <typename T, int BITS, int THREADS> __global__ void gemm_device(int M,
#endif
#endif
}
}
__device__
static
float
nf4_data
[
16
]
=
{
-
1.0
,
-
0.6961928009986877
,
-
0.5250730514526367
,
-
0.39491748809814453
,
-
0.28444138169288635
,
-
0.18477343022823334
,
-
0.09105003625154495
,
0.0
,
0.07958029955625534
,
0.16093020141124725
,
0.24611230194568634
,
0.33791524171829224
,
0.44070982933044434
,
0.5626170039176941
,
0.7229568362236023
,
1.0
};
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
template
<
typename
T
,
int
THREADS
>
__global__
void
kgemm_4bit_inference
(
int
M
,
int
N
,
int
K
,
T
*
__restrict__
const
A
,
unsigned
char
*
B
,
float
*
absmax
,
T
*
out
,
int
lda
,
int
ldb
,
int
ldc
,
int
blocksize
)
{
{
...
@@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3308,6 +3309,12 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
half_warp_lane
=
threadIdx
.
x
%
16
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
const
int
batch_size_warps
=
(
WARPS
-
1
)
*
2
;
T
quant_map
[
16
];
#pragma unroll 16
for
(
int
i
=
0
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
T
local_A
[
2
];
T
local_A
[
2
];
T
local_B
[
64
];
T
local_B
[
64
];
unsigned
char
local_B_4bit
[
32
];
unsigned
char
local_B_4bit
[
32
];
...
@@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
...
@@ -3410,6 +3417,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference(int M, i
{
{
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
local_B
[
col
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
>>
4
)
*
T
(
absidx
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
local_B
[
col
+
1
]
=
dhDequantizeNF4
(
local_B_4bit
[
col
/
2
]
&
0x0F
)
*
T
(
absidx
);
//local_B[col] = quant_map[(local_B_4bit[col/2] >> 4)]*T(absidx);
//local_B[col+1] = quant_map[(local_B_4bit[col/2] & 0x0F)]*T(absidx);
}
}
}
}
...
...
tests/test_functional.py
View file @
b7f04e2a
...
@@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
...
@@ -2297,7 +2297,8 @@ def test_4bit_compressed_stats(quant_type):
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'fp4'
,
'nf4'
])
#@pytest.mark.parametrize("quant_type", ['fp4', 'nf4'])
@
pytest
.
mark
.
parametrize
(
"quant_type"
,
[
'nf4'
])
def
test_bench_4bit_dequant
(
quant_type
):
def
test_bench_4bit_dequant
(
quant_type
):
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
...
@@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
...
@@ -2311,7 +2312,7 @@ def test_bench_4bit_dequant(quant_type):
#print(max_theoretical_s*1e6)
#print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
iters
=
100
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
...
@@ -2438,9 +2439,11 @@ def test_gemm_4bit(dtype):
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C2
=
F
.
cutlass3_gemm
(
A
,
qB
.
t
(),
state
=
state
)
print
(
C1
.
shape
,
C2
.
shape
)
print
(
C1
)
print
(
C2
)
#print(C1.shape, C2.shape)
# tensor cores are non-deterministic
# tensor cores are non-deterministic
# so we need to analyze errors around the mean
# so we need to analyze errors around the mean
...
@@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
...
@@ -2452,6 +2455,7 @@ def test_gemm_4bit(dtype):
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
max_relerr
=
max
(
relerr
.
max
(),
max_relerr
)
err
=
err
.
mean
().
item
()
err
=
err
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
relerr
=
relerr
.
mean
().
item
()
print
(
err
)
errs
.
append
(
err
)
errs
.
append
(
err
)
relerrs
.
append
(
relerr
)
relerrs
.
append
(
relerr
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment