Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
d6e25b5f
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "face20bdbd29b78976ed4d01524b2de4b7d77a2f"
Commit
d6e25b5f
authored
Sep 18, 2022
by
justheuristic
Browse files
change typecast behavior
parent
e2b523d0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
15 deletions
+10
-15
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+10
-15
No files found.
bitsandbytes/autograd/_functions.py
View file @
d6e25b5f
...
@@ -321,7 +321,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -321,7 +321,6 @@ class MatMul8bitLt(torch.autograd.Function):
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
assert
subA
.
dtype
==
state
.
subB
.
dtype
==
output
.
dtype
,
(
subA
.
dtype
,
state
.
subB
.
dtype
,
output
.
dtype
)
output
.
addmm_
(
subA
,
state
.
subB
)
output
.
addmm_
(
subA
,
state
.
subB
)
# 5. Save state
# 5. Save state
...
@@ -330,6 +329,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -330,6 +329,7 @@ class MatMul8bitLt(torch.autograd.Function):
ctx
.
formatB
=
formatB
ctx
.
formatB
=
formatB
ctx
.
grad_shape
=
input_shape
ctx
.
grad_shape
=
input_shape
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
,
requires_gradBias
]
ctx
.
req_grads
=
[
requires_gradA
,
requires_gradB
,
requires_gradBias
]
ctx
.
dtype_A
,
ctx
.
dtype_B
,
ctx
.
dtype_bias
=
A
.
dtype
,
B
.
dtype
,
None
if
bias
is
None
else
bias
.
dtype
if
requires_gradA
or
requires_gradB
:
if
requires_gradA
or
requires_gradB
:
ctx
.
tensors
=
(
CAt
,
subA
)
ctx
.
tensors
=
(
CAt
,
subA
)
...
@@ -348,7 +348,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -348,7 +348,7 @@ class MatMul8bitLt(torch.autograd.Function):
if
ctx
.
is_empty
:
if
ctx
.
is_empty
:
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req
_grad
s
req_gradA
,
req_gradB
,
_
,
req_gradBias
,
_
=
ctx
.
needs_input
_grad
CAt
,
subA
=
ctx
.
tensors
CAt
,
subA
=
ctx
.
tensors
SCAt
,
idx
=
ctx
.
tensor_states
SCAt
,
idx
=
ctx
.
tensor_states
formatB
=
ctx
.
formatB
formatB
=
ctx
.
formatB
...
@@ -357,25 +357,22 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -357,25 +357,22 @@ class MatMul8bitLt(torch.autograd.Function):
if
req_gradBias
:
if
req_gradBias
:
# compute grad_bias first before changing grad_output dtype
# compute grad_bias first before changing grad_output dtype
grad_bias
=
grad_output
.
sum
(
0
)
grad_bias
=
grad_output
.
sum
(
0
)
.
to
(
ctx
.
bias_dtype
)
# Cast grad_output to fp16
# Cast grad_output to fp16
grad_output_dtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float16
)
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
-
1
,
grad_output
.
shape
[
-
1
]
).
contiguous
()
).
contiguous
()
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
)
Cgrad
,
Cgradt
,
SCgrad
,
SCgradt
,
coo_tensor
=
F
.
double_quant
(
grad_output
.
to
(
torch
.
float16
)
)
if
req_gradB
:
if
req_gradB
:
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
CxAt
,
SAt
=
F
.
transform
(
CAt
,
formatB
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
C32grad
,
Sgrad
=
F
.
transform
(
Cgradt
,
"col32"
,
transpose
=
True
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
gradB32
,
SgradB32
=
F
.
igemmlt
(
C32grad
,
CxAt
,
Sgrad
,
SAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
grad_B
=
F
.
mm_dequant
(
gradB32
,
SgradB32
,
SCgradt
,
SCAt
)
.
to
(
ctx
.
B_dtype
)
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
if
state
.
threshold
>
0.0
and
subA
is
not
None
:
grad_B
[:,
idx
]
+=
torch
.
matmul
(
grad_output
.
t
(),
subA
)
grad_B
[:,
idx
]
.
addmm_
(
grad_output
.
t
(),
subA
)
if
req_gradA
:
if
req_gradA
:
if
state
.
CBt
is
not
None
:
if
state
.
CBt
is
not
None
:
...
@@ -385,18 +382,16 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -385,18 +382,16 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
state
.
CBt
,
to_order
=
formatB
,
transpose
=
True
)
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
gradA32
,
SgradA32
=
F
.
igemmlt
(
C32grad
,
state
.
CxBt
,
Sgrad
,
state
.
SBt
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
)
grad_A
=
F
.
mm_dequant
(
gradA32
,
SgradA32
,
SCgrad
,
state
.
SCBt
).
view
(
ctx
.
grad_shape
).
to
(
ctx
.
A_dtype
)
elif
state
.
CB
is
not
None
:
elif
state
.
CB
is
not
None
:
CB
=
state
.
CB
.
half
(
)
CB
=
state
.
CB
.
to
(
ctx
.
B_dtype
)
SCB
=
(
state
.
SCB
.
unsqueeze
(
1
)
/
127.0
).
half
()
SCB
=
(
state
.
SCB
.
unsqueeze
(
1
)
/
127.0
).
half
()
CB
*=
SCB
CB
*=
SCB
grad_A
=
torch
.
mm
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
)
grad_A
=
torch
.
mm
(
grad_output
,
CB
).
view
(
ctx
.
grad_shape
)
.
to
(
ctx
.
A_dtype
)
else
:
else
:
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
raise
Exception
(
'State must contain either CBt or CB matrix for backward'
)
# Cast grad_A back to grad_output_dtype
grad_output
=
grad_output
.
to
(
grad_output_dtype
)
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment