Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
bitsandbytes
Commits
c3d87e44
Commit
c3d87e44
authored
Apr 12, 2023
by
Tim Dettmers
Browse files
Added is_available_triton guard.
parent
7140c014
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
589 additions
and
553 deletions
+589
-553
bitsandbytes/research/autograd/_functions.py
bitsandbytes/research/autograd/_functions.py
+2
-2
bitsandbytes/triton/dequantize_rowwise.py
bitsandbytes/triton/dequantize_rowwise.py
+56
-50
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
+148
-143
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
+161
-156
bitsandbytes/triton/quantize_columnwise_and_transpose.py
bitsandbytes/triton/quantize_columnwise_and_transpose.py
+64
-58
bitsandbytes/triton/quantize_global.py
bitsandbytes/triton/quantize_global.py
+94
-87
bitsandbytes/triton/quantize_rowwise.py
bitsandbytes/triton/quantize_rowwise.py
+64
-57
No files found.
bitsandbytes/research/autograd/_functions.py
View file @
c3d87e44
...
@@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
...
@@ -184,7 +184,7 @@ class MatMulFP8Global(torch.autograd.Function):
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
return
grad_A
,
grad_B
,
None
,
None
,
None
,
None
,
None
class
MatMul8bitMixed
(
torch
.
autograd
.
Function
):
class
SwitchBackBnb
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
def
forward
(
ctx
,
A
,
B
,
out
=
None
,
bias
=
None
,
state
=
MatmulLtState
()):
# default to pytorch behavior if inputs are empty
# default to pytorch behavior if inputs are empty
...
@@ -408,4 +408,4 @@ def switchback_bnb(
...
@@ -408,4 +408,4 @@ def switchback_bnb(
state
=
state
or
MatmulLtState
()
state
=
state
or
MatmulLtState
()
if
threshold
>
0.0
:
if
threshold
>
0.0
:
state
.
threshold
=
threshold
state
.
threshold
=
threshold
return
MatMul8bitMixed
.
apply
(
A
,
B
,
out
,
bias
,
state
)
return
SwitchBackBnb
.
apply
(
A
,
B
,
out
,
bias
,
state
)
bitsandbytes/triton/dequantize_rowwise.py
View file @
c3d87e44
import
math
import
math
import
torch
import
torch
import
time
import
time
import
triton
from
bitsandbytes.triton.triton_utils
import
is_triton_available
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# rowwise quantize
if
not
is_triton_available
():
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
return
None
else
:
# TODO: autotune this better.
import
triton
@
triton
.
autotune
(
import
triton.language
as
tl
configs
=
[
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_dequantize_rowwise
(
x_ptr
,
state_x
,
output_ptr
,
inv_127
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
offsets
=
block_start
+
arange
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
max_val
=
tl
.
load
(
state_x
+
pid
)
output
=
max_val
*
x
*
inv_127
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
# rowwise quantize
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
# TODO: autotune this better.
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_dequantize_rowwise
(
x_ptr
,
state_x
,
output_ptr
,
inv_127
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
arange
=
tl
.
arange
(
0
,
P2
)
offsets
=
block_start
+
arange
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
max_val
=
tl
.
load
(
state_x
+
pid
)
output
=
max_val
*
x
*
inv_127
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
assert
x
.
is_cuda
and
output
.
is_cuda
def
dequantize_rowwise
(
x
:
torch
.
Tensor
,
state_x
:
torch
.
Tensor
):
n_elements
=
output
.
numel
()
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
float16
)
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
return
output
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_dequantize_rowwise
[
grid
](
x
,
state_x
,
output
,
1.
/
127
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
bitsandbytes/triton/int8_matmul_mixed_dequanitze.py
View file @
c3d87e44
import
torch
import
torch
from
bitsandbytes.triton.triton_utils
import
is_triton_available
i
mport
triton
i
f
not
is_triton_available
():
import
triton.language
as
tl
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
else
:
import
triton
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This is a matmul kernel based on triton.ops.matmul
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
init_to_zero
(
name
):
# This is a matmul kernel based on triton.ops.matmul
return
lambda
nargs
:
nargs
[
name
].
zero_
()
# It is modified to support rowwise quantized input and global quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
get_configs_io_bound
():
def
init_to_zero
(
name
):
configs
=
[]
return
lambda
nargs
:
nargs
[
name
].
zero_
()
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
))
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
return
configs
def
get_configs_io_bound
():
configs
=
[]
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
num_stages
=
num_stages
,
num_warps
=
num_warps
))
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
return
configs
@
triton
.
autotune
(
configs
=
[
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
'M'
,
'N'
,
'K'
],
prune_configs_by
=
{
'early_config_prune'
:
early_config_prune
,
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
):
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# do matrix multiplication
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
# pointers
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
# rematerialize rm and rn to save registers
@
triton
.
autotune
(
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
configs
=
[
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
]
+
get_configs_io_bound
(),
key
=
[
'M'
,
'N'
,
'K'
],
prune_configs_by
=
{
'early_config_prune'
:
early_config_prune
,
'perf_model'
:
estimate_matmul_time
,
'top_k'
:
10
},
)
@
triton
.
heuristics
({
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
})
@
triton
.
jit
def
_int8_matmul_mixed_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
:
tl
.
constexpr
,
has_bias
:
tl
.
constexpr
,
stride_am
,
stride_ak
,
stride_bk
,
stride_bn
,
stride_cm
,
stride_cn
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
):
# matrix multiplication
pid
=
tl
.
program_id
(
0
)
pid_z
=
tl
.
program_id
(
1
)
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
# re-order program ID for better L2 performance
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# do matrix multiplication
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
# pointers
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
w_factor
=
tl
.
load
(
state_w_ptr
)
# rematerialize rm and rn to save registers
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
w_factor
=
tl
.
load
(
state_w_ptr
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
a
=
tl
.
load
(
A
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
b
=
tl
.
load
(
B
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
else
:
if
EVEN_K
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
else
:
acc
+=
tl
.
dot
(
a
,
b
)
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
acc
+=
tl
.
dot
(
a
,
b
)
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
# conditionally add bias
# conditionally add bias
if
has_bias
:
if
has_bias
:
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
+
bias
[
None
,
:]
acc
=
acc
+
bias
[
None
,
:]
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
# handles write-back with reduction-splitting
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
if
SPLIT_K
==
1
:
tl
.
store
(
C
,
acc
,
mask
=
mask
)
tl
.
store
(
C
,
acc
,
mask
=
mask
)
else
:
else
:
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
def
int8_matmul_mixed_dequanitze
(
a
,
b
,
state_x
,
state_w
,
bias
):
device
=
a
.
device
device
=
a
.
device
divfactor
=
1.
/
(
127.
*
127.
)
divfactor
=
1.
/
(
127.
*
127.
)
has_bias
=
0
if
bias
is
None
else
1
has_bias
=
0
if
bias
is
None
else
1
# handle non-contiguous inputs if necessary
# handle non-contiguous inputs if necessary
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
a
=
a
.
contiguous
()
a
=
a
.
contiguous
()
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
b
=
b
.
contiguous
()
b
=
b
.
contiguous
()
# checks constraints
# checks constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
M
,
K
=
a
.
shape
M
,
K
=
a
.
shape
_
,
N
=
b
.
shape
_
,
N
=
b
.
shape
# allocates output
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
# accumulator types
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_mixed_dequantize kernel
# launch int8_matmul_mixed_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
_int8_matmul_mixed_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
_int8_matmul_mixed_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
return
c
return
c
bitsandbytes/triton/int8_matmul_rowwise_dequantize.py
View file @
c3d87e44
import
torch
import
torch
import
triton
from
bitsandbytes.triton.triton_utils
import
is_triton_available
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
if
not
is_triton_available
():
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
return
None
# This is a matmul kernel based on triton.ops.matmul
else
:
# It is modified to support rowwise quantized input and columnwise quantized weight
import
triton
# It's purpose is fused matmul then dequantize
import
triton.language
as
tl
# It does support bias.
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
def
init_to_zero
(
name
):
# This is a matmul kernel based on triton.ops.matmul
return
lambda
nargs
:
nargs
[
name
].
zero_
()
# It is modified to support rowwise quantized input and columnwise quantized weight
# It's purpose is fused matmul then dequantize
# It does support bias.
def
get_configs_io_bound
():
configs
=
[]
def
init_to_zero
(
name
):
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
return
lambda
nargs
:
nargs
[
name
].
zero_
()
for
block_m
in
[
16
,
32
]:
for
block_k
in
[
32
,
64
]:
for
block_n
in
[
32
,
64
,
128
,
256
]:
def
get_configs_io_bound
():
num_warps
=
2
if
block_n
<=
64
else
4
configs
=
[]
configs
.
append
(
for
num_stages
in
[
2
,
3
,
4
,
5
,
6
]:
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
for
block_m
in
[
16
,
32
]:
num_stages
=
num_stages
,
num_warps
=
num_warps
))
for
block_k
in
[
32
,
64
]:
# split_k
for
block_n
in
[
32
,
64
,
128
,
256
]:
for
split_k
in
[
2
,
4
,
8
,
16
]:
num_warps
=
2
if
block_n
<=
64
else
4
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
.
append
(
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
1
},
return
configs
num_stages
=
num_stages
,
num_warps
=
num_warps
))
# split_k
for
split_k
in
[
2
,
4
,
8
,
16
]:
@
triton
.
autotune
(
configs
.
append
(
triton
.
Config
({
'BLOCK_M'
:
block_m
,
'BLOCK_N'
:
block_n
,
'BLOCK_K'
:
block_k
,
'SPLIT_K'
:
split_k
},
configs
=
[
num_stages
=
num_stages
,
num_warps
=
num_warps
,
pre_hook
=
init_to_zero
(
'C'
)))
# basic configs for compute-bound matmuls
return
configs
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
@
triton
.
autotune
(
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
# basic configs for compute-bound matmuls
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
32
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
# good for int8
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
3
,
num_warps
=
8
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
256
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
256
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
]
+
get_configs_io_bound
(),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
128
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
key
=
[
'M'
,
'N'
,
'K'
],
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
64
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
prune_configs_by
=
{
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
128
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
'early_config_prune'
:
early_config_prune
,
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
4
,
num_warps
=
4
),
'perf_model'
:
estimate_matmul_time
,
triton
.
Config
({
'BLOCK_M'
:
64
,
'BLOCK_N'
:
32
,
'BLOCK_K'
:
64
,
'SPLIT_K'
:
1
},
num_stages
=
5
,
num_warps
=
2
),
'top_k'
:
10
]
+
get_configs_io_bound
(),
},
key
=
[
'M'
,
'N'
,
'K'
],
)
prune_configs_by
=
{
@
triton
.
heuristics
({
'early_config_prune'
:
early_config_prune
,
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
'perf_model'
:
estimate_matmul_time
,
})
'top_k'
:
10
@
triton
.
jit
},
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
)
stride_am
,
stride_ak
,
@
triton
.
heuristics
({
stride_bk
,
stride_bn
,
'EVEN_K'
:
lambda
args
:
args
[
'K'
]
%
(
args
[
'BLOCK_K'
]
*
args
[
'SPLIT_K'
])
==
0
,
stride_cm
,
stride_cn
,
})
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
@
triton
.
jit
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
def
_int8_matmul_rowwise_dequantize
(
A
,
B
,
C
,
bias
,
state_x_ptr
,
state_w_ptr
,
M
,
N
,
K
,
divfactor
,
has_bias
:
tl
.
constexpr
,
ACC_TYPE
:
tl
.
constexpr
stride_am
,
stride_ak
,
):
stride_bk
,
stride_bn
,
# matrix multiplication
stride_cm
,
stride_cn
,
pid
=
tl
.
program_id
(
0
)
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
BLOCK_K
:
tl
.
constexpr
,
pid_z
=
tl
.
program_id
(
1
)
GROUP_M
:
tl
.
constexpr
,
SPLIT_K
:
tl
.
constexpr
,
EVEN_K
:
tl
.
constexpr
,
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
ACC_TYPE
:
tl
.
constexpr
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
):
# re-order program ID for better L2 performance
# matrix multiplication
width
=
GROUP_M
*
grid_n
pid
=
tl
.
program_id
(
0
)
group_id
=
pid
//
width
pid_z
=
tl
.
program_id
(
1
)
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
grid_m
=
tl
.
cdiv
(
M
,
BLOCK_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
grid_n
=
tl
.
cdiv
(
N
,
BLOCK_N
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
# re-order program ID for better L2 performance
# do matrix multiplication
width
=
GROUP_M
*
grid_n
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
group_id
=
pid
//
width
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
pid_n
=
(
pid
%
width
)
//
(
group_size
)
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
# do matrix multiplication
# pointers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
ram
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rm
%
M
,
BLOCK_M
),
BLOCK_M
)
rbn
=
tl
.
max_contiguous
(
tl
.
multiple_of
(
rn
%
N
,
BLOCK_N
),
BLOCK_N
)
# rematerialize rm and rn to save registers
rk
=
pid_z
*
BLOCK_K
+
tl
.
arange
(
0
,
BLOCK_K
)
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# pointers
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
A
=
A
+
(
ram
[:,
None
]
*
stride_am
+
rk
[
None
,
:]
*
stride_ak
)
B
=
B
+
(
rk
[:,
None
]
*
stride_bk
+
rbn
[
None
,
:]
*
stride_bn
)
w_factor
=
tl
.
load
(
state_w_ptr
+
rbn
)[
None
,
:]
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
# rematerialize rm and rn to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
w_factor
=
tl
.
load
(
state_w_ptr
+
rbn
)[
None
,
:]
if
EVEN_K
:
x_factor
=
tl
.
load
(
state_x_ptr
+
ram
)[:,
None
]
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
# acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
int32
)
for
k
in
range
(
0
,
tl
.
cdiv
(
K
,
BLOCK_K
*
SPLIT_K
)):
if
EVEN_K
:
a
=
tl
.
load
(
A
)
b
=
tl
.
load
(
B
)
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
acc
+=
tl
.
dot
(
a
,
b
)
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
if
has_bias
:
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
acc
=
acc
+
bias
[
None
,
:]
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
# handles write-back with reduction-splitting
if
SPLIT_K
==
1
:
tl
.
store
(
C
,
acc
,
mask
=
mask
)
else
:
else
:
k_remaining
=
K
-
k
*
(
BLOCK_K
*
SPLIT_K
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
a
=
tl
.
load
(
A
,
mask
=
rk
[
None
,
:]
<
k_remaining
,
other
=
0.
)
b
=
tl
.
load
(
B
,
mask
=
rk
[:,
None
]
<
k_remaining
,
other
=
0.
)
acc
+=
tl
.
dot
(
a
,
b
)
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
A
+=
BLOCK_K
*
SPLIT_K
*
stride_ak
divfactor
=
1.
/
(
127.
*
127.
)
B
+=
BLOCK_K
*
SPLIT_K
*
stride_bk
has_bias
=
0
if
bias
is
None
else
1
acc
=
(
w_factor
*
(
x_factor
*
(
acc
*
divfactor
)))
acc
=
acc
.
to
(
C
.
dtype
.
element_ty
)
device
=
a
.
device
# handle non-contiguous inputs if necessary
if
has_bias
:
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
bias
=
tl
.
load
(
bias
+
rn
).
to
(
C
.
dtype
.
element_ty
)
a
=
a
.
contiguous
()
acc
=
acc
+
bias
[
None
,
:]
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
b
=
b
.
contiguous
()
C
=
C
+
(
rm
[:,
None
]
*
stride_cm
+
rn
[
None
,
:]
*
stride_cn
)
# checks constraints
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
# handles write-back with reduction-splitting
M
,
K
=
a
.
shape
if
SPLIT_K
==
1
:
_
,
N
=
b
.
shape
tl
.
store
(
C
,
acc
,
mask
=
mask
)
# allocates output
else
:
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
tl
.
atomic_add
(
C
,
acc
,
mask
=
mask
)
# accumulator types
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
def
int8_matmul_rowwise_dequantize
(
a
,
b
,
state_x
,
state_w
,
bias
):
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
divfactor
=
1.
/
(
127.
*
127.
)
_int8_matmul_rowwise_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
has_bias
=
0
if
bias
is
None
else
1
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
device
=
a
.
device
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
# handle non-contiguous inputs if necessary
return
c
if
a
.
stride
(
0
)
>
1
and
a
.
stride
(
1
)
>
1
:
a
=
a
.
contiguous
()
if
b
.
stride
(
0
)
>
1
and
b
.
stride
(
1
)
>
1
:
b
=
b
.
contiguous
()
# checks constraints
assert
a
.
shape
[
1
]
==
b
.
shape
[
0
],
"incompatible dimensions"
M
,
K
=
a
.
shape
_
,
N
=
b
.
shape
# allocates output
c
=
torch
.
empty
((
M
,
N
),
device
=
device
,
dtype
=
torch
.
float16
)
# accumulator types
ACC_TYPE
=
tl
.
float32
#if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32
# launch int8_matmul_rowwise_dequantize kernel
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),
META
[
'SPLIT_K'
])
_int8_matmul_rowwise_dequantize
[
grid
](
a
,
b
,
c
,
bias
,
state_x
,
state_w
,
M
,
N
,
K
,
divfactor
,
has_bias
,
a
.
stride
(
0
),
a
.
stride
(
1
),
b
.
stride
(
0
),
b
.
stride
(
1
),
c
.
stride
(
0
),
c
.
stride
(
1
),
GROUP_M
=
8
,
ACC_TYPE
=
ACC_TYPE
)
return
c
bitsandbytes/triton/quantize_columnwise_and_transpose.py
View file @
c3d87e44
import
math
import
math
import
torch
import
torch
import
time
import
time
import
triton
from
bitsandbytes.triton.triton_utils
import
is_triton_available
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# This kernel does fused columnwise quantization and transpose.
if
not
is_triton_available
():
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
return
None
else
:
# TODO: autotune this better.
import
triton
@
triton
.
autotune
(
import
triton.language
as
tl
configs
=
[
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
x_ptr
,
output_ptr
,
output_maxs
,
n_elements
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange_mask
=
p2_arange
<
M
arange
=
p2_arange
*
N
offsets
=
block_start
+
arange
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
new_start
=
pid
*
M
# This kernel does fused columnwise quantization and transpose.
new_offsets
=
new_start
+
p2_arange
tl
.
store
(
output_ptr
+
new_offsets
,
output
,
mask
=
p2_arange_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
# TODO: autotune this better.
M
,
N
=
x
.
shape
@
triton
.
autotune
(
output
=
torch
.
empty
(
N
,
M
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
configs
=
[
output_maxs
=
torch
.
empty
(
x
.
shape
[
1
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
triton
.
Config
({},
num_stages
=
1
),
triton
.
Config
({},
num_stages
=
2
),
triton
.
Config
({},
num_stages
=
4
),
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
16
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
16
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_warps
=
8
),
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_columnwise_and_transpose
(
x_ptr
,
output_ptr
,
output_maxs
,
n_elements
,
M
:
tl
.
constexpr
,
N
:
tl
.
constexpr
,
BLOCK_SIZE
:
tl
.
constexpr
,
P2
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
p2_arange
=
tl
.
arange
(
0
,
P2
)
p2_arange_mask
=
p2_arange
<
M
arange
=
p2_arange
*
N
offsets
=
block_start
+
arange
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
p2_arange_mask
)
abs_x
=
tl
.
abs
(
x
)
max_val
=
tl
.
max
(
tl
.
where
(
p2_arange_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
M
))))
new_start
=
pid
*
M
new_offsets
=
new_start
+
p2_arange
tl
.
store
(
output_ptr
+
new_offsets
,
output
,
mask
=
p2_arange_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
assert
x
.
is_cuda
and
output
.
is_cuda
def
quantize_columnwise_and_transpose
(
x
:
torch
.
Tensor
):
n_elements
=
output
.
numel
()
M
,
N
=
x
.
shape
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
output
=
torch
.
empty
(
N
,
M
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
1
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
return
output
,
output_maxs
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
M
))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
_quantize_columnwise_and_transpose
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
M
,
N
,
BLOCK_SIZE
=
M
,
P2
=
P2
)
return
output
,
output_maxs
bitsandbytes/triton/quantize_global.py
View file @
c3d87e44
import
math
import
math
import
torch
import
torch
import
time
import
time
import
triton
from
bitsandbytes.triton.triton_utils
import
is_triton_available
import
triton.language
as
tl
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
# global quantize
if
not
is_triton_available
():
@
triton
.
autotune
(
def
quantize_global_transpose
(
input
):
return
None
configs
=
[
def
quantize_global
(
x
:
torch
.
Tensor
):
return
None
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
else
:
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
],
import
triton
key
=
[
'n_elements'
]
import
triton.language
as
tl
)
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
@
triton
.
jit
def
_quantize_global
(
x_ptr
,
absmax_inv_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
def
quantize_global
(
x
:
torch
.
Tensor
):
# global quantize
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
@
triton
.
autotune
(
absmax_inv
=
1.
/
absmax
configs
=
[
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
triton
.
Config
({
'BLOCK_SIZE'
:
1024
,},
num_warps
=
4
),
assert
x
.
is_cuda
and
output
.
is_cuda
triton
.
Config
({
'BLOCK_SIZE'
:
2048
,},
num_stages
=
1
),
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
return
output
,
absmax
],
key
=
[
'n_elements'
]
)
@
triton
.
jit
def
_quantize_global
(
x_ptr
,
absmax_inv_ptr
,
output_ptr
,
n_elements
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
offsets
=
block_start
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
offsets
<
n_elements
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
*
absmax_inv
))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
mask
)
# global quantize and transpose
def
quantize_global
(
x
:
torch
.
Tensor
):
@
triton
.
autotune
(
absmax
=
x
.
abs
().
max
().
unsqueeze
(
0
)
configs
=
[
absmax_inv
=
1.
/
absmax
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
triton
.
cdiv
(
n_elements
,
meta
[
'BLOCK_SIZE'
]),)
_quantize_global
[
grid
](
x
,
absmax_inv
,
output
,
n_elements
)
return
output
,
absmax
# ...
],
key
=
[
'M'
,
'N'
]
)
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
group_size
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
A
=
A
+
(
rm
[:,
None
]
*
stride_am
+
rn
[
None
,
:]
*
stride_an
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
a
=
tl
.
load
(
A
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
# rematerialize to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
a
*
absmax_inv
))
# global quantize and transpose
@
triton
.
autotune
(
configs
=
[
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
triton
.
Config
({
'BLOCK_M'
:
128
,
'BLOCK_N'
:
128
,
'GROUP_M'
:
8
},
num_warps
=
4
),
tl
.
store
(
B
,
output
,
mask
=
mask
)
# ...
],
key
=
[
'M'
,
'N'
]
)
@
triton
.
jit
def
_quantize_global_transpose
(
A
,
absmax_inv_ptr
,
B
,
stride_am
,
stride_an
,
stride_bn
,
stride_bm
,
M
,
N
,
BLOCK_M
:
tl
.
constexpr
,
BLOCK_N
:
tl
.
constexpr
,
GROUP_M
:
tl
.
constexpr
):
pid
=
tl
.
program_id
(
0
)
grid_m
=
(
M
+
BLOCK_M
-
1
)
//
BLOCK_M
grid_n
=
(
N
+
BLOCK_N
-
1
)
//
BLOCK_N
width
=
GROUP_M
*
grid_n
group_id
=
pid
//
width
group_size
=
min
(
grid_m
-
group_id
*
GROUP_M
,
GROUP_M
)
pid_m
=
group_id
*
GROUP_M
+
(
pid
%
group_size
)
pid_n
=
(
pid
%
width
)
//
group_size
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
A
=
A
+
(
rm
[:,
None
]
*
stride_am
+
rn
[
None
,
:]
*
stride_an
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
a
=
tl
.
load
(
A
,
mask
=
mask
)
absmax_inv
=
tl
.
load
(
absmax_inv_ptr
)
# rematerialize to save registers
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rn
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
B
=
B
+
(
rm
[:,
None
]
*
stride_bm
+
rn
[
None
,
:]
*
stride_bn
)
mask
=
(
rm
<
M
)[:,
None
]
&
(
rn
<
N
)[
None
,
:]
def
quantize_global_transpose
(
input
):
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
a
*
absmax_inv
))
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
absmax_inv
=
1.
/
absmax
tl
.
store
(
B
,
output
,
mask
=
mask
)
M
,
N
=
input
.
shape
out
=
torch
.
empty
(
N
,
M
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
def
quantize_global_transpose
(
input
):
absmax
=
input
.
abs
().
max
().
unsqueeze
(
0
)
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
absmax_inv
=
1.
/
absmax
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
M
,
N
=
input
.
shape
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
out
=
torch
.
empty
(
N
,
M
,
device
=
'cuda'
,
dtype
=
torch
.
int8
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),)
assert
out
.
size
(
0
)
==
N
and
out
.
size
(
1
)
==
M
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
)
assert
input
.
stride
(
0
)
==
1
or
input
.
stride
(
1
)
==
1
return
out
,
absmax
assert
out
.
stride
(
0
)
==
1
or
out
.
stride
(
1
)
==
1
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
])
*
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
]),)
_quantize_global_transpose
[
grid
](
input
,
absmax_inv
,
out
,
input
.
stride
(
0
),
input
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
M
,
N
)
return
out
,
absmax
bitsandbytes/triton/quantize_rowwise.py
View file @
c3d87e44
import
math
import
math
import
torch
import
torch
import
time
import
time
import
triton
import
triton.language
as
tl
from
bitsandbytes.triton.triton_utils
import
is_triton_available
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
if
not
is_triton_available
():
# rowwise quantize
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
return
None
else
:
# TODO: autotune this better.
@
triton
.
autotune
(
import
triton
configs
=
[
import
triton.language
as
tl
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
# rowwise quantize
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
# TODO: autotune this better.
triton
.
Config
({},
num_stages
=
2
),
@
triton
.
autotune
(
triton
.
Config
({},
num_stages
=
4
),
configs
=
[
triton
.
Config
({},
num_stages
=
8
),
triton
.
Config
({},
num_stages
=
1
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
1
),
triton
.
Config
({},
num_stages
=
2
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
2
),
triton
.
Config
({},
num_stages
=
4
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
4
),
triton
.
Config
({},
num_stages
=
8
,
num_warps
=
8
),
triton
.
Config
({},
num_warps
=
8
),
triton
.
Config
({},
num_stages
=
1
),
],
triton
.
Config
({},
num_stages
=
2
),
key
=
[
'n_elements'
]
triton
.
Config
({},
num_stages
=
4
),
)
triton
.
Config
({},
num_stages
=
8
),
@
triton
.
jit
triton
.
Config
({},
num_warps
=
1
),
def
_quantize_rowwise
(
triton
.
Config
({},
num_warps
=
2
),
x_ptr
,
triton
.
Config
({},
num_warps
=
4
),
output_ptr
,
triton
.
Config
({},
num_warps
=
8
),
output_maxs
,
],
n_elements
,
key
=
[
'n_elements'
]
BLOCK_SIZE
:
tl
.
constexpr
,
)
P2
:
tl
.
constexpr
,
@
triton
.
jit
):
def
_quantize_rowwise
(
pid
=
tl
.
program_id
(
axis
=
0
)
x_ptr
,
block_start
=
pid
*
BLOCK_SIZE
output_ptr
,
arange
=
tl
.
arange
(
0
,
P2
)
output_maxs
,
offsets
=
block_start
+
arange
n_elements
,
row_mask
=
arange
<
BLOCK_SIZE
BLOCK_SIZE
:
tl
.
constexpr
,
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
P2
:
tl
.
constexpr
,
):
abs_x
=
tl
.
abs
(
x
)
pid
=
tl
.
program_id
(
axis
=
0
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
block_start
=
pid
*
BLOCK_SIZE
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
arange
=
tl
.
arange
(
0
,
P2
)
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
offsets
=
block_start
+
arange
tl
.
store
(
output_maxs
+
pid
,
max_val
)
row_mask
=
arange
<
BLOCK_SIZE
x
=
tl
.
load
(
x_ptr
+
offsets
,
mask
=
row_mask
)
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
abs_x
=
tl
.
abs
(
x
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
max_val
=
tl
.
max
(
tl
.
where
(
row_mask
,
abs_x
,
0
),
axis
=
0
)
output
=
tl
.
libdevice
.
llrint
(
127.
*
(
x
/
max_val
))
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
tl
.
store
(
output_ptr
+
offsets
,
output
,
mask
=
row_mask
)
tl
.
store
(
output_maxs
+
pid
,
max_val
)
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
def
quantize_rowwise
(
x
:
torch
.
Tensor
):
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
output
=
torch
.
empty
(
*
x
.
shape
,
device
=
x
.
device
,
dtype
=
torch
.
int8
)
_quantize_rowwise
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
output_maxs
=
torch
.
empty
(
x
.
shape
[
0
],
device
=
x
.
device
,
dtype
=
torch
.
float16
)
return
output
,
output_maxs
P2
=
int
(
2
**
(
math
.
ceil
(
math
.
log2
(
x
.
shape
[
1
]))))
assert
x
.
is_cuda
and
output
.
is_cuda
n_elements
=
output
.
numel
()
grid
=
lambda
meta
:
(
x
.
shape
[
0
],)
_quantize_rowwise
[
grid
](
x
,
output
,
output_maxs
,
n_elements
,
BLOCK_SIZE
=
x
.
shape
[
1
],
P2
=
P2
)
return
output
,
output_maxs
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment