Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
94168d79
Commit
94168d79
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added FP4 fast inference support.
parent
4b88d69d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
16 deletions
+14
-16
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+2
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-2
csrc/kernels.cu
csrc/kernels.cu
+2
-4
tests/test_functional.py
tests/test_functional.py
+9
-8
No files found.
bitsandbytes/autograd/_functions.py
View file @
94168d79
...
@@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize
# 1. Dequantize
# 2. MatmulnN
# 2. MatmulnN
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
fp4
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
4bit
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
# 3. Save state
# 3. Save state
ctx
.
state
=
state
ctx
.
state
=
state
...
@@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
fp4
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
4bit
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
bitsandbytes/functional.py
View file @
94168d79
...
@@ -1459,8 +1459,7 @@ def gemv_4bit(
...
@@ -1459,8 +1459,7 @@ def gemv_4bit(
out
:
Tensor
=
None
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_A
=
False
,
transposed_B
=
False
,
transposed_B
=
False
,
state
=
None
,
state
=
None
storage_type
=
'nf4'
):
):
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
if
state
is
None
:
...
...
csrc/kernels.cu
View file @
94168d79
...
@@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
T
local_absmax
=
T
(
0.0
f
);
T
local_absmax
=
T
(
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
quant_map
[
i
]
=
datatype
[
i
];
__syncthreads
();
__syncthreads
();
// A: [1, K]
// A: [1, K]
...
@@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
//if(threadIdx.x == 0)
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
}
}
if
(
inner_idx
+
num_values_4bit
)
if
(
inner_idx
+
num_values_4bit
)
...
...
tests/test_functional.py
View file @
94168d79
...
@@ -2351,12 +2351,13 @@ def test_normal_map_tree():
...
@@ -2351,12 +2351,13 @@ def test_normal_map_tree():
print
(
pivots
)
print
(
pivots
)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gemv_4bit
(
dtype
):
def
test_gemv_4bit
(
dtype
,
storage_type
):
print
(
''
)
print
(
''
)
for
dim
in
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#for dim in [4*1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
#for dim in [1*16]:
errs
=
[]
errs
=
[]
...
@@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
...
@@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
max_err
=
0
max_err
=
0
max_relerr
=
0
max_relerr
=
0
for
i
in
range
(
1
00
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
...
@@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
...
@@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
#A.flatten()[:-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB
,
state
=
F
.
quantize_
nf4
(
B
)
qB
,
state
=
F
.
quantize_
4bit
(
B
,
quant_type
=
storage_type
)
F
.
dequantize_
nf4
(
qB
,
state
)
F
.
dequantize_
4bit
(
qB
,
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
...
@@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
#print(A)
#print(A)
#print(B)
#print(B)
#print('='*89)
#print('='*89)
#print(C3.flatten()[-20:])
#print(C3)
#print(C3)
#print(C1.shape, C2.shape)
#print(C1.shape, C2.shape)
...
@@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
...
@@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
#print(dim, (max_err.item(), max_relerr.item()))
#print(dim, (max_err.item(), max_relerr.item()))
print
(
C1
.
flatten
()[
-
20
:])
print
(
C1
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
0.00015
)
print
(
C3
.
flatten
()[
-
20
:])
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
0.0015
)
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment