Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c3d87e44
"vscode:/vscode.git/clone" did not exist on "18200240053b1ef5f7beb0584c01dd6677927e84"
Commit
c3d87e44
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Added is_available_triton guard.
parent
7140c014
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
589 additions
and
553 deletions
+589
-553
bitsandbytes/research/autograd/_functions.py
bitsandbytes/research/autograd/_functions.py
+2
-2
bitsandbytes/triton/dequantize_rowwise.py
bitsandbytes/triton/dequantize_rowwise.py
+56
-50
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
+148
-143
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+161
-156
bitsandbytes/triton/quantize_columnwise_and_transpose.py
bitsandbytes/triton/quantize_columnwise_and_transpose.py
+64
-58
bitsandbytes/triton/quantize_global.py
bitsandbytes/triton/quantize_global.py
+94
-87
bitsandbytes/triton/quantize_rowwise.py
bitsandbytes/triton/quantize_rowwise.py
+64
-57
No files found.
bitsandbytes/research/autograd/_functions.py
View file @
c3d87e44
...
...
@@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMul8bitMixed
(
torch
.
autograd
.
Function
):
class
SwitchBackBnb
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
...
...
@@ -408,4 +408,4 @@ def switchback_bnb(
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
state
.
threshold
=
threshold
return
MatMul8bitMixed
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
SwitchBackBnb
.
apply
(
A
,
B
,
out
,
bias
,
state
)
bitsandbytes/triton/dequantize_rowwise.py
View file @
c3d87e44
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
bitsandbytes.triton.triton_utils
import
is_triton_available
# rowwise quantize
if
not
is_triton_available
():
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
return
None
else
:
# TODO: autotune this better.
@
triton
.
autotune
(
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
...
...
@@ -24,9 +30,9 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_dequantize_rowwise
(
)
@
triton
.
jit
def
_dequantize_rowwise
(
x_ptr
,
state_x
,
output_ptr
,
...
...
@@ -34,7 +40,7 @@ def _dequantize_rowwise(
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
...
...
@@ -46,7 +52,7 @@ def _dequantize_rowwise(
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
...
...
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
View file @
c3d87e44
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
i
mport
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
i
f
not
is_triton_available
():
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
...
...
@@ -30,7 +35,7 @@ def get_configs_io_bound():
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
...
...
@@ -59,12 +64,12 @@ def get_configs_io_bound():
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
})
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
...
...
@@ -131,7 +136,7 @@ def _int8_matmul_mixed_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M, N,
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
...
...
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
View file @
c3d87e44
import
torch
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
bitsandbytes.triton.triton_utils
import
is_triton_available
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
if
not
is_triton_available
():
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
else
:
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
def
init_to_zero
(
name
):
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
return
lambda
nargs
:
nargs
[
name
].
zero_
()
def
get_configs_io_bound
():
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
...
...
@@ -30,7 +35,7 @@ def get_configs_io_bound():
return
configs
@
triton
.
autotune
(
@
triton
.
autotune
(
configs
=
[
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
...
...
@@ -59,12 +64,12 @@ def get_configs_io_bound():
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
})
@
triton
.
jit
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
...
...
@@ -130,7 +135,7 @@ def _int8_matmul_rowwise_dequantize(A, B, C, bias, state_x_ptr, state_w_ptr, M,
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
...
...
bitsandbytes/triton/quantize_columnwise_and_transpose.py
View file @
c3d87e44
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
bitsandbytes.triton.triton_utils
import
is_triton_available
# This kernel does fused columnwise quantization and transpose.
if
not
is_triton_available
():
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
# TODO: autotune this better.
@
triton
.
autotune
(
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
...
...
@@ -26,9 +32,9 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
)
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
x_ptr
,
output_ptr
,
output_maxs
,
...
...
@@ -36,7 +42,7 @@ def _quantize_columnwise_and_transpose(
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
...
...
@@ -53,7 +59,7 @@ def _quantize_columnwise_and_transpose(
tl
.
store
(
output_ptr
+
new_offsets
,
output
,
mask
=
p2_arange_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
M
,
N
=
x
.
shape
output
=
torch
.
empty
(
N
,
M
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
1
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
...
...
bitsandbytes/triton/quantize_global.py
View file @
c3d87e44
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
bitsandbytes.triton.triton_utils
import
is_triton_available
# global quantize
@
triton
.
autotune
(
if
not
is_triton_available
():
def
quantize_global_transpose
(
input
):
return
None
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# global quantize
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_global
(
)
@
triton
.
jit
def
_quantize_global
(
x_ptr
,
absmax_inv_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
...
...
@@ -31,7 +38,7 @@ def _quantize_global(
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
def
quantize_global
(
x
:
torch
.
Tensor
):
def
quantize_global
(
x
:
torch
.
Tensor
):
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
...
...
@@ -42,8 +49,8 @@ def quantize_global(x: torch.Tensor):
return
output
,
absmax
# global quantize and transpose
@
triton
.
autotune
(
# global quantize and transpose
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
...
...
@@ -51,9 +58,9 @@ def quantize_global(x: torch.Tensor):
# ...
],
key
=
[
'M'
,
'N'
]
)
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
)
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
):
...
...
@@ -84,7 +91,7 @@ def _quantize_global_transpose(A, absmax_inv_ptr, B, stride_am, stride_an, strid
tl
.
store
(
B
,
output
,
mask
=
mask
)
def
quantize_global_transpose
(
input
):
def
quantize_global_transpose
(
input
):
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
M
,
N
=
input
.
shape
...
...
bitsandbytes/triton/quantize_rowwise.py
View file @
c3d87e44
import
math
import
torch
import
time
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantiz
e
from
bitsandbytes.triton.triton_utils
import
is_triton_availabl
e
# TODO: autotune this better.
@
triton
.
autotune
(
if
not
is_triton_available
():
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
...
...
@@ -24,16 +31,16 @@ from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_tim
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_rowwise
(
)
@
triton
.
jit
def
_quantize_rowwise
(
x_ptr
,
output_ptr
,
output_maxs
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
...
...
@@ -47,7 +54,7 @@ def _quantize_rowwise(
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment