Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
947db7cf
"...text-generation-inference.git" did not exist on "709d8936f68002c2244e245607c6b88d658ebe6f"
Unverified
Commit
947db7cf
authored
Jan 02, 2024
by
Tim Dettmers
Committed by
GitHub
Jan 02, 2024
Browse files
Merge pull request #436 from akx/quanitze
Fix typo "quanitze"
parents
8c5c6689
6b26402b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
11 deletions
+11
-11
benchmarking/switchback/speed_benchmark.py
benchmarking/switchback/speed_benchmark.py
+3
-3
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+6
-6
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+2
-2
No files found.
benchmarking/switchback/speed_benchmark.py
View file @
947db7cf
...
...
@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
...
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
get_time
(
'standard_gx'
,
lambda
:
g
.
matmul
(
w
),
info
)
get_time
(
'rowwise_fwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_columnwise
,
None
),
info
)
get_time
(
'rowwise_bwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_rowwise
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'x_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
x
),
info
)
get_time
(
'g_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
g
),
info
)
get_time
(
'w_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
w
),
info
)
...
...
bitsandbytes/nn/triton_based_modules.py
View file @
947db7cf
...
...
@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
class
_switchback_global
(
torch
.
autograd
.
Function
):
...
...
@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
...
...
@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
# so we transpose once then call .t() in the matmul
G_int8
,
state_G
=
quantize_rowwise
(
G
)
W_int8
,
state_W
=
quantize_global_transpose
(
W
)
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D
.
size
()[:
-
1
],
-
1
)
if
ctx
.
needs_input_grad
[
1
]:
...
...
@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
...
...
@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8
,
state_G
=
quantize_rowwise
(
G
)
del
G
W_int8
=
W_int8
.
t
().
contiguous
()
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D_sz
[:
-
1
],
-
1
)
...
...
@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
else
:
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
...
...
bitsandbytes/triton/int8_matmul_mixed_dequan
i
tze.py
→
bitsandbytes/triton/int8_matmul_mixed_dequant
i
ze.py
View file @
947db7cf
...
...
@@ -2,7 +2,7 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
import
triton
...
...
@@ -136,7 +136,7 @@ else:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment