Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
51a21df7
Commit
51a21df7
authored
Apr 01, 2023
by
Tim Dettmers
Browse files
Added 8-bit compression to quantization statistics.
parent
c4cfe4fb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
27 deletions
+88
-27
bitsandbytes/functional.py
bitsandbytes/functional.py
+26
-12
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+6
-4
tests/test_autograd.py
tests/test_autograd.py
+7
-6
tests/test_functional.py
tests/test_functional.py
+48
-4
tests/test_modules.py
tests/test_modules.py
+1
-1
No files found.
bitsandbytes/functional.py
View file @
51a21df7
...
@@ -155,7 +155,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
...
@@ -155,7 +155,7 @@ def create_linear_map(signed=True, total_bits=8, add_zero=True):
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
#return torch.Tensor(values[:l].tolist() + [-1e-6]*((gap//2)-1) + [0]*2 + [1e-6]*((gap//2)-1) + values[l:].tolist())
return
torch
.
Tensor
(
values
[:
l
].
tolist
()
+
[
0
]
*
gap
+
values
[
l
:].
tolist
())
return
torch
.
Tensor
(
values
[:
l
].
tolist
()
+
[
0
]
*
gap
+
values
[
l
:].
tolist
())
def
custom_map
(
seed
=
0
,
scale
=
0.01
):
def
create_
custom_map
(
seed
=
0
,
scale
=
0.01
):
v
=
[
12
,
10
,
8
,
6
,
3
,
2
,
1
]
v
=
[
12
,
10
,
8
,
6
,
3
,
2
,
1
]
# 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
# 16-bit 7B 22.33, 4-bit best 22.88, FP4 23.25, 4-bit 95 22.97, 4-bit evo 22.45
# 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48
# 16-bit 13B 70.35, 4-bit best 67.16, FP4 100.78, 4-bit-95 69.39, 4-bit evo 70.48
...
@@ -191,13 +191,13 @@ def custom_map(seed=0, scale=0.01):
...
@@ -191,13 +191,13 @@ def custom_map(seed=0, scale=0.01):
# 13B evo start
# 13B evo start
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
#v = [1.6077535089716468, 1.1914902148179205, 0.8999752421085561, 0.6967904489387543, 0.4949093928311768, 0.30920472033044544, 0.15391602735952042]
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
#v = [1.586363722436466, 1.202610827188916, 0.9003332576346587, 0.6904888715206972, 0.49490974688233724, 0.2971151461329376, 0.15683230810738283]
v
=
[
1.5842247437829478
,
1.2037228884260156
,
0.900369059187269
,
0.6898587137788914
,
0.4949097822874533
,
0.2959061887131868
,
0.15712393618216908
]
#
v = [1.5842247437829478, 1.2037228884260156, 0.900369059187269, 0.6898587137788914, 0.4949097822874533, 0.2959061887131868, 0.15712393618216908]
# mean evo 7B + 13B
# mean evo 7B + 13B
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
#v = [1.5993337549066253, 1.1965624035328402, 0.9000864380418481, 0.6925840978034195, 0.5011181210961458, 0.32040328389777434, 0.13570386022711237]
# theoretically optiomal (0.93333)
# theoretically optiomal (0.93333)
#
v = [1.501085946044025, 1.1331700302595604, 0.8761428492468408, 0.6670160135425023, 0.48373855304610314, 0.3155014472579608, 0.15580024666388428] # 0.9333333333333333
v
=
[
1.501085946044025
,
1.1331700302595604
,
0.8761428492468408
,
0.6670160135425023
,
0.48373855304610314
,
0.3155014472579608
,
0.15580024666388428
]
# 0.9333333333333333
...
@@ -599,7 +599,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
...
@@ -599,7 +599,9 @@ def quantize_blockwise(A: Tensor, code: Tensor = None, absmax: Tensor = None, ra
assert
rand
is
None
assert
rand
is
None
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
lib
.
cquantize_blockwise_cpu_fp32
(
get_ptr
(
code
),
get_ptr
(
A
),
get_ptr
(
absmax
),
get_ptr
(
out
),
ct
.
c_longlong
(
blocksize
),
ct
.
c_longlong
(
A
.
numel
()))
return
out
,
(
absmax
,
code
)
state
=
(
absmax
,
code
,
blocksize
)
return
out
,
state
def
dequantize_blockwise
(
def
dequantize_blockwise
(
...
@@ -644,9 +646,9 @@ def dequantize_blockwise(
...
@@ -644,9 +646,9 @@ def dequantize_blockwise(
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
out
=
torch
.
zeros_like
(
A
,
dtype
=
torch
.
float32
)
if
quant_state
is
None
:
if
quant_state
is
None
:
quant_state
=
(
absmax
,
code
)
quant_state
=
(
absmax
,
code
,
blocksize
)
else
:
else
:
absmax
,
code
=
quant_state
absmax
,
code
,
blocksize
=
quant_state
if
A
.
device
.
type
!=
'cpu'
:
if
A
.
device
.
type
!=
'cpu'
:
...
@@ -669,7 +671,7 @@ def dequantize_blockwise(
...
@@ -669,7 +671,7 @@ def dequantize_blockwise(
return
out
return
out
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
)
->
Tensor
:
def
quantize_fp4
(
A
:
Tensor
,
absmax
:
Tensor
=
None
,
out
:
Tensor
=
None
,
blocksize
=
64
,
compress_statistics
=
False
)
->
Tensor
:
"""
"""
Quantize tensor A in blocks of FP4 values.
Quantize tensor A in blocks of FP4 values.
...
@@ -704,12 +706,11 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
...
@@ -704,12 +706,11 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
blocks
+=
1
if
n
%
blocksize
>
0
else
0
blocks
+=
1
if
n
%
blocksize
>
0
else
0
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
absmax
=
torch
.
zeros
((
blocks
,),
device
=
A
.
device
)
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
)
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,
1
),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
out
=
torch
.
zeros
(((
n
+
1
)
//
2
,
1
),
dtype
=
torch
.
uint8
,
device
=
A
.
device
)
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
]
assert
blocksize
in
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
32
]
prev_device
=
pre_call
(
A
.
device
)
prev_device
=
pre_call
(
A
.
device
)
is_on_gpu
([
A
,
out
,
absmax
])
is_on_gpu
([
A
,
out
,
absmax
])
...
@@ -722,6 +723,17 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
...
@@ -722,6 +723,17 @@ def quantize_fp4(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksize
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
raise
ValueError
(
f
"Blockwise quantization only supports 16/32-bit floats, but got
{
A
.
dtype
}
"
)
post_call
(
A
.
device
)
post_call
(
A
.
device
)
if
compress_statistics
:
offset
=
absmax
.
mean
()
absmax
-=
offset
#code = create_custom_map().to(absmax.device)
#qabsmax, state2 = quantize_blockwise(absmax, code=code, blocksize=256)
qabsmax
,
state2
=
quantize_blockwise
(
absmax
,
blocksize
=
256
)
del
absmax
state
=
(
qabsmax
,
input_shape
,
A
.
dtype
,
blocksize
,
(
offset
,
state2
))
else
:
state
=
(
absmax
,
input_shape
,
A
.
dtype
,
blocksize
,
None
)
return
out
,
state
return
out
,
state
...
@@ -756,8 +768,12 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
...
@@ -756,8 +768,12 @@ def dequantize_fp4(A: Tensor,quant_state: Tuple[Tensor, Tensor] = None, absmax:
shape
=
out
.
shape
shape
=
out
.
shape
dtype
=
out
.
dtype
dtype
=
out
.
dtype
else
:
else
:
absmax
,
shape
,
dtype
,
blocksize
=
quant_state
absmax
,
shape
,
dtype
,
blocksize
,
compressed_stats
=
quant_state
if
compressed_stats
is
not
None
:
offset
,
state2
=
compressed_stats
absmax
=
dequantize_blockwise
(
absmax
,
state2
)
absmax
+=
offset
if
out
is
None
:
if
out
is
None
:
out
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
out
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
A
.
device
)
...
@@ -1986,8 +2002,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
...
@@ -1986,8 +2002,6 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
ccolsB
=
ct
.
c_int32
(
B
.
shape
[
1
])
cldb
=
ct
.
c_int32
(
ldb
)
cldb
=
ct
.
c_int32
(
ldb
)
cldc
=
ct
.
c_int32
(
ldc
)
cldc
=
ct
.
c_int32
(
ldc
)
# print(cooA.rowidx[:64])
# print(cooA.colidx[:64].sort()[0])
is_on_gpu
([
cooA
.
rowidx
,
cooA
.
colidx
,
cooA
.
values
,
B
,
out
,
dequant_stats
])
is_on_gpu
([
cooA
.
rowidx
,
cooA
.
colidx
,
cooA
.
values
,
B
,
out
,
dequant_stats
])
if
B
.
dtype
==
torch
.
float16
:
if
B
.
dtype
==
torch
.
float16
:
...
...
bitsandbytes/nn/modules.py
View file @
51a21df7
...
@@ -134,15 +134,17 @@ class Embedding(torch.nn.Embedding):
...
@@ -134,15 +134,17 @@ class Embedding(torch.nn.Embedding):
return
emb
return
emb
class
FP4Params
(
torch
.
nn
.
Parameter
):
class
FP4Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
):
def
__new__
(
cls
,
data
=
None
,
requires_grad
=
True
,
quant_state
=
None
,
blocksize
=
64
,
compress_statistics
=
True
):
cls
.
quant_state
=
None
cls
.
quant_state
=
None
cls
.
blocksize
=
blocksize
cls
.
compress_statistics
=
compress_statistics
if
data
is
None
:
if
data
is
None
:
data
=
torch
.
empty
(
0
)
data
=
torch
.
empty
(
0
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
return
torch
.
Tensor
.
_make_subclass
(
cls
,
data
,
requires_grad
)
def
cuda
(
self
,
device
):
def
cuda
(
self
,
device
):
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w
=
self
.
data
.
contiguous
().
half
().
cuda
(
device
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
w
)
w_fp4
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
w
,
blocksize
=
self
.
blocksize
,
compress_statistics
=
self
.
compress_statistics
)
self
.
data
=
w_fp4
self
.
data
=
w_fp4
self
.
quant_state
=
quant_state
self
.
quant_state
=
quant_state
...
@@ -173,10 +175,10 @@ class FP4Params(torch.nn.Parameter):
...
@@ -173,10 +175,10 @@ class FP4Params(torch.nn.Parameter):
class
LinearFP4
(
nn
.
Linear
):
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
compute_dtype
=
None
,
compress_statistics
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
state
=
bnb
.
MatmulLtState
()
self
.
state
=
bnb
.
MatmulLtState
()
self
.
weight
=
FP4Params
(
self
.
weight
.
data
,
requires_grad
=
False
)
self
.
weight
=
FP4Params
(
self
.
weight
.
data
,
requires_grad
=
False
,
compress_statistics
=
compress_statistics
)
self
.
compute_dtype
=
compute_dtype
self
.
compute_dtype
=
compute_dtype
def
init_8bit_state
(
self
):
def
init_8bit_state
(
self
):
...
...
tests/test_autograd.py
View file @
51a21df7
...
@@ -454,14 +454,15 @@ for c in req_grad:
...
@@ -454,14 +454,15 @@ for c in req_grad:
transpose
=
[(
False
,
True
),
(
False
,
False
)]
transpose
=
[(
False
,
True
),
(
False
,
False
)]
str_transpose
=
[
"NT"
,
"NN"
]
str_transpose
=
[
"NT"
,
"NN"
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
dtype
=
[
torch
.
float16
,
torch
.
float32
]
compress_statistics
=
[
False
,
True
]
has_fp16_weights
=
[
True
,
False
]
has_fp16_weights
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
has_bias
=
[
True
,
False
]
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
))
values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
))
str_values
=
list
(
product
(
dim1
,
dim2
,
dim3
,
dim4
,
str_funcs
,
dtype
,
req_grad_str
,
str_transpose
,
has_bias
,
compress_statistics
))
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}"
.
format
(
*
vals
)
for
vals
in
str_values
]
names
=
[
"dim1_{}_dim2_{}_dim3_{}_dim4_{}_func_{}_dtype_{}_requires_grad_{}_transpose_{}_has_bias_{}
_compress_statistics
"
.
format
(
*
vals
)
for
vals
in
str_values
]
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose, has_bias
, compress_statistics
"
,
values
,
ids
=
names
)
def
test_matmul_fp4
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
):
def
test_matmul_fp4
(
dim1
,
dim2
,
dim3
,
dim4
,
funcs
,
dtype
,
req_grad
,
transpose
,
has_bias
,
compress_statistics
):
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimA
=
(
dim2
,
dim3
)
if
not
transpose
[
0
]
else
(
dim3
,
dim2
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
dimB
=
(
dim3
,
dim4
)
if
not
transpose
[
1
]
else
(
dim4
,
dim3
)
if
has_bias
==
False
:
if
has_bias
==
False
:
...
@@ -481,7 +482,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -481,7 +482,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
bias2
=
bias
.
clone
()
bias2
=
bias
.
clone
()
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
B
)
B2
,
quant_state
=
bnb
.
functional
.
quantize_fp4
(
B
,
compress_statistics
=
compress_statistics
)
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
...
...
tests/test_functional.py
View file @
51a21df7
...
@@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization():
...
@@ -167,8 +167,8 @@ def test_dynamic_blockwise_quantization():
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.011
assert
abserr
<
0.011
assert
relerr
<
0.018
assert
relerr
<
0.018
print
(
'randn'
,
blocksize
,
sum
(
diffs
)
/
len
(
diffs
))
#
print('randn', blocksize, sum(diffs)/len(diffs))
print
(
'randn'
,
blocksize
,
sum
(
reldiffs
)
/
len
(
reldiffs
))
#
print('randn', blocksize, sum(reldiffs)/len(reldiffs))
diffs
=
[]
diffs
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization():
...
@@ -184,8 +184,8 @@ def test_dynamic_blockwise_quantization():
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
relerr
=
sum
(
reldiffs
)
/
len
(
reldiffs
)
assert
abserr
<
0.0035
assert
abserr
<
0.0035
assert
relerr
<
0.015
assert
relerr
<
0.015
print
(
'rand'
,
blocksize
,
sum
(
diffs
)
/
len
(
diffs
))
#
print('rand', blocksize, sum(diffs)/len(diffs))
print
(
'rand'
,
blocksize
,
sum
(
reldiffs
)
/
len
(
reldiffs
))
#
print('rand', blocksize, sum(reldiffs)/len(reldiffs))
def
test_dynamic_blockwise_stochastic_quantization
():
def
test_dynamic_blockwise_stochastic_quantization
():
...
@@ -1806,6 +1806,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1806,6 +1806,7 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
B_fp4_c
,
state_c
=
F
.
quantize_fp4
(
B
,
compress_statistics
=
True
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
...
@@ -1839,6 +1840,13 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1839,6 +1840,13 @@ def test_bench_matmul(batch, seq, model, hidden):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_fp4
(
A
,
B_fp4
.
t
(),
quant_state
=
state_c
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4 + compressed stats: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
#torch.cuda.synchronize()
#torch.cuda.synchronize()
#t0 = time.time()
#t0 = time.time()
#for i in range(iters):
#for i in range(iters):
...
@@ -2244,6 +2252,42 @@ def test_fp4_quant():
...
@@ -2244,6 +2252,42 @@ def test_fp4_quant():
assert
relerr
.
item
()
<
0.28
assert
relerr
.
item
()
<
0.28
def
test_fp4_compressed_stats
():
for
blocksize
in
[
128
,
64
]:
errs1
=
[]
errs2
=
[]
for
i
in
range
(
10
):
A1
=
torch
.
randn
(
1024
,
1024
,
device
=
'cuda'
).
half
()
q2
,
SA2
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
)
q3
,
SA3
=
F
.
quantize_fp4
(
A1
,
blocksize
=
blocksize
,
compress_statistics
=
True
)
A2
=
F
.
dequantize_fp4
(
q2
,
SA2
)
A3
=
F
.
dequantize_fp4
(
q3
,
SA3
)
err
=
(
A1
-
A2
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
errs1
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
err
=
(
A1
-
A3
).
abs
().
float
()
relerr
=
(
err
/
(
A1
.
abs
().
float
()
+
1e-15
)).
mean
()
err
=
err
.
mean
()
errs2
.
append
(
err
.
item
())
assert
err
.
item
()
<
0.11
assert
relerr
.
item
()
<
0.28
#print(sum(errs1)/len(errs1), blocksize)
#print(sum(errs2)/len(errs2), blocksize)
def
test_bench_fp4_dequant
():
def
test_bench_fp4_dequant
():
blocksize
=
256
blocksize
=
256
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
a
=
torch
.
rand
(
1024
*
12
*
4
,
1024
*
12
,
device
=
'cuda'
).
half
()
...
...
tests/test_modules.py
View file @
51a21df7
...
@@ -507,7 +507,7 @@ def test_linear_kbit_fp32_bias(module):
...
@@ -507,7 +507,7 @@ def test_linear_kbit_fp32_bias(module):
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"module"
,
[
bnb
.
nn
.
Linear8bitLt
,
bnb
.
nn
.
LinearFP4
],
ids
=
[
'Int8Lt'
,
'FP4'
])
@
pytest
.
mark
.
parametrize
(
"module"
,
[
bnb
.
nn
.
Linear8bitLt
,
bnb
.
nn
.
LinearFP4
,
lambda
d1
,
d2
:
bnb
.
nn
.
LinearFP4
(
d1
,
d2
,
compress_statistics
=
True
)
],
ids
=
[
'Int8Lt'
,
'FP4'
,
'FP4+C'
])
def
test_kbit_backprop
(
module
):
def
test_kbit_backprop
(
module
):
b
=
17
b
=
17
dim1
=
37
dim1
=
37
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment