Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
947db7cf
Unverified
Commit
947db7cf
authored
Jan 02, 2024
by
Tim Dettmers
Committed by
GitHub
Jan 02, 2024
Browse files
Merge pull request #436 from akx/quanitze
Fix typo "quanitze"
parents
8c5c6689
6b26402b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
11 deletions
+11
-11
benchmarking/switchback/speed_benchmark.py
benchmarking/switchback/speed_benchmark.py
+3
-3
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+6
-6
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+2
-2
No files found.
benchmarking/switchback/speed_benchmark.py
View file @
947db7cf
...
@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
...
@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
get_time
(
'standard_gx'
,
lambda
:
g
.
matmul
(
w
),
info
)
get_time
(
'standard_gx'
,
lambda
:
g
.
matmul
(
w
),
info
)
get_time
(
'rowwise_fwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_columnwise
,
None
),
info
)
get_time
(
'rowwise_fwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_columnwise
,
None
),
info
)
get_time
(
'rowwise_bwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_rowwise
,
None
),
info
)
get_time
(
'rowwise_bwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_rowwise
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'x_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
x
),
info
)
get_time
(
'x_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
x
),
info
)
get_time
(
'g_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
g
),
info
)
get_time
(
'g_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
g
),
info
)
get_time
(
'w_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
w
),
info
)
get_time
(
'w_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
w
),
info
)
...
...
bitsandbytes/nn/triton_based_modules.py
View file @
947db7cf
...
@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
...
@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
class
_switchback_global
(
torch
.
autograd
.
Function
):
class
_switchback_global
(
torch
.
autograd
.
Function
):
...
@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
...
@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
).
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
...
@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
...
@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
# so we transpose once then call .t() in the matmul
# so we transpose once then call .t() in the matmul
G_int8
,
state_G
=
quantize_rowwise
(
G
)
G_int8
,
state_G
=
quantize_rowwise
(
G
)
W_int8
,
state_W
=
quantize_global_transpose
(
W
)
W_int8
,
state_W
=
quantize_global_transpose
(
W
)
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D
.
size
()[:
-
1
],
-
1
*
G_3D
.
size
()[:
-
1
],
-
1
)
)
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
...
@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
...
@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
...
@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
...
@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8
,
state_G
=
quantize_rowwise
(
G
)
G_int8
,
state_G
=
quantize_rowwise
(
G
)
del
G
del
G
W_int8
=
W_int8
.
t
().
contiguous
()
W_int8
=
W_int8
.
t
().
contiguous
()
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D_sz
[:
-
1
],
-
1
*
G_3D_sz
[:
-
1
],
-
1
)
)
...
@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
...
@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
else
:
else
:
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
...
...
bitsandbytes/triton/int8_matmul_mixed_dequan
i
tze.py
→
bitsandbytes/triton/int8_matmul_mixed_dequant
i
ze.py
View file @
947db7cf
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
else
:
import
triton
import
triton
...
@@ -136,7 +136,7 @@ else:
...
@@ -136,7 +136,7 @@ else:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment