Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
c3bef204
Unverified
Commit
c3bef204
authored
Sep 28, 2023
by
Xu Kai
Committed by
GitHub
Sep 28, 2023
Browse files
add autotune (#4822)
parent
822051d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
182 additions
and
5 deletions
+182
-5
colossalai/kernel/triton/custom_autotune.py
colossalai/kernel/triton/custom_autotune.py
+176
-0
colossalai/kernel/triton/gptq_triton.py
colossalai/kernel/triton/gptq_triton.py
+6
-5
No files found.
colossalai/kernel/triton/custom_autotune.py
0 → 100644
View file @
c3bef204
# code from AutoGPTQ auto_gptq: https://github.com/PanQiWei/AutoGPTQ/blob/main/auto_gptq/nn_modules/triton_utils/custom_autotune.py
import
builtins
import
math
import
time
from
typing
import
Dict
import
triton
class
CustomizedTritonAutoTuner
(
triton
.
KernelInterface
):
def
__init__
(
self
,
fn
,
arg_names
,
configs
,
key
,
reset_to_zero
,
prune_configs_by
:
Dict
=
None
,
nearest_power_of_two
:
bool
=
False
,
):
if
not
configs
:
self
.
configs
=
[
triton
.
Config
({},
num_warps
=
4
,
num_stages
=
2
)]
else
:
self
.
configs
=
configs
self
.
key_idx
=
[
arg_names
.
index
(
k
)
for
k
in
key
]
self
.
nearest_power_of_two
=
nearest_power_of_two
self
.
cache
=
{}
# hook to reset all required tensor to zeros before relaunching a kernel
self
.
hook
=
lambda
args
:
0
if
reset_to_zero
is
not
None
:
self
.
reset_idx
=
[
arg_names
.
index
(
k
)
for
k
in
reset_to_zero
]
def
_hook
(
args
):
for
i
in
self
.
reset_idx
:
args
[
i
].
zero_
()
self
.
hook
=
_hook
self
.
arg_names
=
arg_names
# prune configs
if
prune_configs_by
:
perf_model
,
top_k
=
prune_configs_by
[
"perf_model"
],
prune_configs_by
[
"top_k"
]
if
"early_config_prune"
in
prune_configs_by
:
early_config_prune
=
prune_configs_by
[
"early_config_prune"
]
else
:
perf_model
,
top_k
,
early_config_prune
=
None
,
None
,
None
self
.
perf_model
,
self
.
configs_top_k
=
perf_model
,
top_k
self
.
early_config_prune
=
early_config_prune
self
.
fn
=
fn
def
_bench
(
self
,
*
args
,
config
,
**
meta
):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts
=
meta
.
keys
()
&
config
.
kwargs
.
keys
()
if
conflicts
:
raise
ValueError
(
f
"Conflicting meta-parameters:
{
', '
.
join
(
conflicts
)
}
."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current
=
dict
(
meta
,
**
config
.
kwargs
)
def
kernel_call
():
if
config
.
pre_hook
:
config
.
pre_hook
(
self
.
nargs
)
self
.
hook
(
args
)
self
.
fn
.
run
(
*
args
,
num_warps
=
config
.
num_warps
,
num_stages
=
config
.
num_stages
,
**
current
)
try
:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return
triton
.
testing
.
do_bench
(
kernel_call
,
percentiles
=
(
0.5
,
0.2
,
0.8
),
rep
=
40
)
except
triton
.
compiler
.
OutOfResources
:
return
(
float
(
"inf"
),
float
(
"inf"
),
float
(
"inf"
))
def
run
(
self
,
*
args
,
**
kwargs
):
self
.
nargs
=
dict
(
zip
(
self
.
arg_names
,
args
))
if
len
(
self
.
configs
)
>
1
:
key
=
tuple
(
args
[
i
]
for
i
in
self
.
key_idx
)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if
self
.
nearest_power_of_two
:
key
=
tuple
([
2
**
int
(
math
.
log2
(
x
)
+
0.5
)
for
x
in
key
])
if
key
not
in
self
.
cache
:
# prune configs
pruned_configs
=
self
.
prune_configs
(
kwargs
)
bench_start
=
time
.
time
()
timings
=
{
config
:
self
.
_bench
(
*
args
,
config
=
config
,
**
kwargs
)
for
config
in
pruned_configs
}
bench_end
=
time
.
time
()
self
.
bench_time
=
bench_end
-
bench_start
self
.
cache
[
key
]
=
builtins
.
min
(
timings
,
key
=
timings
.
get
)
self
.
hook
(
args
)
self
.
configs_timings
=
timings
config
=
self
.
cache
[
key
]
else
:
config
=
self
.
configs
[
0
]
self
.
best_config
=
config
if
config
.
pre_hook
is
not
None
:
config
.
pre_hook
(
self
.
nargs
)
return
self
.
fn
.
run
(
*
args
,
num_warps
=
config
.
num_warps
,
num_stages
=
config
.
num_stages
,
**
kwargs
,
**
config
.
kwargs
)
def
prune_configs
(
self
,
kwargs
):
pruned_configs
=
self
.
configs
if
self
.
early_config_prune
:
pruned_configs
=
self
.
early_config_prune
(
self
.
configs
,
self
.
nargs
)
if
self
.
perf_model
:
top_k
=
self
.
configs_top_k
if
isinstance
(
top_k
,
float
)
and
top_k
<=
1.0
:
top_k
=
int
(
len
(
self
.
configs
)
*
top_k
)
if
len
(
pruned_configs
)
>
top_k
:
est_timing
=
{
config
:
self
.
perf_model
(
**
self
.
nargs
,
**
kwargs
,
**
config
.
kwargs
,
num_stages
=
config
.
num_stages
,
num_warps
=
config
.
num_warps
,
)
for
config
in
pruned_configs
}
pruned_configs
=
sorted
(
est_timing
.
keys
(),
key
=
lambda
x
:
est_timing
[
x
])[:
top_k
]
return
pruned_configs
def
warmup
(
self
,
*
args
,
**
kwargs
):
self
.
nargs
=
dict
(
zip
(
self
.
arg_names
,
args
))
for
config
in
self
.
prune_configs
(
kwargs
):
self
.
fn
.
warmup
(
*
args
,
num_warps
=
config
.
num_warps
,
num_stages
=
config
.
num_stages
,
**
kwargs
,
**
config
.
kwargs
,
)
self
.
nargs
=
None
def
autotune
(
configs
,
key
,
prune_configs_by
=
None
,
reset_to_zero
=
None
,
nearest_power_of_two
=
False
):
def
decorator
(
fn
):
return
CustomizedTritonAutoTuner
(
fn
,
fn
.
arg_names
,
configs
,
key
,
reset_to_zero
,
prune_configs_by
,
nearest_power_of_two
)
return
decorator
def
matmul248_kernel_config_pruner
(
configs
,
nargs
):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m
=
max
(
2
**
int
(
math
.
ceil
(
math
.
log2
(
nargs
[
"M"
]))),
16
)
n
=
max
(
2
**
int
(
math
.
ceil
(
math
.
log2
(
nargs
[
"N"
]))),
16
)
k
=
max
(
2
**
int
(
math
.
ceil
(
math
.
log2
(
nargs
[
"K"
]))),
16
)
used
=
set
()
for
config
in
configs
:
block_size_m
=
min
(
m
,
config
.
kwargs
[
"BLOCK_SIZE_M"
])
block_size_n
=
min
(
n
,
config
.
kwargs
[
"BLOCK_SIZE_N"
])
block_size_k
=
min
(
k
,
config
.
kwargs
[
"BLOCK_SIZE_K"
])
group_size_m
=
config
.
kwargs
[
"GROUP_SIZE_M"
]
if
(
block_size_m
,
block_size_n
,
block_size_k
,
group_size_m
,
config
.
num_stages
,
config
.
num_warps
)
in
used
:
continue
used
.
add
((
block_size_m
,
block_size_n
,
block_size_k
,
group_size_m
,
config
.
num_stages
,
config
.
num_warps
))
yield
triton
.
Config
(
{
"BLOCK_SIZE_M"
:
block_size_m
,
"BLOCK_SIZE_N"
:
block_size_n
,
"BLOCK_SIZE_K"
:
block_size_k
,
"GROUP_SIZE_M"
:
group_size_m
,
},
num_stages
=
config
.
num_stages
,
num_warps
=
config
.
num_warps
,
)
colossalai/kernel/triton/gptq_triton.py
View file @
c3bef204
...
...
@@ -3,7 +3,8 @@
import
torch
import
triton
import
triton.language
as
tl
from
auto_gptq.nn_modules.triton_utils
import
custom_autotune
from
.custom_autotune
import
autotune
,
matmul248_kernel_config_pruner
@
triton
.
jit
...
...
@@ -94,7 +95,7 @@ def silu(x):
return
x
*
tl
.
sigmoid
(
x
)
@
custom_autotune
.
autotune
(
@
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
},
num_stages
=
4
,
num_warps
=
4
...
...
@@ -124,7 +125,7 @@ def silu(x):
key
=
[
"M"
,
"N"
,
"K"
],
nearest_power_of_two
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
custom_autotune
.
matmul248_kernel_config_pruner
,
"early_config_prune"
:
matmul248_kernel_config_pruner
,
"perf_model"
:
None
,
"top_k"
:
None
,
},
...
...
@@ -266,7 +267,7 @@ def cai_gptq_matmul_248_kernel(
tl
.
store
(
c_ptrs
,
accumulator
,
mask
=
c_mask
)
@
custom_autotune
.
autotune
(
@
autotune
(
configs
=
[
triton
.
Config
(
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
},
num_stages
=
4
,
num_warps
=
4
...
...
@@ -296,7 +297,7 @@ def cai_gptq_matmul_248_kernel(
key
=
[
"M"
,
"N"
,
"K"
],
nearest_power_of_two
=
True
,
prune_configs_by
=
{
"early_config_prune"
:
custom_autotune
.
matmul248_kernel_config_pruner
,
"early_config_prune"
:
matmul248_kernel_config_pruner
,
"perf_model"
:
None
,
"top_k"
:
None
,
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment