Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
8d34d36f
"docs/source/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "577cf2e63b847d6e5a3ec35d68ba8241d7498182"
Commit
8d34d36f
authored
Aug 29, 2022
by
dbaranchuk
Browse files
req_gradA for casted & more efficient and accurate fp16 backward
parent
b3fee1ed
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
11 deletions
+12
-11
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+12
-11
No files found.
bitsandbytes/autograd/_functions.py
View file @
8d34d36f
...
@@ -213,10 +213,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -213,10 +213,6 @@ class MatMul8bitLt(torch.autograd.Function):
else
:
else
:
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
return
torch
.
empty
(
A
.
shape
[:
-
1
]
+
B
.
shape
[:
1
],
dtype
=
torch
.
float16
,
device
=
A
.
device
)
# Cast A to fp16
A_dtype
=
A
.
dtype
A
=
A
.
to
(
torch
.
float16
)
# 1. Quantize A
# 1. Quantize A
# 2. Quantize B
# 2. Quantize B
# 3. Matmul
# 3. Matmul
...
@@ -229,6 +225,11 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -229,6 +225,11 @@ class MatMul8bitLt(torch.autograd.Function):
input_shape
=
A
.
shape
input_shape
=
A
.
shape
if
state
.
outlier_pool
is
None
:
if
state
.
outlier_pool
is
None
:
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
state
.
outlier_pool
=
GlobalOutlierPooler
.
get_instance
()
# Cast A to fp16
A_dtype
=
A
.
dtype
A
=
A
.
to
(
torch
.
float16
)
assert
(
assert
(
A
.
dtype
==
torch
.
float16
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
...
@@ -337,14 +338,14 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -337,14 +338,14 @@ class MatMul8bitLt(torch.autograd.Function):
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
bias_grad
=
(
None
if
ctx
.
bias
is
None
else
torch
.
zeros_like
(
ctx
.
bias
))
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
return
torch
.
zeros_like
(
ctx
.
A
),
torch
.
zeros_like
(
ctx
.
B
),
None
,
bias_grad
,
None
# Cast grad_output to fp16
grad_output_dtype
=
grad_output
.
dtype
grad_output
.
to
(
torch
.
float16
)
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
req_gradA
,
req_gradB
,
req_gradBias
=
ctx
.
req_grads
assert
not
req_gradB
,
"TODO: support weight updates as well"
assert
not
req_gradB
,
"TODO: support weight updates as well"
state
=
ctx
.
state
state
=
ctx
.
state
# Cast grad_output to fp16
grad_output_dtype
=
grad_output
.
dtype
grad_output
=
grad_output
.
to
(
torch
.
float16
)
if
len
(
grad_output
.
shape
)
==
3
:
if
len
(
grad_output
.
shape
)
==
3
:
grad_output
=
grad_output
.
reshape
(
grad_output
=
grad_output
.
reshape
(
-
1
,
grad_output
.
shape
[
-
1
]
-
1
,
grad_output
.
shape
[
-
1
]
...
@@ -354,9 +355,9 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -354,9 +355,9 @@ class MatMul8bitLt(torch.autograd.Function):
if
req_gradA
:
if
req_gradA
:
CB
=
state
.
CB
.
half
()
CB
=
state
.
CB
.
half
()
SCB
=
state
.
SCB
.
unsqueeze
(
1
).
half
()
SCB
=
(
state
.
SCB
.
unsqueeze
(
1
)
/
127.0
)
.
half
()
B
=
(
CB
*
SCB
)
/
127.0
CB
*
=
SCB
grad_A
=
torch
.
mm
(
grad_output
,
B
).
view
(
ctx
.
grad_shape
)
grad_A
=
torch
.
mm
(
grad_output
,
C
B
).
view
(
ctx
.
grad_shape
)
if
req_gradBias
:
if
req_gradBias
:
grad_bias
=
grad_output
.
sum
(
0
)
grad_bias
=
grad_output
.
sum
(
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment