Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
1da48802
Commit
1da48802
authored
Sep 18, 2022
by
justheuristic
Browse files
change typecast behavior
parent
1145589f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
4 deletions
+3
-4
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+3
-4
No files found.
bitsandbytes/autograd/_functions.py
View file @
1da48802
...
...
@@ -356,7 +356,7 @@ class MatMul8bitLt(torch.autograd.Function):
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
).
to
(
ctx
.
dtype_bias
)
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
...
...
@@ -385,9 +385,8 @@ class MatMul8bitLt(torch.autograd.Function):
elif
state
.
CB
is
not
None
:
CB
=
state
.
CB
.
to
(
ctx
.
B_dtype
)
SCB
=
(
state
.
SCB
.
unsqueeze
(
1
)
/
127.0
).
half
()
CB
*=
SCB
grad_A
=
torch
.
mm
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
A_dtype
)
CB
.
mul_
(
state
.
SCB
.
unsqueeze
(
1
).
div_
(
127.0
).
to
(
ctx
.
B_dtype
))
grad_A
=
torch
.
matmul
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
A_dtype
)
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment