Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
2dfa3ce1
"docs/git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "9e4b2ff58eee5172c3ebaf8ca8cc1f1bcd49978d"
Commit
2dfa3ce1
authored
Feb 13, 2023
by
Tim Dettmers
Browse files
Fixed LinearFP8 and added tests.
parent
fa255cbc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
3 deletions
+40
-3
bitsandbytes/nn/modules.py
bitsandbytes/nn/modules.py
+3
-3
tests/test_modules.py
tests/test_modules.py
+37
-0
No files found.
bitsandbytes/nn/modules.py
View file @
2dfa3ce1
...
@@ -352,10 +352,10 @@ class LinearFP8(nn.Linear):
...
@@ -352,10 +352,10 @@ class LinearFP8(nn.Linear):
def
forward
(
self
,
x
:
torch
.
Tensor
):
def
forward
(
self
,
x
:
torch
.
Tensor
):
if
self
.
fw_code
is
None
:
if
self
.
fw_code
is
None
:
self
.
bw_code
=
F
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
bw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
5
,
2
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
F
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
self
.
fw_code
=
bnb
.
functional
.
create_fp8_map
(
True
,
4
,
3
,
8
).
to
(
x
.
device
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
code
=
self
.
bw_code
)
out
=
bnb
.
matmul_fp8
(
x
,
self
.
weight
.
t
(),
fw_code
=
self
.
fw_code
,
bw_
code
=
self
.
bw_code
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
out
+=
self
.
bias
out
+=
self
.
bias
...
...
tests/test_modules.py
View file @
2dfa3ce1
...
@@ -525,3 +525,40 @@ def test_linear8bitlt_fp32_bias():
...
@@ -525,3 +525,40 @@ def test_linear8bitlt_fp32_bias():
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
b1
=
torch
.
randn
(
16
,
8
,
32
,
device
=
"cuda"
).
half
()
o1
=
l1
(
b1
)
o1
=
l1
(
b1
)
assert
l1
.
bias
is
None
assert
l1
.
bias
is
None
def
test_fp8linear
():
b
=
10
h
=
1024
inp
=
torch
.
randn
(
b
,
h
).
cuda
()
fp32
=
torch
.
nn
.
Linear
(
h
,
h
*
2
).
cuda
()
fp8
=
bnb
.
nn
.
LinearFP8
(
h
,
h
*
2
).
cuda
()
fp32b
=
torch
.
nn
.
Linear
(
h
*
2
,
h
).
cuda
()
fp8b
=
bnb
.
nn
.
LinearFP8
(
h
*
2
,
h
).
cuda
()
fp8
.
weight
.
data
.
copy_
(
fp32
.
weight
.
data
)
fp8
.
bias
.
data
.
copy_
(
fp32
.
bias
.
data
)
fp8b
.
weight
.
data
.
copy_
(
fp32b
.
weight
.
data
)
fp8b
.
bias
.
data
.
copy_
(
fp32b
.
bias
.
data
)
a
=
fp32b
(
torch
.
nn
.
functional
.
gelu
(
fp32
(
inp
)))
b
=
fp8b
(
torch
.
nn
.
functional
.
gelu
(
fp8
(
inp
)))
err
=
(
a
-
b
).
abs
().
mean
()
a
.
mean
().
backward
()
b
.
mean
().
backward
()
graderr
=
(
fp8
.
weight
.
grad
-
fp32
.
weight
.
grad
).
abs
().
mean
()
bgraderr
=
(
fp8
.
bias
.
grad
-
fp32
.
bias
.
grad
).
abs
().
mean
()
assert
err
<
0.05
assert
graderr
<
0.00002
assert
bgraderr
<
0.00002
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment