Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
9e7cdc9e
"vscode:/vscode.git/clone" did not exist on "e8620a86bd2c3ebc0e891cd0e822c9b37104b7d0"
Commit
9e7cdc9e
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Added last SwitchBack refactors. All tests green.
parent
008dfff9
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
26 additions
and
19 deletions
+26
-19
CHANGELOG.md
CHANGELOG.md
+7
-0
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+9
-9
setup.py
setup.py
+1
-1
tests/test_triton.py
tests/test_triton.py
+8
-8
No files found.
CHANGELOG.md
View file @
9e7cdc9e
...
...
@@ -221,3 +221,10 @@ Improvements:
Deprecated:
-
Devices with compute capability 3.0 (GTX 700s, K10) and 3.2 (Tegra K1, Jetson TK1) are now deprecated and support will be removed in 0.39.0.
-
Support for CUDA 10.0 and 10.2 will be removed in bitsandbytes 0.39.0
### 0.38.1
Features:
-
Added Int8 SwitchBack layers
-
Added Fake FP8 layers for research purposes (available under
`bnb.research.nn. ...`
)
bitsandbytes/nn/__init__.py
View file @
9e7cdc9e
...
...
@@ -3,4 +3,4 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
OutlierAwareLinear
,
SwitchBackLinearBnb
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVector
ized
,
StandardLinear
from
.triton_based_modules
import
SwitchBackLinear
,
SwitchBackLinearGlobal
,
SwitchBackLinearVector
wise
,
StandardLinear
bitsandbytes/nn/triton_based_modules.py
View file @
9e7cdc9e
...
...
@@ -157,7 +157,7 @@ class SwitchBackLinear(nn.Linear):
bias
:
bool
=
True
,
device
=
None
,
dtype
=
None
,
vector
ize
:
bool
=
False
,
vector
_wise_quantization
:
bool
=
False
,
mem_efficient
:
bool
=
False
,
):
super
().
__init__
(
in_features
,
out_features
,
bias
,
device
,
dtype
)
...
...
@@ -167,11 +167,11 @@ class SwitchBackLinear(nn.Linear):
Alternatively, you can use bnb.nn.SwitchBackLinearBnb, but it will be slower'''
)
# By default, we use the global quantization.
self
.
vector
ize
=
vectorize
if
self
.
vector
ize
:
self
.
vector
_wise_quantization
=
vector_wise_quantization
if
self
.
vector
_wise_quantization
:
self
.
_fn
=
_switchback_vectorrize
if
mem_efficient
:
print
(
'mem efficient is not supported for vector
ize mode
.'
)
print
(
'mem efficient is not supported for vector
-wise quantization
.'
)
exit
(
1
)
else
:
if
mem_efficient
:
...
...
@@ -188,7 +188,7 @@ class SwitchBackLinear(nn.Linear):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print
(
'=> preparing for eval.'
)
if
self
.
vector
ize
:
if
self
.
vector
_wise_quantization
:
W_int8
,
state_W
=
quantize_rowwise
(
self
.
weight
)
else
:
W_int8
,
state_W
=
quantize_global
(
self
.
weight
)
...
...
@@ -210,7 +210,7 @@ class SwitchBackLinear(nn.Linear):
X
=
x
.
view
(
-
1
,
x
.
size
(
-
1
))
X_int8
,
state_X
=
quantize_rowwise
(
X
)
if
self
.
vector
ize
:
if
self
.
vector
_wise_quantization
:
return
int8_matmul_rowwise_dequantize
(
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
...
...
@@ -219,9 +219,9 @@ class SwitchBackLinear(nn.Linear):
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vector
ize
=
False
)
SwitchBackLinearGlobalMemEfficient
=
partial
(
SwitchBackLinear
,
vector
ize
=
False
,
mem_efficient
=
True
)
SwitchBackLinearVector
ized
=
partial
(
SwitchBackLinear
,
vector
ize
=
True
)
SwitchBackLinearGlobal
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
False
)
SwitchBackLinearGlobalMemEfficient
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
False
,
mem_efficient
=
True
)
SwitchBackLinearVector
wise
=
partial
(
SwitchBackLinear
,
vector
_wise_quantization
=
True
)
# This is just the standard linear function.
class
StandardLinearFunction
(
torch
.
autograd
.
Function
):
...
...
setup.py
View file @
9e7cdc9e
...
...
@@ -18,7 +18,7 @@ def read(fname):
setup
(
name
=
f
"bitsandbytes"
,
version
=
f
"0.38.
0.post2
"
,
version
=
f
"0.38.
1
"
,
author
=
"Tim Dettmers"
,
author_email
=
"dettmers@cs.washington.edu"
,
description
=
"8-bit optimizers and matrix multiplication routines."
,
...
...
tests/test_triton.py
View file @
9e7cdc9e
import
pytest
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.nn.triton_based_modules
import
SwitchBackLinear
from
bitsandbytes.nn
import
Linear8bitLt
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector
rize
"
,
[
False
,
True
])
def
test_switchback
(
vector
rize
):
for
dim
in
[
83
,
17
,
128
]:
for
batch
in
[
13
,
128
,
256
]:
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires
triton and
a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
parametrize
(
"vector
_wise_quantization
"
,
[
False
,
True
])
def
test_switchback
(
vector
_wise_quantization
):
for
dim
in
[
83
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
print
(
'vectorrize'
,
vectorrize
)
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vectorize
=
vectorrize
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment