Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c9f50506
"tests/vscode:/vscode.git/clone" did not exist on "17f1432ab2c74bed54df863be48e23b4113cbb37"
Commit
c9f50506
authored
Jan 28, 2023
by
Tim Dettmers
Browse files
Added outlier detector and fake quantization layer.
parent
1341fb44
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
225 additions
and
5 deletions
+225
-5
bitsandbytes/functional.py
bitsandbytes/functional.py
+3
-3
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+78
-0
bitsandbytes/utils.py
bitsandbytes/utils.py
+136
-0
csrc/kernels.cu
csrc/kernels.cu
+2
-0
tests/test_functional.py
tests/test_functional.py
+5
-1
No files found.
bitsandbytes/functional.py
View file @
c9f50506
...
...
@@ -168,7 +168,7 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
values
=
[]
lst
=
list
(
itertools
.
product
([
0
,
1
],
repeat
=
precision_bits
))
#for ev in evalues:
bias
=
2
**
(
exponent_bits
-
1
)
-
1
bias
=
2
**
(
exponent_bits
-
1
)
for
evalue
in
range
(
2
**
(
exponent_bits
)):
for
bit_pattern
in
lst
:
value
=
(
1
if
evalue
!=
0
else
0
)
...
...
@@ -176,10 +176,10 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
value
+=
pval
*
(
2
**-
(
i
+
1
))
if
evalue
==
0
:
# subnormals
value
=
value
*
2
**-
(
bias
-
1
)
value
=
value
*
2
**-
(
bias
)
else
:
# normals
value
=
value
*
2
**-
(
evalue
-
bias
-
2
)
value
=
value
*
2
**-
(
evalue
-
bias
-
1
)
values
.
append
(
value
)
if
signed
:
values
.
append
(
-
value
)
...
...
bitsandbytes/nn/__init__.py
View file @
c9f50506
...
...
@@ -2,4 +2,4 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
Fake4bitLinear
bitsandbytes/nn/modules.py
View file @
c9f50506
...
...
@@ -10,6 +10,7 @@ from torch import Tensor, device, dtype, nn
import
bitsandbytes
as
bnb
from
bitsandbytes.optim
import
GlobalOptimManager
from
bitsandbytes.utils
import
OutlierTracer
,
find_outlier_dims
T
=
TypeVar
(
"T"
,
bound
=
"torch.nn.Module"
)
...
...
@@ -133,6 +134,83 @@ class Embedding(torch.nn.Embedding):
return
emb
class
OutlierAwareLinear
(
nn
.
Linear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
outlier_dim
=
None
self
.
is_quantized
=
False
def
forward_with_outliers
(
self
,
x
,
outlier_idx
):
raise
NotImplementedError
(
'Please override the `forward_with_outliers(self, x, outlier_idx)` function'
)
def
quantize_weight
(
self
,
w
,
outlier_idx
):
raise
NotImplementedError
(
'Please override the `quantize_weights(self, w, outlier_idx)` function'
)
def
forward
(
self
,
x
):
if
self
.
outlier_dim
is
None
:
tracer
=
OutlierTracer
.
get_instance
()
if
not
tracer
.
is_initialized
():
print
(
'Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer'
)
outlier_idx
=
tracer
.
get_outliers
(
self
.
weight
)
#print(outlier_idx, tracer.get_hvalue(self.weight))
self
.
outlier_dim
=
outlier_idx
if
not
self
.
is_quantized
:
w
=
self
.
quantize_weight
(
self
.
weight
,
self
.
outlier_dim
)
self
.
weight
.
data
.
copy_
(
w
)
self
.
is_quantized
=
True
return
self
.
forward_with_outliers
(
x
,
self
.
outlier_dim
)
class
Fake4bitLinear
(
OutlierAwareLinear
):
def
__init__
(
self
,
input_features
,
output_features
,
bias
=
True
,
codebook
=
bnb
.
functional
.
create_fp8_map
(
True
,
3
,
0
,
total_bits
=
4
)):
super
().
__init__
(
input_features
,
output_features
,
bias
)
self
.
codebook
=
codebook
def
quantize_weight
(
self
,
w
,
outlier_idx
):
if
outlier_idx
.
numel
()
>
0
:
subw
=
w
[:,
outlier_idx
].
clone
()
w
[:,
outlier_idx
]
=
0
wdtype
=
w
.
dtype
code
=
self
.
codebook
.
to
(
w
.
device
)
cw
,
state
=
bnb
.
functional
.
quantize_blockwise
(
w
,
code
=
code
,
blocksize
=
64
)
w
=
bnb
.
functional
.
dequantize_blockwise
(
cw
,
state
,
blocksize
=
64
)
w
=
w
.
to
(
wdtype
)
if
outlier_idx
.
numel
()
>
0
:
w
[:,
outlier_idx
]
=
subw
self
.
is_quantized
=
True
return
w
def
forward_with_outliers
(
self
,
x
,
outlier_idx
):
dims
=
torch
.
abs
(
x
>
4
).
sum
(
dim
=
list
(
range
(
len
(
x
.
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
n
=
x
.
shape
[
-
1
]
idx
=
torch
.
arange
(
n
,
device
=
x
.
device
)
idx
[
outlier_idx
]
=
-
1
inverse_idx
=
torch
.
where
(
idx
>=
0
)[
0
]
if
outlier_idx
.
numel
()
>
0
:
subx
=
x
[...,
outlier_idx
].
clone
()
#print(1, subx, 1)
#x[..., outlier_idx] = 0
inverse_x
=
x
[...,
inverse_idx
]
xdtype
=
x
.
dtype
#code = bnb.functional.create_fp8_map(True, 4-3, 2, 4).to(x.device)
#code = bnb.functional.create_quantile_map(x, 4).to(x.device)
code
=
bnb
.
functional
.
create_dynamic_map
(
True
,
total_bits
=
4.0
).
to
(
x
.
device
)
c
,
state
=
bnb
.
functional
.
quantize_blockwise
(
inverse_x
,
code
=
code
,
blocksize
=
64
)
inverse_x
=
bnb
.
functional
.
dequantize_blockwise
(
c
,
state
,
blocksize
=
64
)
#c, state = bnb.functional.quantize_blockwise(x, code=code, blocksize=64)
#x = bnb.functional.dequantize_blockwise(c, state, blocksize=64)
x
=
x
.
to
(
xdtype
)
x
[...,
inverse_idx
]
=
inverse_x
.
to
(
x
.
dtype
)
#if outlier_idx.numel() > 0:
#x[..., outlier_idx] = subx
return
torch
.
nn
.
functional
.
linear
(
x
,
self
.
weight
,
self
.
bias
)
class
Int8Params
(
torch
.
nn
.
Parameter
):
def
__new__
(
...
...
bitsandbytes/utils.py
View file @
c9f50506
import
shlex
import
subprocess
import
torch
from
typing
import
Tuple
def
outlier_hook
(
module
,
input
):
assert
isinstance
(
module
,
torch
.
nn
.
Linear
)
tracer
=
OutlierTracer
.
get_instance
()
hvalue
=
tracer
.
get_hvalue
(
module
.
weight
)
if
hvalue
not
in
tracer
.
hvalue2outlier_idx
:
outlier_idx
=
find_outlier_dims
(
module
.
weight
)
tracer
.
outliers
.
append
(
outlier_idx
)
tracer
.
hvalues
.
append
(
hvalue
)
if
len
(
tracer
.
outliers
)
>
1
:
# assign the current layer the outlier idx found from the weight
# of the previous linear layer
if
tracer
.
outliers
[
-
1
].
numel
()
>
0
:
assert
tracer
.
outliers
[
-
1
].
max
()
<
module
.
weight
.
shape
[
1
]
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
tracer
.
outliers
[
-
1
]
else
:
# first layer, we cannot use the weight for outlier detection
# we follow a mixed approach:
# (1) zscore test of std of hidden dimension
# (2) magnitude > 6 test
merged
=
input
[
0
].
view
(
-
1
,
input
[
0
].
shape
[
-
1
])
# (1) zscore test of std of hidden dimension
outlier_idx
=
find_outlier_dims
(
merged
,
reduction_dim
=
1
,
zscore
=
3
)
# (2) magnitude > 6 test
dims
=
(
torch
.
abs
(
input
[
0
])
>
6
).
sum
(
dim
=
list
(
range
(
len
(
input
[
0
].
shape
)
-
1
)))
outlier_idx2
=
torch
.
where
(
dims
>
0
)[
0
]
outlier_idx
=
torch
.
cat
([
outlier_idx
,
outlier_idx2
]).
unique
()
tracer
.
hvalue2outlier_idx
[
hvalue
]
=
outlier_idx
else
:
for
hook
in
tracer
.
hooks
:
hook
.
remove
()
class
OutlierTracer
(
object
):
_instance
=
None
def
__init__
(
self
):
raise
RuntimeError
(
"Call get_instance() instead"
)
def
initialize
(
self
,
model
):
self
.
last_w
=
None
self
.
current_outlier_dims
=
None
self
.
hvalues
=
[]
self
.
outliers
=
[]
self
.
hvalue2outlier_idx
=
{}
self
.
initialized
=
True
self
.
hooks
=
[]
for
n
,
m
in
model
.
named_modules
():
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
self
.
hooks
.
append
(
m
.
register_forward_pre_hook
(
outlier_hook
))
def
is_initialized
(
self
):
return
getattr
(
self
,
'initialized'
,
False
)
def
get_hvalue
(
self
,
weight
):
return
weight
.
data
.
storage
().
data_ptr
()
def
get_outliers
(
self
,
weight
):
if
not
self
.
is_initialized
():
print
(
'Outlier tracer is not initialized...'
)
return
None
hvalue
=
self
.
get_hvalue
(
weight
)
if
hvalue
in
self
.
hvalue2outlier_idx
:
return
self
.
hvalue2outlier_idx
[
hvalue
]
else
:
return
None
@
classmethod
def
get_instance
(
cls
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
cls
.
__new__
(
cls
)
return
cls
.
_instance
def
find_outlier_dims
(
weight
,
reduction_dim
=
0
,
zscore
=
4.0
,
topk
=
None
,
rdm
=
False
):
if
rdm
:
return
torch
.
randint
(
0
,
weight
.
shape
[
1
],
size
=
(
topk
,),
device
=
weight
.
device
).
long
()
m
=
weight
.
mean
(
reduction_dim
)
mm
=
m
.
mean
()
mstd
=
m
.
std
()
zm
=
(
m
-
mm
)
/
mstd
std
=
weight
.
std
(
reduction_dim
)
stdm
=
std
.
mean
()
stdstd
=
std
.
std
()
zstd
=
(
std
-
stdm
)
/
stdstd
if
topk
is
not
None
:
val
,
idx
=
torch
.
topk
(
std
.
abs
(),
k
=
topk
,
dim
=
0
)
else
:
idx
=
torch
.
where
(
zstd
>
zscore
)[
0
]
return
idx
def
replace_linear
(
model
,
linear_replacement
,
skip_modules
=
[
"lm_head"
],
copy_weights
=
False
,
post_processing_function
=
None
):
"""
Replace linear modules with a new Linear module.
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
linear_replacement (`torch.nn.Module`):
The linear module that replaces the old one. Only expects standard arguments.
If other arguments need to be passed, use a lambda.
skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
List of modules names not to convert. Defaults to `lm_head`.
copy_weights (`bool`):
Copy the weights from the old linear module to the new one
post_processing_fun_name (`str`):
A function name of the replacement linear class that is called
after processing.
"""
for
name
,
module
in
model
.
named_children
():
if
len
(
list
(
module
.
children
()))
>
0
:
replace_linear
(
module
,
linear_replacement
,
skip_modules
,
copy_weights
,
post_processing_function
)
if
isinstance
(
module
,
torch
.
nn
.
Linear
)
and
name
not
in
skip_modules
:
old_module
=
model
.
_modules
[
name
]
model
.
_modules
[
name
]
=
linear_replacement
(
module
.
in_features
,
module
.
out_features
,
module
.
bias
is
not
None
,
)
if
copy_weights
:
model
.
_modules
[
name
].
weight
=
old_module
.
weight
model
.
_modules
[
name
].
bias
=
old_module
.
bias
if
post_processing_function
is
not
None
:
func
=
getattr
(
module
,
post_processing_function
,
None
)
if
func
is
not
None
:
func
(
module
)
return
model
def
execute_and_return
(
command_string
:
str
)
->
Tuple
[
str
,
str
]:
def
_decode
(
subprocess_err_out_tuple
):
...
...
csrc/kernels.cu
View file @
c9f50506
...
...
@@ -543,7 +543,9 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
// load code through read-only cache via __ldg
#pragma unroll NUM_PER_TH
for
(
int
j
=
0
;
j
<
NUM_PER_TH
;
j
++
)
{
vals
[
j
]
=
__ldg
(
&
code
[
qvals
[
j
]])
*
local_abs_max
;
}
__syncthreads
();
StoreT
(
storet
).
Store
(
&
(
out
[
i
]),
vals
,
valid_items
);
...
...
tests/test_functional.py
View file @
c9f50506
...
...
@@ -2109,6 +2109,7 @@ def test_few_bit_quant():
ebits
=
math
.
ceil
(
bits
/
2
)
pbits
=
bits
-
ebits
-
1
code
=
F
.
create_fp8_map
(
True
,
ebits
,
pbits
,
bits
).
cuda
()
print
(
code
)
elif
method
==
'dynamic'
:
code
=
F
.
create_dynamic_map
(
True
,
bits
-
0
,
bits
).
cuda
()
elif
method
==
'quantile'
:
...
...
@@ -2181,7 +2182,9 @@ def test_kbit_quantile_estimation():
def
test_bench_dequantization
():
a
=
torch
.
rand
(
1024
,
1024
,
device
=
'cuda'
).
half
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
)
code
=
F
.
create_fp8_map
(
True
,
3
,
0
,
4
).
cuda
()
qa
,
SA
=
F
.
quantize_blockwise
(
a
,
code
=
code
)
print
(
qa
.
max
())
max_theoretical_mu
=
1024
*
1024
*
2
/
1024
**
3
/
672
*
1000
*
1000
#print(max_theoretical_mu)
...
...
@@ -2193,3 +2196,4 @@ def test_bench_dequantization():
torch
.
cuda
.
synchronize
()
#print((time.time()-t0)/1e6)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment