Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
389f66ca
Commit
389f66ca
authored
Jul 27, 2022
by
Tim Dettmers
Browse files
Fixed direct extraction masking.
parent
a4092136
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
10 deletions
+12
-10
bitsandbytes/autograd/_functions.py
bitsandbytes/autograd/_functions.py
+12
-10
No files found.
bitsandbytes/autograd/_functions.py
View file @
389f66ca
...
@@ -191,6 +191,7 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -191,6 +191,7 @@ class MatMul8bitLt(torch.autograd.Function):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
# we also need to convert it to the turing/ampere format
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
...
@@ -214,7 +215,6 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -214,7 +215,6 @@ class MatMul8bitLt(torch.autograd.Function):
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
state
.
CxB
,
state
.
SB
=
F
.
transform
(
state
.
CB
,
to_order
=
formatB
)
subA
=
None
subA
=
None
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
# 2. Quantize B
# 2. Quantize B
if
state
.
has_fp16_weights
:
if
state
.
has_fp16_weights
:
...
@@ -233,14 +233,15 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -233,14 +233,15 @@ class MatMul8bitLt(torch.autograd.Function):
# extract outliers
# extract outliers
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
outlier_idx
=
torch
.
unique
(
coo_tensorA
.
colidx
)
state
.
outlier_pool
.
add_outliers
(
outlier_idx
,
A
.
shape
[
-
1
])
state
.
idx
=
outlier_idx
if
state
.
use_pool
and
state
.
outlier_pool
.
model_dim
==
A
.
shape
[
-
1
]:
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# do not use pool for 2nd FFN layer
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
state
.
idx
=
state
.
outlier_pool
.
get_current_outlier_idx
().
to
(
A
.
device
)
# # do not use pool for 2nd FFN layer
else
:
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
state
.
idx
=
outlier_idx
#else:
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
outlier_idx
).
half
()
# state.idx = outlier_idx
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
).
half
()
/
127.0
).
t
().
contiguous
()
outliers
=
F
.
extract_outliers
(
state
.
CxB
,
state
.
SB
,
state
.
idx
.
int
())
state
.
subB
=
(
outliers
*
state
.
SCB
.
view
(
-
1
,
1
)
/
127.0
).
t
().
contiguous
().
half
()
CA
[:,
state
.
idx
.
long
()]
=
0
CA
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
CAt
[:,
state
.
idx
.
long
()]
=
0
subA
=
A
[:,
state
.
idx
.
long
()]
subA
=
A
[:,
state
.
idx
.
long
()]
...
@@ -253,11 +254,12 @@ class MatMul8bitLt(torch.autograd.Function):
...
@@ -253,11 +254,12 @@ class MatMul8bitLt(torch.autograd.Function):
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
output_shape
=
(
input_shape
[
0
],
shapeB
[
0
])
# 3. Matmul
# 3. Matmul
C32A
,
SA
=
F
.
transform
(
CA
,
'col32'
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
out32
,
Sout32
=
F
.
igemmlt
(
C32A
,
state
.
CxB
,
SA
,
state
.
SB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
output
=
F
.
mm_dequant
(
out32
,
Sout32
,
SCA
,
state
.
SCB
)
# 4. Mixed-precision decomposition matmul
# 4. Mixed-precision decomposition matmul
if
state
.
threshold
>
0.0
and
coo_tensorA
is
not
None
and
subA
is
not
None
:
if
coo_tensorA
is
not
None
and
subA
is
not
None
:
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
output
+=
torch
.
matmul
(
subA
,
state
.
subB
)
# 5. Save state
# 5. Save state
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment