Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
94168d79
Commit
94168d79
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added FP4 fast inference support.
parent
4b88d69d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
16 deletions
+14
-16
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+2
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-2
csrc/kernels.cu
csrc/kernels.cu
+2
-4
tests/test_functional.py
tests/test_functional.py
+9
-8
No files found.
bitsandbytes/autograd/_functions.py
View file @
94168d79
...
@@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -509,7 +509,7 @@ class MatMul4Bit(torch.autograd.Function):
# 1. Dequantize
# 1. Dequantize
# 2. MatmulnN
# 2. MatmulnN
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
fp4
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_
4bit
(
B
,
state
).
to
(
A
.
dtype
).
t
(),
bias
)
# 3. Save state
# 3. Save state
ctx
.
state
=
state
ctx
.
state
=
state
...
@@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
...
@@ -540,7 +540,7 @@ class MatMul4Bit(torch.autograd.Function):
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
fp4
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_
4bit
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
bitsandbytes/functional.py
View file @
94168d79
...
@@ -1459,8 +1459,7 @@ def gemv_4bit(
...
@@ -1459,8 +1459,7 @@ def gemv_4bit(
out
:
Tensor
=
None
,
out
:
Tensor
=
None
,
transposed_A
=
False
,
transposed_A
=
False
,
transposed_B
=
False
,
transposed_B
=
False
,
state
=
None
,
state
=
None
storage_type
=
'nf4'
):
):
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
if
state
is
None
:
...
...
csrc/kernels.cu
View file @
94168d79
...
@@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3546,7 +3546,8 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
T
local_absmax
=
T
(
0.0
f
);
T
local_absmax
=
T
(
0.0
f
);
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
for
(
int
i
=
threadIdx
.
x
;
i
<
16
;
i
++
)
quant_map
[
i
]
=
nf4_data
[
i
];
quant_map
[
i
]
=
datatype
[
i
];
__syncthreads
();
__syncthreads
();
// A: [1, K]
// A: [1, K]
...
@@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
...
@@ -3580,9 +3581,6 @@ template <typename T, int THREADS> __global__ void kgemm_4bit_inference_naive(in
{
{
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
]
=
quant_map
[
local_B_4bit
[
k
]
>>
4
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
local_B
[
k
*
2
+
1
]
=
quant_map
[
local_B_4bit
[
k
]
&
0x0F
]
*
local_absmax
;
//if(threadIdx.x == 0)
//printf("%f %f %f %f\n", (float)local_B[k*2], (float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax, (float)local_B[k*2]- ((float)dDequantizeNF4(local_B_4bit[k] >> 4)*(float)local_absmax), (float)local_absmax);
}
}
if
(
inner_idx
+
num_values_4bit
)
if
(
inner_idx
+
num_values_4bit
)
...
...
tests/test_functional.py
View file @
94168d79
...
@@ -2351,12 +2351,13 @@ def test_normal_map_tree():
...
@@ -2351,12 +2351,13 @@ def test_normal_map_tree():
print
(
pivots
)
print
(
pivots
)
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gemv_4bit
(
dtype
):
def
test_gemv_4bit
(
dtype
,
storage_type
):
print
(
''
)
print
(
''
)
for
dim
in
[
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#for dim in [4*1024]:
#for dim in [4*1024]:
#for dim in [1*16]:
#for dim in [1*16]:
errs
=
[]
errs
=
[]
...
@@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
...
@@ -2364,7 +2365,7 @@ def test_gemv_4bit(dtype):
max_err
=
0
max_err
=
0
max_relerr
=
0
max_relerr
=
0
for
i
in
range
(
1
00
):
for
i
in
range
(
1
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
...
@@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
...
@@ -2381,8 +2382,8 @@ def test_gemv_4bit(dtype):
#A.flatten()[:-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB
,
state
=
F
.
quantize_
nf4
(
B
)
qB
,
state
=
F
.
quantize_
4bit
(
B
,
quant_type
=
storage_type
)
F
.
dequantize_
nf4
(
qB
,
state
)
F
.
dequantize_
4bit
(
qB
,
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
...
@@ -2396,7 +2397,6 @@ def test_gemv_4bit(dtype):
#print(A)
#print(A)
#print(B)
#print(B)
#print('='*89)
#print('='*89)
#print(C3.flatten()[-20:])
#print(C3)
#print(C3)
#print(C1.shape, C2.shape)
#print(C1.shape, C2.shape)
...
@@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
...
@@ -2425,8 +2425,9 @@ def test_gemv_4bit(dtype):
#print(dim, (max_err.item(), max_relerr.item()))
#print(dim, (max_err.item(), max_relerr.item()))
print
(
C1
.
flatten
()[
-
20
:])
print
(
C1
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
C2
.
flatten
()[
-
20
:])
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
0.00015
)
print
(
C3
.
flatten
()[
-
20
:])
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
0.0015
)
print
(
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
,
dim
)
print
(
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
,
dim
)
if
dtype
==
torch
.
float16
:
if
dtype
==
torch
.
float16
:
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
errs
)
/
len
(
errs
)
/
math
.
sqrt
(
dim
)
<
5e-5
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
assert
sum
(
relerrs
)
/
len
(
relerrs
)
/
math
.
sqrt
(
dim
)
<
0.0005
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment