Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
048a2d40
Unverified
Commit
048a2d40
authored
Mar 06, 2024
by
Aarni Koskela
Committed by
GitHub
Mar 06, 2024
Browse files
Deduplicate helpers & fix lint issues from #1099 (#1107)
parent
a1c0844b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
28 additions
and
34 deletions
+28
-34
tests/helpers.py
tests/helpers.py
+20
-7
tests/test_linear4bit.py
tests/test_linear4bit.py
+1
-13
tests/test_linear8bitlt.py
tests/test_linear8bitlt.py
+7
-14
No files found.
tests/helpers.py
View file @
048a2d40
from
io
import
BytesIO
from
itertools
import
product
from
itertools
import
product
import
random
import
random
from
typing
import
Any
,
List
from
typing
import
Any
,
List
...
@@ -7,6 +8,25 @@ import torch
...
@@ -7,6 +8,25 @@ import torch
test_dims_rng
=
random
.
Random
(
42
)
test_dims_rng
=
random
.
Random
(
42
)
TRUE_FALSE
=
(
True
,
False
)
BOOLEAN_TRIPLES
=
list
(
product
(
TRUE_FALSE
,
repeat
=
3
))
# all combinations of (bool, bool, bool)
BOOLEAN_TUPLES
=
list
(
product
(
TRUE_FALSE
,
repeat
=
2
))
# all combinations of (bool, bool)
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
def
get_test_dims
(
min
:
int
,
max
:
int
,
*
,
n
:
int
)
->
List
[
int
]:
def
get_test_dims
(
min
:
int
,
max
:
int
,
*
,
n
:
int
)
->
List
[
int
]:
return
[
test_dims_rng
.
randint
(
min
,
max
)
for
_
in
range
(
n
)]
return
[
test_dims_rng
.
randint
(
min
,
max
)
for
_
in
range
(
n
)]
...
@@ -42,10 +62,3 @@ DTYPE_NAMES = {
...
@@ -42,10 +62,3 @@ DTYPE_NAMES = {
def
describe_dtype
(
dtype
:
torch
.
dtype
)
->
str
:
def
describe_dtype
(
dtype
:
torch
.
dtype
)
->
str
:
return
DTYPE_NAMES
.
get
(
dtype
)
or
str
(
dtype
).
rpartition
(
"."
)[
2
]
return
DTYPE_NAMES
.
get
(
dtype
)
or
str
(
dtype
).
rpartition
(
"."
)[
2
]
TRUE_FALSE
=
(
True
,
False
)
BOOLEAN_TRIPLES
=
list
(
product
(
TRUE_FALSE
,
repeat
=
3
)
)
# all combinations of (bool, bool, bool)
BOOLEAN_TUPLES
=
list
(
product
(
TRUE_FALSE
,
repeat
=
2
))
# all combinations of (bool, bool)
tests/test_linear4bit.py
View file @
048a2d40
import
copy
import
copy
from
io
import
BytesIO
import
os
import
os
import
pickle
import
pickle
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -8,7 +7,7 @@ import pytest
...
@@ -8,7 +7,7 @@ import pytest
import
torch
import
torch
import
bitsandbytes
as
bnb
import
bitsandbytes
as
bnb
from
tests.helpers
import
TRUE_FALSE
from
tests.helpers
import
TRUE_FALSE
,
torch_load_from_buffer
,
torch_save_to_buffer
storage
=
{
storage
=
{
"uint8"
:
torch
.
uint8
,
"uint8"
:
torch
.
uint8
,
...
@@ -17,17 +16,6 @@ storage = {
...
@@ -17,17 +16,6 @@ storage = {
"float32"
:
torch
.
float32
,
"float32"
:
torch
.
float32
,
}
}
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
"uint8"
,
"float16"
,
"bfloat16"
,
"float32"
])
@
pytest
.
mark
.
parametrize
(
"quant_storage"
,
[
"uint8"
,
"float16"
,
"bfloat16"
,
"float32"
])
@
pytest
.
mark
.
parametrize
(
"bias"
,
TRUE_FALSE
)
@
pytest
.
mark
.
parametrize
(
"bias"
,
TRUE_FALSE
)
...
...
tests/test_linear8bitlt.py
View file @
048a2d40
from
contextlib
import
nullcontext
from
contextlib
import
nullcontext
from
io
import
BytesIO
import
os
import
os
from
tempfile
import
TemporaryDirectory
from
tempfile
import
TemporaryDirectory
...
@@ -10,7 +9,12 @@ import bitsandbytes as bnb
...
@@ -10,7 +9,12 @@ import bitsandbytes as bnb
from
bitsandbytes
import
functional
as
F
from
bitsandbytes
import
functional
as
F
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.autograd
import
get_inverse_transform_indices
,
undo_layout
from
bitsandbytes.nn.modules
import
Linear8bitLt
from
bitsandbytes.nn.modules
import
Linear8bitLt
from
tests.helpers
import
TRUE_FALSE
,
id_formatter
from
tests.helpers
import
(
TRUE_FALSE
,
id_formatter
,
torch_load_from_buffer
,
torch_save_to_buffer
,
)
# contributed by Alex Borzunov, see:
# contributed by Alex Borzunov, see:
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
# https://github.com/bigscience-workshop/petals/blob/main/tests/test_linear8bitlt.py
...
@@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
...
@@ -66,17 +70,6 @@ def test_linear_no_igemmlt():
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CB
is
not
None
assert
linear_custom
.
state
.
CxB
is
None
assert
linear_custom
.
state
.
CxB
is
None
def
torch_save_to_buffer
(
obj
):
buffer
=
BytesIO
()
torch
.
save
(
obj
,
buffer
)
buffer
.
seek
(
0
)
return
buffer
def
torch_load_from_buffer
(
buffer
):
buffer
.
seek
(
0
)
obj
=
torch
.
load
(
buffer
)
buffer
.
seek
(
0
)
return
obj
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"has_fp16_weights"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"has_fp16_weights"
))
@
pytest
.
mark
.
parametrize
(
"serialize_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"serialize_before_forward"
))
@
pytest
.
mark
.
parametrize
(
"serialize_before_forward"
,
TRUE_FALSE
,
ids
=
id_formatter
(
"serialize_before_forward"
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment