Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c93a90d0
"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "23e756af21fc22b486d1a658400412ca1e776295"
Commit
c93a90d0
authored
Feb 14, 2023
by
Tim Dettmers
Browse files
Fixed FP4 import and data type conversion in backward.
parent
7f0773ae
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
6 deletions
+2
-6
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+1
-5
bitsandbytes/nn/__init__.py
bitsandbytes/nn/__init__.py
+1
-1
No files found.
bitsandbytes/autograd/_functions.py
View file @
c93a90d0
...
@@ -525,13 +525,9 @@ class MatMulFP4(torch.autograd.Function):
...
@@ -525,13 +525,9 @@ class MatMulFP4(torch.autograd.Function):
# compute grad_bias first before changing grad_output dtype
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
grad_bias
=
grad_output
.
sum
(
0
,
dtype
=
ctx
.
dtype_bias
)
# Cast grad_output to fp16
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]).
contiguous
()
# not supported by PyTorch. TODO: create work-around
# not supported by PyTorch. TODO: create work-around
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
#if req_gradB: grad_B = torch.matmul(grad_output.t(), A)
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
ctx
.
dtype
_A
).
t
())
if
req_gradA
:
grad_A
=
torch
.
matmul
(
grad_output
,
F
.
dequantize_fp4
(
B
,
ctx
.
state
).
to
(
grad_output
.
dtype
).
t
())
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
bitsandbytes/nn/__init__.py
View file @
c93a90d0
...
@@ -2,4 +2,4 @@
...
@@ -2,4 +2,4 @@
#
#
# This source code is licensed under the MIT license found in the
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# LICENSE file in the root directory of this source tree.
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
LinearFP4
from
.modules
import
Int8Params
,
Linear8bitLt
,
StableEmbedding
,
LinearFP4
,
FP4Params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment