Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
56a074f6
Commit
56a074f6
authored
Sep 17, 2022
by
justheuristic
Browse files
un-fuse bias
parent
d9ca0ed9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
6 deletions
+7
-6
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+7
-6
No files found.
bitsandbytes/autograd/_functions.py
View file @
56a074f6
...
@@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -314,10 +314,13 @@ class MatMul8bitLt(torch.autograd.Function):
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
# we apply the fused bias here
# we apply the fused bias here
fused_bias
=
bias
if
bias
.
dtype
==
torch
.
float16
else
None
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
fused_bias
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
if
fused_bias
is
None
and
bias
is
not
None
:
output
=
output
.
to
(
A_dtype
)
output
.
add_
(
bias
.
to
(
output
.
dtype
))
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A_dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
...
@@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -338,8 +341,6 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
tensor_states
=
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
ctx
.
save_for_backward
(
None
,
None
)
# Cast fp16 output back to A.dtype
output
=
output
.
to
(
A_dtype
)
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
clone_func
=
torch
.
clone
if
len
(
output_shape
)
==
3
else
lambda
x
:
x
return
clone_func
(
output
.
view
(
output_shape
))
return
clone_func
(
output
.
view
(
output_shape
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment