Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1753aa04
Commit
1753aa04
authored
Aug 23, 2022
by
dbaranchuk
Browse files
refactoring
parent
8ae9bb23
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
23 additions
and
17 deletions
+23
-17
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+23
-17
No files found.
bitsandbytes/autograd/_functions.py
View file @
1753aa04
...
@@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -245,10 +245,11 @@ class MatMul8bitLt(torch.autograd.Function):
subA
=
A
[:,
idx
]
subA
=
A
[:,
idx
]
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
subB
=
B
[:,
idx
].
t
().
contiguous
()
state
.
idx
=
idx
state
.
idx
=
idx
elif
state
.
CxB
is
None
:
else
:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
if
state
.
CxB
is
None
:
# we also need to convert it to the turing/ampere format
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
else
:
else
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
if
not
state
.
has_fp16_weights
and
state
.
CxB
is
None
:
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
...
@@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -355,19 +356,24 @@ class MatMul8bitLt(torch.autograd.Function):
if
req_gradA
:
if
req_gradA
:
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgrad
,
"col32"
)
if
state
.
CxBt
is
None
and
state
.
has_fp16_weights
:
if
state
.
CxBt
is
None
:
CBt
=
state
.
CBt
if
state
.
has_fp16_weights
:
elif
state
.
CxBt
is
None
:
CBt
=
state
.
CBt
assert
state
.
CBt
is
None
else
:
CB
=
state
.
CB
.
half
()
# Restore CBt from CB
SCB
=
state
.
SCB
.
unsquezee
(
1
).
half
()
assert
state
.
CBt
is
None
,
"CBt should not be stored in state"
SCBt
=
state
.
SCBt
.
unsquezee
(
1
).
half
()
CB
=
state
.
CB
.
half
()
Bt
=
(
CB
*
SCB
).
t
().
contiguous
()
SCB
=
state
.
SCB
.
unsquezee
(
1
).
half
()
CBt
=
(
Bt
/
SCBt
).
t
().
to
(
torch
.
int8
)
SCBt
=
state
.
SCBt
.
unsquezee
(
1
).
half
()
Bt
=
(
CB
*
SCB
).
t
().
contiguous
()
CxBt
,
SBt
=
F
.
transform
(
CBt
=
(
Bt
/
SCBt
).
t
().
to
(
torch
.
int8
)
CBt
,
to_order
=
formatB
,
transpose
=
True
)
# intentionally, do not store CxBt into state
CxBt
,
SBt
=
F
.
transform
(
CBt
,
to_order
=
formatB
,
transpose
=
True
)
else
:
CxBt
=
state
.
CxBt
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
CxBt
,
Sgrad
,
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
CxBt
,
Sgrad
,
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment