Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
4fb37d45
Commit
4fb37d45
authored
Jun 09, 2023
by
Max Ryabinin
Browse files
Extract get_tile_inds to a separate function
parent
ac5550a0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
14 additions
and
11 deletions
+14
-11
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+14
-11
No files found.
bitsandbytes/autograd/_functions.py
View file @
4fb37d45
...
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
...
@@ -232,6 +232,19 @@ def supports_igemmlt(device: torch.device) -> bool:
return
True
return
True
def
_get_tile_size
(
format
):
assert
format
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
format
}
"
return
(
8
,
32
)
if
format
==
"col_turing"
else
(
32
,
32
)
def
get_tile_inds
(
format
,
device
):
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
format
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
return
get_inverse_transform_indices
(
transform
,
_get_tile_size
(
format
)).
to
(
device
)
@
dataclass
@
dataclass
class
MatmulLtState
:
class
MatmulLtState
:
_tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
_tile_indices
:
Optional
[
torch
.
Tensor
]
=
None
...
@@ -267,20 +280,10 @@ class MatmulLtState:
...
@@ -267,20 +280,10 @@ class MatmulLtState:
self
.
SBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
self
.
CBt
=
None
def
get_tile_size
(
self
):
assert
self
.
formatB
in
(
"col_turing"
,
"col_ampere"
,
),
f
"please find this assert and manually enter tile size for
{
self
.
formatB
}
"
return
(
8
,
32
)
if
self
.
formatB
==
"col_turing"
else
(
32
,
32
)
@
property
@
property
def
tile_indices
(
self
):
def
tile_indices
(
self
):
if
self
.
_tile_indices
is
None
:
if
self
.
_tile_indices
is
None
:
device
=
self
.
CxB
.
device
self
.
_tile_indices
=
get_tile_inds
(
self
.
formatB
,
self
.
CxB
.
device
)
transform
=
lambda
x
:
F
.
transform
(
x
.
to
(
device
),
from_order
=
"row"
,
to_order
=
self
.
formatB
)[
0
].
to
(
x
.
device
)
with
torch
.
no_grad
():
self
.
_tile_indices
=
get_inverse_transform_indices
(
transform
,
self
.
get_tile_size
()).
to
(
device
)
return
self
.
_tile_indices
return
self
.
_tile_indices
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment