Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
06029dd6
Unverified
Commit
06029dd6
authored
Mar 13, 2024
by
Titus
Committed by
GitHub
Mar 13, 2024
Browse files
Merge pull request #1081 from akx/ruff-format
Reformat Python code with Ruff
parents
fd723b78
5a4263f4
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
8 deletions
+11
-8
tests/test_triton.py
tests/test_triton.py
+11
-8
No files found.
tests/test_triton.py
View file @
06029dd6
...
...
@@ -7,15 +7,18 @@ from bitsandbytes.triton.triton_utils import is_triton_available
from
tests.helpers
import
TRUE_FALSE
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
)
@
pytest
.
mark
.
skipif
(
not
is_triton_available
()
or
not
torch
.
cuda
.
is_available
()
or
not
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
,
reason
=
"This test requires triton and a GPU with compute capability 8.0 or higher."
,
)
@
pytest
.
mark
.
parametrize
(
"vector_wise_quantization"
,
TRUE_FALSE
)
def
test_switchback
(
vector_wise_quantization
):
for
dim
in
[
83
]:
for
batch
in
[
13
]:
standard
=
torch
.
nn
.
Linear
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
=
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
switchback
=
(
SwitchBackLinear
(
dim
,
4
*
dim
,
vector_wise_quantization
=
vector_wise_quantization
).
cuda
().
half
()
)
baseline
=
Linear8bitLt
(
dim
,
4
*
dim
).
cuda
().
half
()
switchback
.
weight
.
data
.
copy_
(
standard
.
weight
)
switchback
.
bias
.
data
.
copy_
(
standard
.
bias
)
...
...
@@ -38,23 +41,23 @@ def test_switchback(vector_wise_quantization):
err_sb
=
(
out_standard
-
out_sb
).
abs
().
mean
()
err_baseline
=
(
out_standard
-
out_baseline
).
abs
().
mean
()
print
(
'
OUT
'
,
err_sb
,
err_baseline
)
print
(
"
OUT
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
bias
.
grad
-
switchback
.
bias
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
bias
.
grad
-
baseline
.
bias
.
grad
).
abs
().
mean
()
print
(
'
GW2
'
,
err_sb
,
err_baseline
)
print
(
"
GW2
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
standard
.
weight
.
grad
-
switchback
.
weight
.
grad
).
abs
().
mean
()
err_baseline
=
(
standard
.
weight
.
grad
-
baseline
.
weight
.
grad
).
abs
().
mean
()
print
(
'
GW1
'
,
err_sb
,
err_baseline
)
print
(
"
GW1
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
err_sb
=
(
x1
.
grad
-
x2
.
grad
).
abs
().
mean
()
err_baseline
=
(
x1
.
grad
-
x3
.
grad
).
abs
().
mean
()
print
(
'
GX1
'
,
err_sb
,
err_baseline
)
print
(
"
GX1
"
,
err_sb
,
err_baseline
)
assert
err_sb
<
2
*
err_baseline
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment