Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
6b26402b
Commit
6b26402b
authored
May 25, 2023
by
Aarni Koskela
Browse files
Fix typo "quanitze"
parent
0f40fa3f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
11 deletions
+11
-11
benchmarking/switchback/speed_benchmark.py
benchmarking/switchback/speed_benchmark.py
+3
-3
bitsandbytes/nn/triton_based_modules.py
bitsandbytes/nn/triton_based_modules.py
+6
-6
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
bitsandbytes/triton/int8_matmul_mixed_dequantize.py
+2
-2
No files found.
benchmarking/switchback/speed_benchmark.py
View file @
6b26402b
...
@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
...
@@ -8,7 +8,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
# KNOW ISSUE: need to optimize "w_quantize_colwise_transpose" when embeddim is too large.
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
...
@@ -72,8 +72,8 @@ if __name__ == '__main__':
get_time
(
'standard_gx'
,
lambda
:
g
.
matmul
(
w
),
info
)
get_time
(
'standard_gx'
,
lambda
:
g
.
matmul
(
w
),
info
)
get_time
(
'rowwise_fwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_columnwise
,
None
),
info
)
get_time
(
'rowwise_fwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_columnwise
,
None
),
info
)
get_time
(
'rowwise_bwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_rowwise
,
None
),
info
)
get_time
(
'rowwise_bwd'
,
lambda
:
int8_matmul_rowwise_dequantize
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_rowwise
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_fwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
x_int8
,
w_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequan
i
tze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'global_bwd'
,
lambda
:
int8_matmul_mixed_dequant
i
ze
(
g_int8
,
wt_int8
.
t
(),
state_x_rowwise
,
state_w_global
,
None
),
info
)
get_time
(
'x_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
x
),
info
)
get_time
(
'x_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
x
),
info
)
get_time
(
'g_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
g
),
info
)
get_time
(
'g_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
g
),
info
)
get_time
(
'w_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
w
),
info
)
get_time
(
'w_quantize_rowwise'
,
lambda
:
quantize_rowwise
(
w
),
info
)
...
...
bitsandbytes/nn/triton_based_modules.py
View file @
6b26402b
...
@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
...
@@ -10,7 +10,7 @@ from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.quantize_columnwise_and_transpose
import
quantize_columnwise_and_transpose
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.int8_matmul_rowwise_dequantize
import
int8_matmul_rowwise_dequantize
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.quantize_global
import
quantize_global
,
quantize_global_transpose
from
bitsandbytes.triton.int8_matmul_mixed_dequan
i
tze
import
int8_matmul_mixed_dequan
i
tze
from
bitsandbytes.triton.int8_matmul_mixed_dequant
i
ze
import
int8_matmul_mixed_dequant
i
ze
class
_switchback_global
(
torch
.
autograd
.
Function
):
class
_switchback_global
(
torch
.
autograd
.
Function
):
...
@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
...
@@ -29,7 +29,7 @@ class _switchback_global(torch.autograd.Function):
# matmult, fused dequant and add bias
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
).
view
(
*
X_3D
.
size
()[:
-
1
],
-
1
)
...
@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
...
@@ -47,7 +47,7 @@ class _switchback_global(torch.autograd.Function):
# so we transpose once then call .t() in the matmul
# so we transpose once then call .t() in the matmul
G_int8
,
state_G
=
quantize_rowwise
(
G
)
G_int8
,
state_G
=
quantize_rowwise
(
G
)
W_int8
,
state_W
=
quantize_global_transpose
(
W
)
W_int8
,
state_W
=
quantize_global_transpose
(
W
)
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D
.
size
()[:
-
1
],
-
1
*
G_3D
.
size
()[:
-
1
],
-
1
)
)
if
ctx
.
needs_input_grad
[
1
]:
if
ctx
.
needs_input_grad
[
1
]:
...
@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
...
@@ -119,7 +119,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
# matmult, fused dequant and add bias
# matmult, fused dequant and add bias
# call "mixed" because we are mixing rowwise quantized and global quantized
# call "mixed" because we are mixing rowwise quantized and global quantized
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
X_int8
,
W_int8
.
t
(),
state_X
,
state_W
,
bias
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
).
view
(
*
X_3D_sz
[:
-
1
],
-
1
)
...
@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
...
@@ -143,7 +143,7 @@ class _switchback_global_mem_efficient(torch.autograd.Function):
G_int8
,
state_G
=
quantize_rowwise
(
G
)
G_int8
,
state_G
=
quantize_rowwise
(
G
)
del
G
del
G
W_int8
=
W_int8
.
t
().
contiguous
()
W_int8
=
W_int8
.
t
().
contiguous
()
grad_X
=
int8_matmul_mixed_dequan
i
tze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
grad_X
=
int8_matmul_mixed_dequant
i
ze
(
G_int8
,
W_int8
.
t
(),
state_G
,
state_W
,
None
).
view
(
*
G_3D_sz
[:
-
1
],
-
1
*
G_3D_sz
[:
-
1
],
-
1
)
)
...
@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
...
@@ -215,7 +215,7 @@ class SwitchBackLinear(nn.Linear):
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
else
:
else
:
return
int8_matmul_mixed_dequan
i
tze
(
return
int8_matmul_mixed_dequant
i
ze
(
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
X_int8
,
self
.
W_int8
.
t
(),
state_X
,
self
.
state_W
,
self
.
bias
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
).
view
(
*
x
.
size
()[:
-
1
],
-
1
)
...
...
bitsandbytes/triton/int8_matmul_mixed_dequan
i
tze.py
→
bitsandbytes/triton/int8_matmul_mixed_dequant
i
ze.py
View file @
6b26402b
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
bitsandbytes.triton.triton_utils
import
is_triton_available
if
not
is_triton_available
():
if
not
is_triton_available
():
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
else
:
import
triton
import
triton
...
@@ -136,7 +136,7 @@ else:
...
@@ -136,7 +136,7 @@ else:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequan
i
tze
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequant
i
ze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment