Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c9f50506
Commit
c9f50506
authored
Jan 28, 2023
by
Tim Dettmers
Browse files
Added outlier detector and fake quantization layer.
parent
1341fb44
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
225 additions
and
5 deletions
+225
-5
bitsandbytes/functional.py
bitsandbytes/functional.py
+3
-3
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+78
-0
bitsandbytes/utils.py
bitsandbytes/utils.py
+136
-0
csrc/kernels.cu
csrc/kernels.cu
+2
-0
tests/test_functional.py
tests/test_functional.py
+5
-1
No files found.
bitsandbytes/functional.py
View file @
c9f50506
...
...
@@ -168,7 +168,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
=
[]
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
#for ev in evalues:
bias
=
2
**
(
exponent_bits
-
1
)
-
1
bias
=
2
**
(
exponent_bits
-
1
)
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
bit_pattern
in
lst
:
value
=
(
1
if
evalue
!=
0
else
0
)
...
...
@@ -176,10 +176,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value
+=
pval
*
(
2
**-
(
i
+
1
))
if
evalue
==
0
:
# subnormals
value
=
value
*
2
**-
(
bias
-
1
)
value
=
value
*
2
**-
(
bias
)
else
:
# normals
value
=
value
*
2
**-
(
evalue
-
bias
-
2
)
value
=
value
*
2
**-
(
evalue
-
bias
-
1
)
values
.
append
(
value
)
if
signed
:
values
.
append
(
-
value
)
...
...
bitsandbytes/nn/__init__.py
View file @
c9f50506
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
bitsandbytes/nn/modules.py
View file @
c9f50506
...
...
@@ -10,6 +10,7 @@ from torch import Tensor, device, dtype, nn
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
...
...
@@ -133,6 +134,83 @@ class Embedding(torch.nn.Embedding):
return
emb
class
OutlierAwareLinear
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
outlier_dim
=
None
self
.
is_quantized
=
False
def
forward_with_outliers
(
self
,
x
,
outlier_idx
):
raise
NotImplementedError
(
'Please override the `forward_with_outliers(self, x, outlier_idx)` function'
)
def
quantize_weight
(
self
,
w
,
outlier_idx
):
raise
NotImplementedError
(
'Please override the `quantize_weights(self, w, outlier_idx)` function'
)
def
forward
(
self
,
x
):
if
self
.
outlier_dim
is
None
:
tracer
=
OutlierTracer
.
get_instance
()
if
not
tracer
.
is_initialized
():
print
(
'Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer'
)
outlier_idx
=
tracer
.
get_outliers
(
self
.
weight
)
#print(outlier_idx, tracer.get_hvalue(self.weight))
self
.
outlier_dim
=
outlier_idx
if
not
self
.
is_quantized
:
w
=
self
.
quantize_weight
(
self
.
weight
,
self
.
outlier_dim
)
self
.
weight
.
data
.
copy_
(
w
)
self
.
is_quantized
=
True
return
self
.
forward_with_outliers
(
x
,
self
.
outlier_dim
)
class
Fake4bitLinear
(
OutlierAwareLinear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
codebook
=
bnb
.
functional
.
create_fp8_map
(
True
,
3
,
0
,
total_bits
=
4
)):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
codebook
=
codebook
def
quantize_weight
(
self
,
w
,
outlier_idx
):
if
outlier_idx
.
numel
()
>
0
:
subw
=
w
[:,
outlier_idx
].
clone
()
w
[:,
outlier_idx
]
=
0
wdtype
=
w
.
dtype
code
=
self
.
codebook
.
to
(
w
.
device
)
cw
,
state
=
bnb
.
functional
.
quantize_blockwise
(
w
,
code
=
code
,
blocksize
=
64
)
w
=
bnb
.
functional
.
dequantize_blockwise
(
cw
,
state
,
blocksize
=
64
)
w
=
w
.
to
(
wdtype
)
if
outlier_idx
.
numel
()
>
0
:
w
[:,
outlier_idx
]
=
subw
self
.
is_quantized
=
True
return
w
def
forward_with_outliers
(
self
,
x
,
outlier_idx
):
dims
=
torch
.
abs
(
x
>
4
).
sum
(
dim
=
list
(
range
(
len
(
x
.
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
n
=
x
.
shape
[
-
1
]
idx
=
torch
.
arange
(
n
,
device
=
x
.
device
)
idx
[
outlier_idx
]
=
-
1
inverse_idx
=
torch
.
where
(
idx
>=
0
)[
0
]
if
outlier_idx
.
numel
()
>
0
:
subx
=
x
[...,
outlier_idx
].
clone
()
#print(1, subx, 1)
#x[..., outlier_idx] = 0
inverse_x
=
x
[...,
inverse_idx
]
xdtype
=
x
.
dtype
#code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device)
#code = bnb.functional.create_quantile_map(x, 4).to(x.device)
code
=
bnb
.
functional
.
create_dynamic_map
(
True
,
total_bits
=
4.0
).
to
(
x
.
device
)
c
,
state
=
bnb
.
functional
.
quantize_blockwise
(
inverse_x
,
code
=
code
,
blocksize
=
64
)
inverse_x
=
bnb
.
functional
.
dequantize_blockwise
(
c
,
state
,
blocksize
=
64
)
#c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64)
#x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
x
=
x
.
to
(
xdtype
)
x
[...,
inverse_idx
]
=
inverse_x
.
to
(
x
.
dtype
)
#if outlier_idx.numel() > 0:
#x[..., outlier_idx] = subx
return
torch
.
nn
.
functional
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
...
...
bitsandbytes/utils.py
View file @
c9f50506
import
shlex
import
subprocess
import
torch
from
typing
import
Tuple
def
outlier_hook
(
module
,
input
):
assert
isinstance
(
module
,
torch
.
nn
.
Linear
)
tracer
=
OutlierTracer
.
get_instance
()
hvalue
=
tracer
.
get_hvalue
(
module
.
weight
)
if
hvalue
not
in
tracer
.
hvalue2outlier_idx
:
outlier_idx
=
find_outlier_dims
(
module
.
weight
)
tracer
.
outliers
.
append
(
outlier_idx
)
tracer
.
hvalues
.
append
(
hvalue
)
if
len
(
tracer
.
outliers
)
>
1
:
# assign the current layer the outlier idx found from the weight
# of the previous linear layer
if
tracer
.
outliers
[
-
1
].
numel
()
>
0
:
assert
tracer
.
outliers
[
-
1
].
max
()
<
module
.
weight
.
shape
[
1
]
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
tracer
.
outliers
[
-
1
]
else
:
# first layer, we cannot use the weight for outlier detection
# we follow a mixed approach:
# (1) zscore test of std of hidden dimension
# (2) magnitude > 6 test
merged
=
input
[
0
].
view
(
-
1
,
input
[
0
].
shape
[
-
1
])
# (1) zscore test of std of hidden dimension
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
# (2) magnitude > 6 test
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
else
:
for
hook
in
tracer
.
hooks
:
hook
.
remove
()
class
OutlierTracer
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
"Call get_instance() instead"
)
def
initialize
(
self
,
model
):
self
.
last_w
=
None
self
.
current_outlier_dims
=
None
self
.
hvalues
=
[]
self
.
outliers
=
[]
self
.
hvalue2outlier_idx
=
{}
self
.
initialized
=
True
self
.
hooks
=
[]
for
n
,
m
in
model
.
named_modules
():
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
def
is_initialized
(
self
):
return
getattr
(
self
,
'initialized'
,
False
)
def
get_hvalue
(
self
,
weight
):
return
weight
.
data
.
storage
().
data_ptr
()
def
get_outliers
(
self
,
weight
):
if
not
self
.
is_initialized
():
print
(
'Outlier tracer is not initialized...'
)
return
None
hvalue
=
self
.
get_hvalue
(
weight
)
if
hvalue
in
self
.
hvalue2outlier_idx
:
return
self
.
hvalue2outlier_idx
[
hvalue
]
else
:
return
None
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
return
cls
.
_instance
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
if
rdm
:
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
m
=
weight
.
mean
(
reduction_dim
)
mm
=
m
.
mean
()
mstd
=
m
.
std
()
zm
=
(
m
-
mm
)
/
mstd
std
=
weight
.
std
(
reduction_dim
)
stdm
=
std
.
mean
()
stdstd
=
std
.
std
()
zstd
=
(
std
-
stdm
)
/
stdstd
if
topk
is
not
None
:
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
else
:
idx
=
torch
.
where
(
zstd
>
zscore
)[
0
]
return
idx
def
replace_linear
(
model
,
linear_replacement
,
skip_modules
=
[
"lm_head"
],
copy_weights
=
False
,
post_processing_function
=
None
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_linear
(
module
,
linear_replacement
,
skip_modules
,
copy_weights
,
post_processing_function
)
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
name
not
in
skip_modules
:
old_module
=
model
.
_modules
[
name
]
model
.
_modules
[
name
]
=
linear_replacement
(
module
.
in_features
,
module
.
out_features
,
module
.
bias
is
not
None
,
)
if
copy_weights
:
model
.
_modules
[
name
].
weight
=
old_module
.
weight
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
return
model
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
...
...
csrc/kernels.cu
View file @
c9f50506
...
...
@@ -543,7 +543,9 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
}
__syncthreads
();
StoreT
(
storet
).
Store
(
&
(
out
[
i
]),
vals
,
valid_items
);
...
...
tests/test_functional.py
View file @
c9f50506
...
...
@@ -2109,6 +2109,7 @@ def test_few_bit_quant():
ebits
=
math
.
ceil
(
bits
/
2
)
pbits
=
bits
-
ebits
-
1
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
print
(
code
)
elif
method
==
'dynamic'
:
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
elif
method
==
'quantile'
:
...
...
@@ -2181,7 +2182,9 @@ def test_kbit_quantile_estimation():
def
test_bench_dequantization
():
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
print
(
qa
.
max
())
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
#print(max_theoretical_mu)
...
...
@@ -2193,3 +2196,4 @@ def test_bench_dequantization():
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment