Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
9e7cdc9e
Commit
9e7cdc9e
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Added last SwitchBack refactors. All tests green.
parent
008dfff9
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
19 deletions
+26
-19
CHANGELOG.md
CHANGELOG.md
+7
-0
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+9
-9
setup.py
setup.py
+1
-1
tests/test_triton.py
tests/test_triton.py
+8
-8
No files found.
CHANGELOG.md
View file @
9e7cdc9e
...
...
@@ -221,3 +221,10 @@ Improvements:
Deprecated:
-
Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
-
Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
### 0.38.1
Features:
-
Added Int8 SwitchBack layers
-
Added Fake FP8 layers for research purposes (available under
`bnb.research.nn. ...`
)
bitsandbytes/nn/__init__.py
View file @
9e7cdc9e
...
...
@@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
SwitchBackLinearBnb
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVector
ized
,
StandardLinear
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVector
wise
,
StandardLinear
bitsandbytes/nn/triton_based_modules.py
View file @
9e7cdc9e
...
...
@@ -157,7 +157,7 @@ class SwitchBackLinear(nn.Linear):
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
vector
ize
:
bool
=
False
,
vector
_wise_quantization
:
bool
=
False
,
mem_efficient
:
bool
=
False
,
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
...
...
@@ -167,11 +167,11 @@ class SwitchBackLinear(nn.Linear):
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower'''
)
# By default, we use the global quantization.
self
.
vector
ize
=
vectorize
if
self
.
vector
ize
:
self
.
vector
_wise_quantization
=
vector_wise_quantization
if
self
.
vector
_wise_quantization
:
self
.
_fn
=
_switchback_vectorrize
if
mem_efficient
:
print
(
'mem efficient is not supported for vector
ize mode
.'
)
print
(
'mem efficient is not supported for vector
-wise quantization
.'
)
exit
(
1
)
else
:
if
mem_efficient
:
...
...
@@ -188,7 +188,7 @@ class SwitchBackLinear(nn.Linear):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print
(
'=> preparing for eval.'
)
if
self
.
vector
ize
:
if
self
.
vector
_wise_quantization
:
W_int8
,
state_W
=
quantize_rowwise
(
self
.
weight
)
else
:
W_int8
,
state_W
=
quantize_global
(
self
.
weight
)
...
...
@@ -210,7 +210,7 @@ class SwitchBackLinear(nn.Linear):
X
=
x
.
view
(
-
1
,
x
.
size
(
-
1
))
X_int8
,
state_X
=
quantize_rowwise
(
X
)
if
self
.
vector
ize
:
if
self
.
vector
_wise_quantization
:
return
int8_matmul_rowwise_dequantize
(
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
...
...
@@ -219,9 +219,9 @@ class SwitchBackLinear(nn.Linear):
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vector
ize
=
False
)
SwitchBackLinearGlobalMemEfficient
=
partial
(
SwitchBackLinear
,
vector
ize
=
False
,
mem_efficient
=
True
)
SwitchBackLinearVector
ized
=
partial
(
SwitchBackLinear
,
vector
ize
=
True
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
False
)
SwitchBackLinearGlobalMemEfficient
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
False
,
mem_efficient
=
True
)
SwitchBackLinearVector
wise
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
True
)
# This is just the standard linear function.
class
StandardLinearFunction
(
torch
.
autograd
.
Function
):
...
...
setup.py
View file @
9e7cdc9e
...
...
@@ -18,7 +18,7 @@ def read(fname):
setup
(
name
=
f
"bitsandbytes"
,
version
=
f
"0.38.
0.post2
"
,
version
=
f
"0.38.
1
"
,
author
=
"Tim Dettmers"
,
author_email
=
"dettmers@cs.washington.edu"
,
description
=
"8-bit optimizers and matrix multiplication routines."
,
...
...
tests/test_triton.py
View file @
9e7cdc9e
import
pytest
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
from
bitsandbytes.nn
import
Linear8bitLt
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector
rize
"
,
[
False
,
True
])
def
test_switchback
(
vector
rize
):
for
dim
in
[
83
,
17
,
128
]:
for
batch
in
[
13
,
128
,
256
]:
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires
triton and
a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector
_wise_quantization
"
,
[
False
,
True
])
def
test_switchback
(
vector
_wise_quantization
):
for
dim
in
[
83
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
print
(
'vectorrize'
,
vectorrize
)
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vectorize
=
vectorrize
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment