Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dd562c24
Commit
dd562c24
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Refactored simulated fp8 modules into research.nn.
parent
e67bfccb
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
108 additions
and
272 deletions
+108
-272
benchmarking/switchback/README.md
benchmarking/switchback/README.md
+0
-0
benchmarking/switchback/info_a100_py2.jsonl
benchmarking/switchback/info_a100_py2.jsonl
+0
-0
benchmarking/switchback/make_plot_with_jsonl.py
benchmarking/switchback/make_plot_with_jsonl.py
+0
-0
benchmarking/switchback/plot_with_info.pdf
benchmarking/switchback/plot_with_info.pdf
+0
-0
benchmarking/switchback/speed_benchmark.py
benchmarking/switchback/speed_benchmark.py
+0
-0
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-175
bitsandbytes/research/__init__.py
bitsandbytes/research/__init__.py
+1
-2
bitsandbytes/research/autograd/_functions.py
bitsandbytes/research/autograd/_functions.py
+8
-90
bitsandbytes/research/nn/__init__.py
bitsandbytes/research/nn/__init__.py
+1
-0
bitsandbytes/research/nn/modules.py
bitsandbytes/research/nn/modules.py
+64
-0
examples/int8_inference_huggingface.py
examples/int8_inference_huggingface.py
+27
-0
tests/test_autograd.py
tests/test_autograd.py
+2
-2
tests/test_functional.py
tests/test_functional.py
+1
-0
tests/test_modules.py
tests/test_modules.py
+2
-2
No files found.
speed_
benchmark/README.md
→
benchmark
ing/switchback
/README.md
View file @
dd562c24
File moved
speed_
benchmark/info_a100_py2.jsonl
→
benchmark
ing/switchback
/info_a100_py2.jsonl
View file @
dd562c24
File moved
speed_
benchmark/make_plot_with_jsonl.py
→
benchmark
ing/switchback
/make_plot_with_jsonl.py
View file @
dd562c24
File moved
speed_
benchmark/plot_with_info.pdf
→
benchmark
ing/switchback
/plot_with_info.pdf
View file @
dd562c24
File moved
speed_
benchmark/speed_benchmark.py
→
benchmark
ing/switchback
/speed_benchmark.py
View file @
dd562c24
File moved
bitsandbytes/nn/__init__.py
View file @
dd562c24
...
@@ -2,5 +2,5 @@
...
@@ -2,5 +2,5 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
LinearFP8
,
LinearInt8
,
Linear8bitLtThresh
,
LinearInt8Cast
,
Linear8bitLtMixed
,
LinearFP8Global
,
LinearFP4
,
LinearFP8Mixed
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
SwitchBackLinearBnb
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVectorized
,
StandardLinear
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVectorized
,
StandardLinear
bitsandbytes/nn/modules.py
View file @
dd562c24
...
@@ -297,7 +297,7 @@ class Linear8bitLt(nn.Linear):
...
@@ -297,7 +297,7 @@ class Linear8bitLt(nn.Linear):
return
out
return
out
class
Linear8bitLtMixed
(
nn
.
Linear
):
class
SwitchBackLinearBnb
(
nn
.
Linear
):
def
__init__
(
def
__init__
(
self
,
self
,
input_features
,
input_features
,
...
@@ -355,177 +355,3 @@ class Linear8bitLtMixed(nn.Linear):
...
@@ -355,177 +355,3 @@ class Linear8bitLtMixed(nn.Linear):
del
self
.
state
.
CxB
del
self
.
state
.
CxB
return
out
return
out
class
Linear8bitLtThresh
(
Linear8bitLt
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
has_fp16_weights
=
True
,
memory_efficient_backward
=
False
,
threshold
=
6.0
,
index
=
None
,
):
super
().
__init__
(
input_features
,
output_features
,
bias
=
bias
,
has_fp16_weights
=
has_fp16_weights
,
memory_efficient_backward
=
memory_efficient_backward
,
threshold
=
6.
,
index
=
index
)
class
LinearFP8
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
research
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP8Mixed
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
research
.
matmul_fp8_mixed
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP8Global
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8_global
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearInt8
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
# This is 4 bit version.
class
LinearInt8Cast
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
code
is
None
:
self
.
code
=
bnb
.
functional
.
create_linear_map
(
True
,
4
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
code
,
bw_code
=
self
.
code
,
bsz
=
self
.
bsz
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP4
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
#self.bw_code = bnb.functional.create_fp8_map(True, 3, 0, 4).to(x.device)
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
3
,
0
,
4
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp4
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
bitsandbytes/research/__init__.py
View file @
dd562c24
from
.
import
nn
from
.autograd._functions
import
(
from
.autograd._functions
import
(
matmul_fp8
,
switchback_bnb
,
switchback_bnb
,
matmul_fp8_global
,
matmul_fp8_global
,
matmul_fp8_mixed
,
matmul_fp8_mixed
,
...
...
bitsandbytes/research/autograd/_functions.py
View file @
dd562c24
...
@@ -16,88 +16,6 @@ def prod(iterable):
...
@@ -16,88 +16,6 @@ def prod(iterable):
tensor
=
torch
.
Tensor
tensor
=
torch
.
Tensor
class
MatMulFP8
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
fw_code
=
None
,
bw_code
=
None
,
bsz
=
1024
,
bsz2
=
1024
):
# default of pytorch behavior if inputs are empty
ctx
.
is_empty
=
False
if
prod
(
A
.
shape
)
==
0
:
ctx
.
is_empty
=
True
ctx
.
A
=
A
ctx
.
B
=
B
B_shape
=
B
.
shape
if
A
.
shape
[
-
1
]
==
B_shape
[
0
]:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[
1
:],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B_shape
[:
1
],
dtype
=
A
.
dtype
,
device
=
A
.
device
)
# 1. Dequantize
# 2. MatmulnN
cA
,
state
=
F
.
quantize_blockwise
(
A
,
code
=
fw_code
,
blocksize
=
bsz
)
fp8A
=
F
.
dequantize_blockwise
(
cA
,
state
,
blocksize
=
bsz
).
to
(
A
.
dtype
)
cB
,
state
=
F
.
quantize
(
B
.
float
(),
code
=
fw_code
)
fp8B
=
F
.
dequantize
(
cB
,
state
).
to
(
B
.
dtype
)
output
=
torch
.
matmul
(
fp8A
,
fp8B
)
# output is half
# 3. Save state
ctx
.
fw_code
=
fw_code
ctx
.
bw_code
=
bw_code
ctx
.
bsz
=
bsz
ctx
.
bsz2
=
bsz2
ctx
.
dtype_A
,
ctx
.
dtype_B
=
A
.
dtype
,
B
.
dtype
if
any
(
ctx
.
needs_input_grad
[:
2
]):
# NOTE: we send back A, and re-quant.
ctx
.
tensors
=
(
A
,
fp8B
)
else
:
ctx
.
tensors
=
(
None
,
None
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
if
ctx
.
is_empty
:
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
None
,
None
,
None
,
None
req_gradA
,
req_gradB
,
_
,
_
,
_
,
_
,
_
=
ctx
.
needs_input_grad
A
,
B
=
ctx
.
tensors
grad_A
,
grad_B
=
None
,
None
cgrad_out
,
state
=
F
.
quantize_blockwise
(
grad_output
,
code
=
ctx
.
bw_code
,
blocksize
=
ctx
.
bsz2
)
fp8out
=
F
.
dequantize_blockwise
(
cgrad_out
,
state
,
blocksize
=
ctx
.
bsz2
).
to
(
grad_output
.
dtype
)
cgrad_output_2
,
state_2
=
F
.
quantize
(
grad_output
.
float
(),
code
=
ctx
.
bw_code
)
fp8out_2
=
F
.
dequantize
(
cgrad_output_2
,
state_2
).
to
(
grad_output
.
dtype
)
# grad_output_reshape = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()
# fp8grad_transpose, stategrad_transpose = F.vectorwise_quant(grad_output_reshape, dim=0, quant_type='vector')
# fp8out_transpose = (fp8grad_transpose / 7) * stategrad_transpose
# fp8out_transpose = fp8out_transpose.view(grad_output.shape[0], grad_output.shape[1], grad_output.shape[2])
# not supported by PyTorch. TODO: create work-around
if
req_gradA
:
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
if
len
(
A
.
shape
)
==
3
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
else
:
At
=
A
.
transpose
(
1
,
0
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out_2
.
dtype
),
fp8out_2
).
to
(
B
.
dtype
)
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMulFP8Mixed
(
torch
.
autograd
.
Function
):
class
MatMulFP8Mixed
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
...
@@ -171,7 +89,10 @@ class MatMulFP8Mixed(torch.autograd.Function):
...
@@ -171,7 +89,10 @@ class MatMulFP8Mixed(torch.autograd.Function):
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
else
:
At
=
A
.
transpose
(
1
,
0
).
contiguous
()
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# cA, state = F.quantize(At.float(), code=ctx.fw_code)
# fp8At = F.dequantize(cA, state).to(A.dtype)
# fp8At = F.dequantize(cA, state).to(A.dtype)
grad_B
=
torch
.
matmul
(
At
.
to
(
grad_output
.
dtype
),
grad_output
).
to
(
B
.
dtype
)
grad_B
=
torch
.
matmul
(
At
.
to
(
grad_output
.
dtype
),
grad_output
).
to
(
B
.
dtype
)
...
@@ -252,7 +173,10 @@ class MatMulFP8Global(torch.autograd.Function):
...
@@ -252,7 +173,10 @@ class MatMulFP8Global(torch.autograd.Function):
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
grad_A
=
torch
.
matmul
(
fp8out
,
B
.
t
().
to
(
fp8out
.
dtype
)).
to
(
A
.
dtype
)
if
req_gradB
:
if
req_gradB
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
if
len
(
A
.
shape
)
==
3
:
At
=
A
.
transpose
(
2
,
1
).
contiguous
()
else
:
At
=
A
.
transpose
(
1
,
0
).
contiguous
()
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
cA
,
state
=
F
.
quantize
(
At
.
float
(),
code
=
ctx
.
fw_code
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
fp8At
=
F
.
dequantize
(
cA
,
state
).
to
(
A
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
grad_B
=
torch
.
matmul
(
fp8At
.
to
(
fp8out
.
dtype
),
fp8out
).
to
(
B
.
dtype
)
...
@@ -465,11 +389,6 @@ def get_block_sizes(input_matrix, weight_matrix):
...
@@ -465,11 +389,6 @@ def get_block_sizes(input_matrix, weight_matrix):
return
bsz
,
bsz2
return
bsz
,
bsz2
def
matmul_fp8
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
matmul_fp8_global
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
def
matmul_fp8_global
(
A
:
tensor
,
B
:
tensor
,
fw_code
:
tensor
,
bw_code
:
tensor
,
out
:
tensor
=
None
,
bsz
:
int
=
-
1
,
bsz2
:
int
=
-
1
):
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8Global
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
return
MatMulFP8Global
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
...
@@ -478,7 +397,6 @@ def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out
...
@@ -478,7 +397,6 @@ def matmul_fp8_mixed(A: tensor, B: tensor, fw_code: tensor, bw_code: tensor, out
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
if
bsz
==
-
1
or
bsz2
==
-
1
:
bsz
,
bsz2
=
get_block_sizes
(
A
,
B
)
return
MatMulFP8Mixed
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
return
MatMulFP8Mixed
.
apply
(
A
,
B
,
out
,
fw_code
,
bw_code
,
bsz
,
bsz2
)
def
switchback_bnb
(
def
switchback_bnb
(
A
:
tensor
,
A
:
tensor
,
B
:
tensor
,
B
:
tensor
,
...
...
bitsandbytes/research/nn/__init__.py
0 → 100644
View file @
dd562c24
from
.modules
import
LinearFP8Mixed
,
LinearFP8Global
bitsandbytes/research/nn/modules.py
0 → 100644
View file @
dd562c24
from
typing
import
Optional
,
TypeVar
,
Union
,
overload
import
torch
import
torch.nn.functional
as
F
from
torch
import
Tensor
,
device
,
dtype
,
nn
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
class
LinearFP8Mixed
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
research
.
matmul_fp8_mixed
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
class
LinearFP8Global
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
bw_code
=
None
self
.
fw_code
=
None
array
=
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
,
0
]
for
i
,
k
in
enumerate
(
array
):
if
input_features
>
array
[
i
+
1
]:
self
.
bsz
=
k
break
for
i
,
k
in
enumerate
(
array
):
if
output_features
>
array
[
i
+
1
]:
self
.
bsz2
=
k
break
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8_global
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_code
=
self
.
bw_code
,
bsz
=
self
.
bsz
,
bsz2
=
self
.
bsz2
)
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
return
out
examples/int8_inference_huggingface.py
0 → 100644
View file @
dd562c24
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
MAX_NEW_TOKENS
=
128
model_name
=
'decapoda-research/llama-7b-hf'
text
=
'Hamburg is in which country?
\n
'
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
input_ids
=
tokenizer
(
text
,
return_tensors
=
"pt"
).
input_ids
free_in_GB
=
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
max_memory
=
f
'
{
int
(
torch
.
cuda
.
mem_get_info
()[
0
]
/
1024
**
3
)
-
2
}
GB'
n_gpus
=
torch
.
cuda
.
device_count
()
max_memory
=
{
i
:
max_memory
for
i
in
range
(
n_gpus
)}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
device_map
=
'auto'
,
load_in_8bit
=
True
,
max_memory
=
max_memory
)
generated_ids
=
model
.
generate
(
input_ids
,
max_length
=
MAX_NEW_TOKENS
)
print
(
tokenizer
.
decode
(
generated_ids
[
0
],
skip_special_tokens
=
True
))
tests/test_autograd.py
View file @
dd562c24
...
@@ -441,8 +441,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
...
@@ -441,8 +441,8 @@ dim4 = torch.randint(32, 96, size=(n,)).tolist()
dim2
.
append
(
0
)
dim2
.
append
(
0
)
funcs
=
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8
)]
funcs
=
[(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8
_mixed
),
(
torch
.
matmul
,
bnb
.
research
.
matmul_fp8_global
)]
str_funcs
=
[
"matmul
"
]
str_funcs
=
[
"matmul
_fp8_mixed"
,
'matmul_fp8_global'
]
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad
=
list
(
product
([
True
,
False
],
repeat
=
3
))
req_grad_str
=
[]
req_grad_str
=
[]
for
c
in
req_grad
:
for
c
in
req_grad
:
...
...
tests/test_functional.py
View file @
dd562c24
...
@@ -190,6 +190,7 @@ def test_dynamic_blockwise_quantization():
...
@@ -190,6 +190,7 @@ def test_dynamic_blockwise_quantization():
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
parametrize
(
"blocksize"
,
[
4096
,
2048
,
1024
,
512
,
256
,
128
,
64
])
@
pytest
.
mark
.
skip
(
"Stochastic has some bugs, but will be deprecated soon anyways."
)
def
test_dynamic_blockwise_stochastic_quantization
(
blocksize
):
def
test_dynamic_blockwise_stochastic_quantization
(
blocksize
):
diffs
=
[]
diffs
=
[]
reldiffs
=
[]
reldiffs
=
[]
...
...
tests/test_modules.py
View file @
dd562c24
...
@@ -532,9 +532,9 @@ def test_fp8linear():
...
@@ -532,9 +532,9 @@ def test_fp8linear():
h
=
1024
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
nn
.
LinearFP8
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
research
.
nn
.
LinearFP8
Mixed
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
nn
.
LinearFP8
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
research
.
nn
.
LinearFP8
Mixed
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment