Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
d15822a5
"tests/python/common/data/test_data.py" did not exist on "56b5d0e53add8c436c90e94c12391e2efdf49490"
Commit
d15822a5
authored
Feb 25, 2023
by
Max Ryabinin
Browse files
Refactor _tile_indices into a cached property, fix device bug
parent
cc608c04
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
22 deletions
+11
-22
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+10
-8
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+1
-14
No files found.
bitsandbytes/autograd/_functions.py
View file @
d15822a5
...
@@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply
...
@@ -223,7 +223,7 @@ matmul_cublas = MatMul8bit.apply
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
_
tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
force_no_igemmlt
:
bool
=
False
force_no_igemmlt
:
bool
=
False
CB
=
None
CB
=
None
CxB
=
None
CxB
=
None
...
@@ -263,6 +263,15 @@ class MatmulLtState:
...
@@ -263,6 +263,15 @@ class MatmulLtState:
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
@
property
def
tile_indices
(
self
):
if
self
.
_tile_indices
is
None
:
device
=
self
.
CxB
.
device
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
self
.
formatB
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
_tile_indices
=
get_inverse_transform_indices
(
transform
,
self
.
get_tile_size
()).
to
(
device
)
return
self
.
_tile_indices
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
# forward is the same, but we added the fallback for pre-turing GPUs
# forward is the same, but we added the fallback for pre-turing GPUs
...
@@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -455,13 +464,6 @@ class MatMul8bitLt(torch.autograd.Function):
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.0
/
127.0
))
CB
=
state
.
CB
.
to
(
ctx
.
dtype_A
,
copy
=
True
).
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
mul
(
1.0
/
127.0
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
dtype_A
)
elif
state
.
CxB
is
not
None
:
elif
state
.
CxB
is
not
None
:
if
state
.
tile_indices
is
None
:
order
,
tile_size
=
state
.
formatB
,
state
.
get_tile_size
()
transform
=
lambda
x
:
F
.
transform
(
x
.
cuda
(),
from_order
=
"row"
,
to_order
=
order
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
state
.
tile_indices
=
get_inverse_transform_indices
(
transform
,
tile_size
).
to
(
state
.
CxB
.
device
)
CB
=
(
CB
=
(
undo_layout
(
state
.
CxB
,
state
.
tile_indices
)
undo_layout
(
state
.
CxB
,
state
.
tile_indices
)
.
to
(
ctx
.
dtype_A
)
.
to
(
ctx
.
dtype_A
)
...
...
bitsandbytes/nn/modules.py
View file @
d15822a5
...
@@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear):
...
@@ -236,20 +236,7 @@ class Linear8bitLt(nn.Linear):
try
:
try
:
if
reorder_layout
:
if
reorder_layout
:
if
self
.
state
.
tile_indices
is
None
:
self
.
weight
.
data
=
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
order
,
tile_size
=
self
.
state
.
formatB
,
self
.
state
.
get_tile_size
()
transform
=
lambda
x
:
\
bitsandbytes
.
functional
.
transform
(
x
.
to
(
self
.
weight
.
data
.
device
),
from_order
=
"row"
,
to_order
=
order
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
state
.
tile_indices
=
get_inverse_transform_indices
(
transform
,
tile_size
).
to
(
self
.
state
.
CxB
.
device
)
CB
=
(
undo_layout
(
self
.
state
.
CxB
,
self
.
state
.
tile_indices
)
)
self
.
weight
.
data
=
CB
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
super
().
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment