Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
0f0390ac
"csrc/pythonInterface.cpp" did not exist on "3aef78342aec4fff1922c0c2cdd83bdda928b536"
Commit
0f0390ac
authored
Jul 09, 2023
by
Tim Dettmers
Browse files
Added double quantization support and tests.
parent
94168d79
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
18 deletions
+38
-18
bitsandbytes/functional.py
bitsandbytes/functional.py
+18
-7
tests/test_functional.py
tests/test_functional.py
+20
-11
No files found.
bitsandbytes/functional.py
View file @
0f0390ac
...
@@ -1461,16 +1461,25 @@ def gemv_4bit(
...
@@ -1461,16 +1461,25 @@ def gemv_4bit(
transposed_B
=
False
,
transposed_B
=
False
,
state
=
None
state
=
None
):
):
prev_device
=
pre_call
(
A
.
device
)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
#sout = check_matmul(A, B, out, transposed_A, transposed_B, expected_type=A.dtype)
if
state
is
None
:
if
state
is
None
:
Bshape
=
B
.
shape
raise
ValueError
(
f
'state cannot None. gem_4bit( ) requires the state from quantize_4bit( )'
)
bout
=
Bshape
[
1
]
else
:
Bshape
=
state
[
1
]
Bshape
=
state
[
1
]
bout
=
Bshape
[
0
]
bout
=
Bshape
[
0
]
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
,
quant_type
,
data_type
=
state
if
compressed_stats
is
not
None
:
offset
,
state2
=
compressed_stats
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
out
=
torch
.
zeros
(
size
=
(
A
.
shape
[
0
],
bout
),
dtype
=
A
.
dtype
,
device
=
A
.
device
)
sA
=
A
.
shape
sA
=
A
.
shape
sB
=
B
.
shape
sB
=
B
.
shape
if
transposed_A
and
len
(
sA
)
==
2
:
if
transposed_A
and
len
(
sA
)
==
2
:
...
@@ -1557,14 +1566,16 @@ def gemv_4bit(
...
@@ -1557,14 +1566,16 @@ def gemv_4bit(
if
B
.
dtype
==
torch
.
uint8
:
if
B
.
dtype
==
torch
.
uint8
:
if
A
.
dtype
==
torch
.
float16
:
if
A
.
dtype
==
torch
.
float16
:
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_fp16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
elif
A
.
dtype
==
torch
.
bfloat16
:
elif
A
.
dtype
==
torch
.
bfloat16
:
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
state
[
0
]
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
lib
.
cgemm_4bit_inference_naive_bf16
(
m
,
n
,
k
,
get_ptr
(
A
),
get_ptr
(
B
),
get_ptr
(
absmax
),
get_ptr
(
state
[
-
1
]),
get_ptr
(
out
),
lda
,
ldb
,
ldc
,
ct
.
c_int32
(
state
[
3
]))
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
else
:
else
:
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
raise
NotImplementedError
(
f
'Matmul not implemented for data type
{
A
.
dtype
}
'
)
post_call
(
prev_device
)
return
out
return
out
...
...
tests/test_functional.py
View file @
0f0390ac
...
@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
...
@@ -1776,7 +1776,7 @@ def test_spmm_coo_dequant(dim1, dim2, dtype):
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
print
(
"partial matmul"
,
time
.
time
()
-
t0
)
batch_size
=
1
batch_size
=
5
seqdim
=
1
seqdim
=
1
values
=
[]
values
=
[]
#values.append((batch_size, seqdim, 768, 4 * 768))
#values.append((batch_size, seqdim, 768, 4 * 768))
...
@@ -1786,8 +1786,8 @@ values = []
...
@@ -1786,8 +1786,8 @@ values = []
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 5120, 4*5120))
#values.append((batch_size, seqdim, 5120, 4*5120))
#
values.append((batch_size, seqdim, 6656, 4*6656))
values
.
append
((
batch_size
,
seqdim
,
6656
,
4
*
6656
))
values
.
append
((
batch_size
,
seqdim
,
8192
,
4
*
8192
))
#
values.append((batch_size, seqdim, 8192, 4*8192))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
...
@@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1804,6 +1804,7 @@ def test_bench_matmul(batch, seq, model, hidden):
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
B_nf4
,
state_nf4
=
F
.
quantize_nf4
(
B
)
B_nf4_c
,
state_nf4_c
=
F
.
quantize_nf4
(
B
,
compress_statistics
=
True
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
...
@@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1816,7 +1817,7 @@ def test_bench_matmul(batch, seq, model, hidden):
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
F
.
gemv
_4bit
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
bnb
.
matmul
_4bit
(
A
,
B_nf4
.
t
(),
quant_
state
=
state_nf4
)
# warmup
# warmup
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
...
@@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1848,11 +1849,18 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
#bnb.matmul_4bit(A, B_nf4.t(), quant_state=state_nf4)
bnb
.
matmul_4bit
(
A
,
B_nf4
.
t
(),
quant_state
=
state_nf4
)
F
.
gemv_4bit
(
A
,
B_nf4
.
t
(),
state
=
state_nf4
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb nf4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_4bit
(
A
,
B_nf4_c
.
t
(),
quant_state
=
state_nf4_c
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb nf4+DQ: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#t0 = time.time()
#t0 = time.time()
#for i in range(iters):
#for i in range(iters):
...
@@ -2351,11 +2359,12 @@ def test_normal_map_tree():
...
@@ -2351,11 +2359,12 @@ def test_normal_map_tree():
print
(
pivots
)
print
(
pivots
)
@
pytest
.
mark
.
parametrize
(
"double_quant"
,
[
True
,
False
],
ids
=
[
'DQ_True'
,
'DQ_False'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
@
pytest
.
mark
.
parametrize
(
"storage_type"
,
[
'nf4'
,
'fp4'
],
ids
=
[
'nf4'
,
'fp4'
])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
#@pytest.mark.parametrize("dtype", [torch.float32, torch.float16], ids=['fp32', 'fp16'])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
],
ids
=
[
'fp16'
,
'bf16'
])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
#@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=['bf16'])
def
test_gemv_4bit
(
dtype
,
storage_type
):
def
test_gemv_4bit
(
dtype
,
storage_type
,
double_quant
):
print
(
''
)
print
(
''
)
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
for
dim
in
[
128
,
256
,
512
,
1024
,
2048
,
4096
]:
#for dim in [4*1024]:
#for dim in [4*1024]:
...
@@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type):
...
@@ -2365,7 +2374,7 @@ def test_gemv_4bit(dtype, storage_type):
max_err
=
0
max_err
=
0
max_relerr
=
0
max_relerr
=
0
for
i
in
range
(
1
):
for
i
in
range
(
1
00
):
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#A = torch.rand(2, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#B = torch.rand(4*4092, 4092, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
#A = torch.rand(1, 4096, dtype=dtype, device='cuda')
...
@@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type):
...
@@ -2382,11 +2391,11 @@ def test_gemv_4bit(dtype, storage_type):
#A.flatten()[:-1] = 0
#A.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
#B.flatten()[:-1] = 0
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
)
qB
,
state
=
F
.
quantize_4bit
(
B
,
quant_type
=
storage_type
,
compress_statistics
=
double_quant
)
F
.
dequantize_4bit
(
qB
,
state
)
#
F.dequantize_4bit(qB, state)
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C3
=
torch
.
matmul
(
A
,
B
.
t
())
C2
=
F
.
gemv_4bit
(
A
,
qB
.
t
(),
state
=
state
)
A
.
requires_grad
=
True
A
.
requires_grad
=
True
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
C1
=
bnb
.
matmul_4bit
(
A
,
qB
.
t
(),
state
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment