Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
dcecbb26
Commit
dcecbb26
authored
Mar 22, 2023
by
Max Ryabinin
Browse files
Add force_no_igemmlt to test params
parent
24609b66
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
3 deletions
+8
-3
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+8
-3
No files found.
tests/test_linear8bitlt.py
View file @
dcecbb26
...
@@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
...
@@ -69,9 +69,9 @@ def test_linear_no_igemmlt():
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
skipif
(
not
torch
.
cuda
.
is_available
(),
reason
=
"this test requires a GPU"
)
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights, serialize_before_forward, deserialize_before_cuda"
,
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights, serialize_before_forward, deserialize_before_cuda
, force_no_igemmlt
"
,
list
(
product
([
False
,
True
],
[
False
,
True
],
[
False
,
True
])))
list
(
product
([
False
,
True
],
[
False
,
True
],
[
False
,
True
],
[
False
,
True
])))
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
):
def
test_linear_serialization
(
has_fp16_weights
,
serialize_before_forward
,
deserialize_before_cuda
,
force_no_igemmlt
):
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
linear
=
torch
.
nn
.
Linear
(
32
,
96
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
x
=
torch
.
randn
(
3
,
32
,
dtype
=
torch
.
half
)
...
@@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -82,6 +82,9 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
threshold
=
6.0
,
)
)
if
force_no_igemmlt
:
linear_custom
.
state
.
force_no_igemmlt
=
True
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear_custom
.
weight
=
bnb
.
nn
.
Int8Params
(
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
linear
.
weight
.
data
.
clone
(),
requires_grad
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
)
)
...
@@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
...
@@ -118,6 +121,8 @@ def test_linear_serialization(has_fp16_weights, serialize_before_forward, deseri
has_fp16_weights
=
has_fp16_weights
,
has_fp16_weights
=
has_fp16_weights
,
threshold
=
6.0
,
threshold
=
6.0
,
)
)
if
force_no_igemmlt
:
new_linear_custom
.
state
.
force_no_igemmlt
=
True
if
deserialize_before_cuda
:
if
deserialize_before_cuda
:
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
with
nullcontext
()
if
has_fp16_weights
else
pytest
.
raises
(
RuntimeError
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment