Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
d358999e
Commit
d358999e
authored
Sep 11, 2022
by
dbaranchuk
Browse files
refactoring
parent
ee325f02
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
9 deletions
+2
-9
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+2
-9
No files found.
bitsandbytes/autograd/_functions.py
View file @
d358999e
...
@@ -185,11 +185,10 @@ class MatmulLtState:
...
@@ -185,11 +185,10 @@ class MatmulLtState:
idx
=
None
idx
=
None
is_training
=
True
is_training
=
True
has_fp16_weights
=
True
has_fp16_weights
=
True
memory_efficient_backward
=
False
use_pool
=
False
use_pool
=
False
formatB
=
F
.
get_special_format_str
()
formatB
=
F
.
get_special_format_str
()
memory_efficient_backward
=
False
def
reset_grads
(
self
):
def
reset_grads
(
self
):
self
.
CB
=
None
self
.
CB
=
None
self
.
CxB
=
None
self
.
CxB
=
None
...
@@ -198,6 +197,7 @@ class MatmulLtState:
...
@@ -198,6 +197,7 @@ class MatmulLtState:
self
.
CxBt
=
None
self
.
CxBt
=
None
self
.
SBt
=
None
self
.
SBt
=
None
self
.
CBt
=
None
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
class
MatMul8bitLt
(
torch
.
autograd
.
Function
):
...
@@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -232,10 +232,6 @@ class MatMul8bitLt(torch.autograd.Function):
A_dtype
=
A
.
dtype
A_dtype
=
A
.
dtype
A
=
A
.
to
(
torch
.
float16
)
A
=
A
.
to
(
torch
.
float16
)
assert
(
A
.
dtype
==
torch
.
float16
),
f
"The input data type needs to be fp16 but
{
A
.
dtype
}
was found!"
# 1. Quantize A
# 1. Quantize A
if
len
(
A
.
shape
)
==
3
:
if
len
(
A
.
shape
)
==
3
:
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
A
=
A
.
view
(
-
1
,
A
.
shape
[
-
1
]).
contiguous
()
...
@@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -398,9 +394,6 @@ class MatMul8bitLt(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
return
grad_A
,
grad_B
,
None
,
grad_bias
,
None
matmul
=
MatMul8bitLt
.
apply
def
matmul
(
def
matmul
(
A
:
tensor
,
A
:
tensor
,
B
:
tensor
,
B
:
tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment