Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
81e68c64
"src/include/blockwise_gemm.hip.hpp" did not exist on "e7b8705b913c1bb7d216255f1f233ea03c096f1e"
Unverified
Commit
81e68c64
authored
Jan 22, 2025
by
Max Podkorytov
Browse files
copy over fmha example
parent
c5fff071
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
485 additions
and
0 deletions
+485
-0
example/ck_tile/18_flexattn/script/benchmark_fwd.sh
example/ck_tile/18_flexattn/script/benchmark_fwd.sh
+31
-0
example/ck_tile/18_flexattn/script/run_full_test.sh
example/ck_tile/18_flexattn/script/run_full_test.sh
+46
-0
example/ck_tile/18_flexattn/script/smoke_test_bwd.sh
example/ck_tile/18_flexattn/script/smoke_test_bwd.sh
+35
-0
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
+106
-0
example/ck_tile/18_flexattn/utils.hpp
example/ck_tile/18_flexattn/utils.hpp
+266
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
No files found.
example/ck_tile/18_flexattn/script/benchmark_fwd.sh
0 → 100755
View file @
81e68c64
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
VALID
=
0
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
hdim
in
64 128 256
;
do
nhead
=
$((
2048
/
$hdim
))
# follow fav2 setup
$EXE
-prec
=
$prec
-b
=
32
-h
=
$nhead
-d
=
$hdim
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
16
-h
=
$nhead
-d
=
$hdim
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
8
-h
=
$nhead
-d
=
$hdim
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
4
-h
=
$nhead
-d
=
$hdim
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
2
-h
=
$nhead
-d
=
$hdim
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
$prec
-b
=
1
-h
=
$nhead
-d
=
$hdim
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-kname
=
1
-v
=
$VALID
;
sleep
3
done
done
done
for
perm
in
0 1
;
do
$EXE
-prec
=
fp8
-squant
=
1
-b
=
32
-h
=
16
-d
=
128
-s
=
512
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
fp8
-squant
=
1
-b
=
16
-h
=
16
-d
=
128
-s
=
1024
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
fp8
-squant
=
1
-b
=
8
-h
=
16
-d
=
128
-s
=
2048
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
fp8
-squant
=
1
-b
=
4
-h
=
16
-d
=
128
-s
=
4096
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
fp8
-squant
=
1
-b
=
2
-h
=
16
-d
=
128
-s
=
8192
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
$EXE
-prec
=
fp8
-squant
=
1
-b
=
1
-h
=
16
-d
=
128
-s
=
16384
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-range_q
=
240
-range_k
=
240
-range_v
=
240
-range_p
=
240
-range_o
=
240
-kname
=
1
-v
=
$VALID
;
sleep
3
done
\ No newline at end of file
example/ck_tile/18_flexattn/script/run_full_test.sh
0 → 100755
View file @
81e68c64
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
#get the command line arguments:
export
env_type
=
$1
echo
'Environment type: '
$env_type
export
branch
=
$2
echo
'Branch name: '
$branch
export
host_name
=
$3
echo
'Host name: '
$host_name
export
GPU_arch
=
$4
echo
'GPU_arch: '
$GPU_arch
function
print_log_header
(){
rm
-f
$1
;
echo
'On branch '
$3
&>
$1
;
echo
'Node name: '
$4
>>
$1
;
#get GPU_arch and number of compute units from rocminfo
echo
-n
"GPU_arch: "
>>
$1
;
rocminfo |
grep
"Name:"
|
grep
"gfx"
>>
$1
;
rocminfo |
grep
"Compute Unit:"
>>
$1
;
hipcc
--version
|
grep
-e
'HIP version'
>>
$1
;
echo
'Environment type: '
$2
>>
$1
;
/opt/rocm/bin/amdclang++
--version
|
grep
-e
'InstalledDir'
>>
$1
;
}
#run verification tests
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export
fmha_fwd_log
=
"perf_fmha_fwd_
$GPU_arch
.log"
print_log_header
$fmha_fwd_log
$env_type
$branch
$host_name
example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 |
tee
-a
$fmha_fwd_log
export
fmha_bwd_log
=
"perf_fmha_bwd_
$GPU_arch
.log"
print_log_header
$fmha_bwd_log
$env_type
$branch
$host_name
example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 |
tee
-a
$fmha_bwd_log
example/ck_tile/18_flexattn/script/smoke_test_bwd.sh
0 → 100755
View file @
81e68c64
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_bwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=1'
set
-x
for
prec
in
"fp16"
"bf16"
;
do
for
perm
in
0 1
;
do
for
hdim
in
32 64 128 256
;
do
for
mode
in
0 1
;
do
for
bias
in
"n"
"a"
;
do
for
dbias
in
0
;
do
for
p_drop
in
0.0 0.2
;
do
for
deterministic
in
0
;
do
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
2
-d
=
$hdim
-s
=
259
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
2
-d
=
$hdim
-s
=
516
-s_k
=
253
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
4
-h_k
=
1
-d
=
$hdim
-s
=
500
-s_k
=
251
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
1
-h
=
2
-d
=
$hdim
-s
=
900
-s_k
=
258
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-v
=
1
-deterministic
=
$deterministic
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
987
-s_k
=
219
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-b
=
2
-h
=
3
-h_k
=
1
-d
=
$hdim
-s
=
244
-s_k
=
499
-bias
=
$bias
-dbias
=
$dbias
-p_drop
=
$p_drop
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-deterministic
=
$deterministic
-v
=
1
-mode
=
$mode
-kname
=
$KNAME
$COMMON_ARGS
done
done
done
done
done
done
done
done
set
+x
example/ck_tile/18_flexattn/script/smoke_test_fwd.sh
0 → 100755
View file @
81e68c64
#!/bin/bash
# TODO: run this script from CK root or build directory
EXE
=
"
$(
find
.
-name
tile_example_fmha_fwd
-type
f |
head
-n
1
)
"
KNAME
=
1
export
CK_WARMUP
=
0
export
CK_REPEAT
=
1
COMMON_ARGS
=
'-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
TEST_SPLITKV
=
0
TEST_APPENDKV
=
0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while
getopts
":sa"
opt
;
do
case
"
${
opt
}
"
in
s
)
TEST_SPLITKV
=
1
;;
a
)
TEST_APPENDKV
=
1
;;
*
)
;;
esac
done
run_fp16_bf16_tests
()
{
local
NUM_SPLITS
=
"1"
local
PAGE_BLOCK_SIZE
=
"0"
local
CACHE_BATCH_IDX
=
"0"
if
[
$TEST_SPLITKV
-eq
1
]
;
then
NUM_SPLITS
=
"
$NUM_SPLITS
2 3"
PAGE_BLOCK_SIZE
=
"
$PAGE_BLOCK_SIZE
128"
CACHE_BATCH_IDX
=
"
$CACHE_BATCH_IDX
1"
fi
for
prec
in
"fp16"
"bf16"
;
do
for
mode
in
1 0
;
do
for
perm
in
0 1
;
do
for
vlayout
in
"r"
"c"
;
do
for
hdim
in
32 64 128 256
;
do
for
lse
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
p_drop
in
0.0 0.2
;
do
for
num_splits
in
$NUM_SPLITS
;
do
for
page_block_size
in
$PAGE_BLOCK_SIZE
;
do
for
cache_batch_idx
in
$CACHE_BATCH_IDX
;
do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
2
-h_k
=
1
-d
=
16,
-d_v
=
$hdim
-s
=
55
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
3
-d
=
$hdim
-s
=
100
-s_k
=
51
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
16
-d_v
=
$hdim
-s
=
99
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
1
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1024
-s_k
=
256
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-d_v
=
24
-s
=
3
-s_k
=
99
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
3
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
200
-s_k
=
520
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
t:128,30
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
2
-h
=
1
-d
=
$hdim
-s
=
99
-s_k
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
b:4,35
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
33
-s_k
=
0
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
$EXE
-prec
=
$prec
-mode
=
$mode
-b
=
1
-h
=
2
-h_k
=
1
-d
=
$hdim
-s
=
1
-s_k
=
10
-s_kpad
=
32
-bias
=
$bias
-p_drop
=
$p_drop
-lse
=
$lse
-iperm
=
$perm
-operm
=
$perm
-mask
=
2
-vlayout
=
$vlayout
-num_splits
=
$num_splits
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
;
done
;
done
done
;
}
run_fp8_tests
()
{
for
perm
in
0 1
;
do
for
bias
in
"n"
"e"
"a"
;
do
for
b
in
1 2
;
do
for
hdim
in
64 128 256
;
do
$EXE
-prec
=
fp8
-init
=
3
-b
=
$b
-h
=
1
-d
=
128
-s
=
128
-bias
=
$bias
-iperm
=
$perm
-operm
=
$perm
-vlayout
=
c
-squant
=
1
-kname
=
$KNAME
$COMMON_ARGS
done
;
done
;
done
;
done
}
run_fp16_appendkv_tests
()
{
for
s
in
$(
seq
63 1 65
)
;
do
for
s_k
in
65 129
;
do
for
s_knew
in
0 64
$s_k
;
do
for
hdim
in
32 64 128 256
;
do
for
ri
in
0 1
;
do
for
rdim
in
0 16 32
$hdim
;
do
for
page_block_size
in
0 128
;
do
for
cache_batch_idx
in
0 1
;
do
$EXE
-prec
=
fp16
-b
=
3
-h
=
3
-d
=
$hdim
-s
=
$s
-s_k
=
$s_k
-s_knew
=
$s_knew
-rotary_dim
=
$rdim
-rotary_interleaved
=
$ri
-page_block_size
=
$page_block_size
-cache_batch_idx
=
$cache_batch_idx
-iperm
=
1
-operm
=
1
-kname
=
1
$COMMON_ARGS
done
;
done
;
done
;
done
;
done
done
;
done
;
done
}
set
-x
run_fp16_bf16_tests
run_fp8_tests
if
[
$TEST_APPENDKV
-eq
1
]
;
then
run_fp16_appendkv_tests
fi
set
+x
example/ck_tile/18_flexattn/utils.hpp
0 → 100644
View file @
81e68c64
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck_tile/core/container/span.hpp"
enum
class
mode_enum
{
batch
=
0
,
group
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
stream
,
mode_enum
mode
)
{
return
stream
<<
(
mode
==
mode_enum
::
batch
?
"batch"
:
"group"
);
}
std
::
vector
<
int32_t
>
to_seqstarts
(
ck_tile
::
span
<
const
int32_t
>
seqlens
)
{
std
::
vector
<
int32_t
>
seqstarts
=
{
0
};
for
(
int32_t
seqlen
:
seqlens
)
{
seqstarts
.
push_back
(
seqstarts
.
back
()
+
seqlen
);
}
assert
(
seqstarts
.
size
()
==
seqlens
.
size
()
+
1
);
return
seqstarts
;
}
std
::
vector
<
int32_t
>
generate_seqlens
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
// if not negative, clamp min
int32_t
seqlen_max
=
-
1
,
// if not negative, clamp max
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
assert
(
0
<
count
);
seqlen_min
=
(
0
<
seqlen_min
?
seqlen_min
:
1
);
seqlen_max
=
(
0
<
seqlen_max
?
seqlen_max
:
std
::
numeric_limits
<
int32_t
>::
max
());
assert
(
seqlen_min
<=
seqlen_max
);
std
::
vector
<
int32_t
>
seqlens
(
count
,
std
::
clamp
(
seqlen_avg
,
seqlen_min
,
seqlen_max
));
if
(
mode
==
mode_enum
::
group
&&
1
<
count
)
{
using
size_type
=
std
::
vector
<
int32_t
>::
size_type
;
std
::
mt19937
random_engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
size_type
>
idx_dist
(
0
,
count
-
1
);
auto
next_idx
=
std
::
bind
(
idx_dist
,
std
::
ref
(
random_engine
));
std
::
uniform_int_distribution
<
size_type
>
step_dist
(
1
,
count
-
1
);
auto
next_step
=
std
::
bind
(
step_dist
,
std
::
ref
(
random_engine
));
for
(
unsigned
repeat
=
seqlen_avg
*
(
count
/
2
);
0
<
repeat
;
--
repeat
)
{
const
size_type
to_decrease
=
next_idx
();
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if
(
seqlens
[
to_decrease
]
==
seqlen_min
)
{
continue
;
}
const
size_type
to_increase
=
(
to_decrease
+
next_step
())
%
count
;
if
(
seqlens
[
to_increase
]
>=
seqlen_max
)
{
continue
;
}
--
seqlens
[
to_decrease
];
++
seqlens
[
to_increase
];
}
}
return
seqlens
;
}
std
::
vector
<
int32_t
>
generate_seqstarts
(
mode_enum
mode
,
unsigned
count
,
int32_t
seqlen_avg
,
int32_t
seqlen_min
=
-
1
,
int32_t
seqlen_max
=
-
1
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
return
to_seqstarts
(
generate_seqlens
(
mode
,
count
,
seqlen_avg
,
seqlen_min
,
seqlen_max
,
seed
));
}
// return random integer generated uniformly in range [low, high]
template
<
typename
Int
=
int
>
auto
randint
(
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>
,
Int
>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
return
dist
(
engine
);
}
// return random integers generated uniformly in range [low, high]
template
<
typename
Int
,
typename
ForwardIterator
>
auto
randints
(
ForwardIterator
first
,
ForwardIterator
last
,
Int
low
,
Int
high
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
->
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
{
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
uniform_int_distribution
<
Int
>
dist
(
low
,
high
);
std
::
generate
(
first
,
last
,
[
&
]
{
return
dist
(
engine
);
});
}
/*
* decode the seqlen string from cmdline
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
std
::
tuple
<
std
::
vector
<
ck_tile
::
index_t
>
,
std
::
vector
<
ck_tile
::
index_t
>
,
std
::
vector
<
ck_tile
::
index_t
>>
decode_seqlen
(
mode_enum
mode
,
ck_tile
::
index_t
batch
,
std
::
string
q_val
,
std
::
string
k_val
,
std
::
string
k_pad_val
,
ck_tile
::
index_t
seqlen_k_min
=
0
,
bool
need_append_kvcache
=
false
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if
(
mode
==
mode_enum
::
batch
)
{
ck_tile
::
index_t
q
=
_S2I_
(
q_val
);
ck_tile
::
index_t
k
=
_S2I_
(
k_val
);
auto
s_q
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
q
);
auto
s_k
=
[
&
]
{
const
ck_tile
::
index_t
seqlen_k_max
=
(
k
<
0
?
q
:
k
);
std
::
vector
<
ck_tile
::
index_t
>
seqlen_ks
(
batch
,
seqlen_k_max
);
if
(
1
<
batch
&&
need_append_kvcache
)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints
(
std
::
next
(
seqlen_ks
.
begin
()),
seqlen_ks
.
end
(),
seqlen_k_min
,
seqlen_k_max
,
seed
);
return
seqlen_ks
;
}
return
seqlen_ks
;
}();
auto
s_kpad
=
std
::
vector
<
ck_tile
::
index_t
>
(
batch
,
-
1
);
// TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
else
{
ck_tile
::
index_t
idx
=
0
;
std
::
string
::
size_type
pos_q
=
0
;
std
::
string
::
size_type
pos_k
=
0
;
std
::
string
::
size_type
pos_kp
=
0
;
std
::
vector
<
ck_tile
::
index_t
>
s_q
;
std
::
vector
<
ck_tile
::
index_t
>
s_k
;
std
::
vector
<
ck_tile
::
index_t
>
s_kpad
;
while
(
true
)
{
auto
found_q
=
q_val
.
find
(
','
,
pos_q
);
auto
found_k
=
k_val
.
find
(
','
,
pos_k
);
auto
found_kp
=
k_pad_val
.
find
(
','
,
pos_kp
);
ck_tile
::
index_t
q
=
_S2I_
(
q_val
.
substr
(
pos_q
,
found_q
==
std
::
string
::
npos
?
found_q
:
found_q
-
pos_q
));
ck_tile
::
index_t
k
=
_S2I_
(
k_val
.
substr
(
pos_k
,
found_k
==
std
::
string
::
npos
?
found_k
:
found_k
-
pos_k
));
ck_tile
::
index_t
kp
=
_S2I_
(
k_pad_val
.
substr
(
pos_kp
,
found_kp
==
std
::
string
::
npos
?
found_kp
:
found_kp
-
pos_kp
));
s_q
.
push_back
(
q
);
s_k
.
push_back
(
k
<
0
?
q
:
k
);
s_kpad
.
push_back
(
kp
);
// s_k should be greater than or equal to seqlen_k_min
if
(
s_k
.
back
()
<
seqlen_k_min
)
{
std
::
ostringstream
msg
;
msg
<<
__FILE__
<<
":"
<<
__LINE__
<<
": seqlen_k (="
<<
s_k
.
back
()
<<
") is less than minimum seqlen_k (="
<<
seqlen_k_min
<<
")"
;
throw
std
::
runtime_error
(
msg
.
str
());
}
idx
++
;
if
(
found_q
==
std
::
string
::
npos
||
idx
>=
batch
)
{
break
;
}
pos_q
=
found_q
+
1
;
pos_k
=
found_k
==
std
::
string
::
npos
?
pos_k
:
found_k
+
1
;
pos_kp
=
found_kp
==
std
::
string
::
npos
?
pos_kp
:
found_kp
+
1
;
}
if
(
idx
<
batch
)
{
auto
rem_q
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_q
.
back
(),
1
,
s_kpad
.
back
(),
seed
);
auto
rem_k
=
generate_seqlens
(
mode
,
batch
-
idx
,
s_k
.
back
(),
seqlen_k_min
,
s_kpad
.
back
(),
seed
);
s_q
.
insert
(
s_q
.
end
(),
rem_q
.
begin
(),
rem_q
.
end
());
s_k
.
insert
(
s_k
.
end
(),
rem_k
.
begin
(),
rem_k
.
end
());
s_kpad
.
insert
(
s_kpad
.
end
(),
batch
-
idx
,
s_kpad
.
back
());
}
return
std
::
make_tuple
(
s_q
,
s_k
,
s_kpad
);
}
#undef _S2I_
}
int
env_get_int
(
const
char
*
var_name
,
int
default_int
)
{
char
*
v
=
getenv
(
var_name
);
int
r
=
default_int
;
if
(
v
)
r
=
std
::
atoi
(
v
);
return
r
;
}
template
<
typename
RandomAccessIterator
,
typename
Int
>
std
::
enable_if_t
<
std
::
is_integral_v
<
Int
>>
iota_shuffle
(
RandomAccessIterator
first
,
RandomAccessIterator
last
,
Int
value
,
std
::
optional
<
unsigned
>
seed
=
std
::
nullopt
)
{
std
::
iota
(
first
,
last
,
value
);
std
::
mt19937
engine
(
seed
.
has_value
()
?
*
seed
:
std
::
random_device
{}());
std
::
shuffle
(
first
,
last
,
engine
);
}
example/ck_tile/CMakeLists.txt
View file @
81e68c64
...
@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant)
...
@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
18_flexattn
)
add_subdirectory
(
35_batched_transpose
)
add_subdirectory
(
35_batched_transpose
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment