Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
cfe4705e
Commit
cfe4705e
authored
Feb 04, 2023
by
Tim Dettmers
Browse files
Added matmul_fp4 to the benchmark.
parent
13c0a4dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
57 additions
and
45 deletions
+57
-45
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+3
-2
bitsandbytes/functional.py
bitsandbytes/functional.py
+1
-4
tests/test_autograd.py
tests/test_autograd.py
+3
-3
tests/test_functional.py
tests/test_functional.py
+50
-36
No files found.
bitsandbytes/autograd/_functions.py
View file @
cfe4705e
...
@@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function):
...
@@ -495,7 +495,7 @@ class MatMulFP4(torch.autograd.Function):
# 1. Dequantize
# 1. Dequantize
# 2. Matmul
# 2. Matmul
nN
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_fp4
(
B
,
state
).
to
(
A
.
dtype
),
bias
)
output
=
torch
.
nn
.
functional
.
linear
(
A
,
F
.
dequantize_fp4
(
B
,
state
).
to
(
A
.
dtype
),
bias
)
# 3. Save state
# 3. Save state
...
@@ -550,5 +550,6 @@ def matmul(
...
@@ -550,5 +550,6 @@ def matmul(
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
MatMul8bitLt
.
apply
(
A
,
B
,
out
,
bias
,
state
)
def
matmul_fp4
(
A
:
tensor
,
B
:
tensor
,
out
:
tensor
=
None
,
quant_state
:
List
=
None
,
bias
=
None
):
def
matmul_fp4
(
A
:
tensor
,
B
:
tensor
,
quant_state
:
List
,
out
:
tensor
=
None
,
bias
=
None
):
assert
quant_state
is
not
None
return
MatMulFP4
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
return
MatMulFP4
.
apply
(
A
,
B
,
out
,
bias
,
quant_state
)
bitsandbytes/functional.py
View file @
cfe4705e
...
@@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -169,7 +169,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
#for ev in evalues:
#for ev in evalues:
bias
=
2
**
(
exponent_bits
-
1
)
+
1
bias
=
2
**
(
exponent_bits
-
1
)
+
1
print
(
bias
)
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
bit_pattern
in
lst
:
for
bit_pattern
in
lst
:
value
=
(
1
if
evalue
!=
0
else
0
)
value
=
(
1
if
evalue
!=
0
else
0
)
...
@@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -180,9 +179,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value
=
value
*
2
**-
(
bias
)
value
=
value
*
2
**-
(
bias
)
else
:
else
:
# normals
# normals
print
(
value
,
1
)
value
=
value
*
2
**-
(
evalue
-
bias
-
1
)
value
=
value
*
2
**-
(
evalue
-
bias
-
1
)
print
(
value
,
2
)
values
.
append
(
value
)
values
.
append
(
value
)
if
signed
:
if
signed
:
values
.
append
(
-
value
)
values
.
append
(
-
value
)
...
@@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
...
@@ -196,7 +193,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
.
append
(
0
)
values
.
append
(
0
)
values
.
sort
()
values
.
sort
()
code
=
torch
.
Tensor
(
values
)
code
=
torch
.
Tensor
(
values
)
#
code /= code.max()
code
/=
code
.
max
()
return
code
return
code
...
...
tests/test_autograd.py
View file @
cfe4705e
...
@@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -485,10 +485,10 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
if
not
transpose
[
0
]
and
transpose
[
1
]:
if
not
transpose
[
0
]
and
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_torch
=
funcs
[
0
](
A
,
B
.
t
())
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
=
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B2
,
quant_state
,
bias
=
bias2
)
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
elif
not
transpose
[
0
]
and
not
transpose
[
1
]:
out_torch
=
funcs
[
0
](
A
,
B
)
out_torch
=
funcs
[
0
](
A
,
B
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
=
quant_state
,
bias
=
bias2
)
out_bnb
=
funcs
[
1
](
A
,
B2
.
t
(),
quant_state
,
bias
=
bias2
)
if
has_bias
:
if
has_bias
:
out_torch
+=
bias
out_torch
+=
bias
...
@@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
...
@@ -498,7 +498,7 @@ def test_matmul_fp4( dim1, dim2, dim3, dim4, funcs, dtype, req_grad, transpose,
n
=
out_bnb
.
numel
()
n
=
out_bnb
.
numel
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
err
=
torch
.
abs
(
out_bnb
-
out_torch
).
float
().
mean
().
item
()
if
n
>
0
:
if
n
>
0
:
assert
err
<
0.11
assert
err
<
0.11
5
if
any
(
req_grad
):
if
any
(
req_grad
):
out_bnb
.
data
.
copy_
(
out_torch
)
out_bnb
.
data
.
copy_
(
out_torch
)
...
...
tests/test_functional.py
View file @
cfe4705e
...
@@ -1788,18 +1788,14 @@ batch_size = 1
...
@@ -1788,18 +1788,14 @@ batch_size = 1
seqdim
=
1
seqdim
=
1
values
=
[]
values
=
[]
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
values
.
append
((
batch_size
,
seqdim
,
768
,
4
*
768
))
#
values.append((batch_size, seqdim, 1024, 4*1024))
#values.append((batch_size, seqdim, 1024, 4*1024))
#
values.append((batch_size, seqdim, 1536, 4*1536))
#values.append((batch_size, seqdim, 1536, 4*1536))
#
values.append((batch_size, seqdim, 2048, 4*2048))
#values.append((batch_size, seqdim, 2048, 4*2048))
#
values.append((batch_size, seqdim, 2560, 4*2560))
#values.append((batch_size, seqdim, 2560, 4*2560))
#
values.append((batch_size, seqdim, 4096, 4*4096))
#values.append((batch_size, seqdim, 4096, 4*4096))
#
values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 5140, 4*5140))
#values.append((batch_size, seqdim, 12288, 4*12288))
#values.append((batch_size, seqdim, 12288, 4*12288))
names
=
[
names
=
[
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
"batch_{}_seq_{}_model_{}_hidden_{}"
.
format
(
*
vals
)
for
vals
in
values
]
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
@
pytest
.
mark
.
parametrize
(
"batch, seq, model, hidden"
,
values
,
ids
=
names
)
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
def
test_bench_matmul
(
batch
,
seq
,
model
,
hidden
):
iters
=
128
iters
=
128
...
@@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1809,17 +1805,20 @@ def test_bench_matmul(batch, seq, model, hidden):
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
B
=
torch
.
empty
(
hidden
,
model
,
dtype
=
torch
.
float16
,
device
=
"cuda"
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
torch
.
nn
.
init
.
xavier_uniform_
(
B
)
B_fp4
,
state
=
F
.
quantize_fp4
(
B
)
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit
.
eval
()
linear8bit
.
eval
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
outliers
=
torch
.
randint
(
0
,
model
,
size
=
(
5
,)).
cuda
()
A
[:,
:,
outliers
]
=
8.0
A
[:,
:,
outliers
]
=
8.0
linearMixedBit
=
(
linearMixedBit
=
(
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
())
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
)
linearMixedBit
.
eval
()
linearMixedBit
.
eval
()
linear8bit_train
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
).
cuda
().
half
()
linear8bit_train_thresh
=
bnb
.
nn
.
Linear8bitLt
(
model
,
hidden
,
False
,
threshold
=
6.0
).
cuda
().
half
()
# warmup
# warmup
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
...
@@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1831,9 +1830,14 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
torch
.
matmul
(
A
,
B
.
t
())
torch
.
matmul
(
A
,
B
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"pytorch fp16: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
bnb
.
matmul_fp4
(
A
,
B_fp4
,
quant_state
=
state
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb fp4: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
...
@@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1872,7 +1876,7 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
F
.
vectorwise_mm_dequant
(
Cout
,
statsA
,
statsB
.
t
())
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#
print(f"vector pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
print
(
f
"vector pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"linear"
)
BA
,
statsB
=
F
.
vectorwise_quant
(
B
,
dim
=
1
,
quant_type
=
"linear"
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
CxB
,
SB
=
F
.
nvidia_transform
(
CB
,
to_order
=
formatB
)
...
@@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1886,7 +1890,7 @@ def test_bench_matmul(batch, seq, model, hidden):
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
Cout
,
Sout
=
F
.
nvidia_transform
(
out32
,
"row"
,
state
=
Sout32
)
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
out
=
Cout
*
statsB
*
statsA
*
(
1.0
/
(
127
*
127
))
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#
print(f"linear pytorch + nvidia: [{batch},{seq},{model}], [{model},{hidden}]->[{batch},{seq},{hidden}]: {time.time()-t0:.4f}s")
print
(
f
"linear pytorch + nvidia: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linear8bit
(
A
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1894,9 +1898,7 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
linear8bit
(
A
)
linear8bit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"bnb linear8bitlt (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"bnb linear8bitlt: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
...
@@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden):
...
@@ -1904,9 +1906,23 @@ def test_bench_matmul(batch, seq, model, hidden):
for
i
in
range
(
iters
):
for
i
in
range
(
iters
):
linearMixedBit
(
A
)
linearMixedBit
(
A
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
(
print
(
f
"bnb linear8bitlt with threshold (eval): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
f
"bnb linear8bitlt with threshold: [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt (training): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
linear8bit_train_thresh
(
A
)
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
for
i
in
range
(
iters
):
linear8bit_train
(
A
)
torch
.
cuda
.
synchronize
()
print
(
f
"bnb linear8bitlt with threshold (training): [
{
batch
}
,
{
seq
}
,
{
model
}
], [
{
model
}
,
{
hidden
}
]->[
{
batch
}
,
{
seq
}
,
{
hidden
}
]:
{
time
.
time
()
-
t0
:.
4
f
}
s"
)
def
test_zeropoint
():
def
test_zeropoint
():
def
quant_zp
(
x
):
def
quant_zp
(
x
):
...
@@ -2050,7 +2066,6 @@ def test_fp8_quant():
...
@@ -2050,7 +2066,6 @@ def test_fp8_quant():
p_bits
=
7
-
e_bits
p_bits
=
7
-
e_bits
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
code
=
F
.
create_fp8_map
(
True
,
e_bits
,
p_bits
).
cuda
()
print
(
e_bits
,
p_bits
)
abserr
=
[]
abserr
=
[]
relerr
=
[]
relerr
=
[]
for
i
in
range
(
100
):
for
i
in
range
(
100
):
...
@@ -2189,7 +2204,6 @@ def test_bench_dequantization():
...
@@ -2189,7 +2204,6 @@ def test_bench_dequantization():
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
t0
=
time
.
time
()
t0
=
time
.
time
()
for
i
in
range
(
100
):
for
i
in
range
(
100
):
#F.dequantize_blockwise(qa, SA, blocksize=2048)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
#print((time.time()-t0)/1e6)
...
@@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant():
...
@@ -2240,7 +2254,7 @@ def test_bench_fp4_dequant():
num_bytes
=
input_size
+
output_size
num_bytes
=
input_size
+
output_size
GB
=
num_bytes
/
1e9
GB
=
num_bytes
/
1e9
max_theoretical_s
=
GB
/
768
max_theoretical_s
=
GB
/
768
print
(
max_theoretical_s
*
1e6
)
#
print(max_theoretical_s*1e6)
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
b
=
torch
.
randn
(
128
,
1024
*
12
,
device
=
'cuda'
).
half
()
iters
=
5
iters
=
5
...
@@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant():
...
@@ -2250,14 +2264,14 @@ def test_bench_fp4_dequant():
F
.
dequantize_fp4
(
qa
,
SA
,
blocksize
=
blocksize
)
F
.
dequantize_fp4
(
qa
,
SA
,
blocksize
=
blocksize
)
#b.copy_(a)
#b.copy_(a)
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
print
((
time
.
time
()
-
t0
)
/
iters
*
1e6
)
#
print((time.time()-t0)/iters*1e6)
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
t0
=
time
.
time
()
#
t0 = time.time()
for
i
in
range
(
iters
):
#
for i in range(iters):
torch
.
matmul
(
b
,
a
.
t
())
#
torch.matmul(b, a.t())
torch
.
cuda
.
synchronize
()
#
torch.cuda.synchronize()
print
((
time
.
time
()
-
t0
)
/
iters
*
1e6
)
#
print((time.time()-t0)/iters*1e6)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment