Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
e2b523d0
Commit
e2b523d0
authored
Sep 18, 2022
by
justheuristic
Browse files
change typecast behavior
parent
85bf5294
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
7 deletions
+5
-7
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+5
-7
No files found.
bitsandbytes/autograd/_functions.py
View file @
e2b523d0
...
@@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -230,16 +230,14 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
# Cast A to fp16
# Cast A to fp16
A_dtype
=
A
.
dtype
if
A
.
dtype
!=
torch
.
float16
:
if
A_dtype
!=
torch
.
float16
:
warnings
.
warn
(
f
"MatMul8bitLt: input matrix will be cast from
{
A
.
dtype
}
to float16"
)
warnings
.
warn
(
f
"MatMul8bitLt: input matrix will be converted from
{
A_dtype
}
to float16"
)
A
=
A
.
to
(
torch
.
float16
)
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
CA
,
CAt
,
SCA
,
SCAt
,
coo_tensorA
=
F
.
double_quant
(
A
,
threshold
=
state
.
threshold
A
.
to
(
torch
.
float16
)
,
threshold
=
state
.
threshold
)
)
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
:
...
@@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -316,10 +314,10 @@ class MatMul8bitLt(torch.autograd.Function):
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
if
bias
is
None
or
bias
.
dtype
==
torch
.
float16
:
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
bias
)
output
=
output
.
to
(
A
_
dtype
)
output
=
output
.
to
(
A
.
dtype
)
else
:
# apply bias separately
else
:
# apply bias separately
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
,
bias
=
None
)
output
=
output
.
to
(
A
_
dtype
).
add_
(
bias
)
output
=
output
.
to
(
A
.
dtype
).
add_
(
bias
)
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment