Unverified Commit 81e68c64 authored by Max Podkorytov's avatar Max Podkorytov
Browse files

copy over fmha example

parent c5fff071
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
VALID=0
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 64 128 256 ; do
nhead=$((2048 / $hdim)) # follow fav2 setup
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
done
done
done
for perm in 0 1 ; do
$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
done
\ No newline at end of file
#!/bin/bash
#
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
#
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
# input arguments:
# environment tag : a string describing the specifics of your test environment
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
# host name : $hostname
# gpu architecture: e.g., gfx90a, or gfx942, etc.
#get the command line arguments:
export env_type=$1
echo 'Environment type: ' $env_type
export branch=$2
echo 'Branch name: ' $branch
export host_name=$3
echo 'Host name: ' $host_name
export GPU_arch=$4
echo 'GPU_arch: ' $GPU_arch
function print_log_header(){
rm -f $1;
echo 'On branch ' $3 &> $1;
echo 'Node name: ' $4 >> $1;
#get GPU_arch and number of compute units from rocminfo
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
rocminfo | grep "Compute Unit:" >> $1;
hipcc --version | grep -e 'HIP version' >> $1;
echo 'Environment type: ' $2 >> $1;
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
}
#run verification tests
example/ck_tile/01_fmha/script/smoke_test_fwd.sh
example/ck_tile/01_fmha/script/smoke_test_bwd.sh
#run performance benchmarks
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
print_log_header $fmha_fwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
print_log_header $fmha_bwd_log $env_type $branch $host_name
example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
#!/bin/sh
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1'
set -x
for prec in "fp16" "bf16" ; do
for perm in 0 1 ; do
for hdim in 32 64 128 256 ; do
for mode in 0 1 ; do
for bias in "n" "a" ; do
for dbias in 0 ; do
for p_drop in 0.0 0.2 ; do
for deterministic in 0 ; do
$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -v=1 -deterministic=$deterministic -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
done
done
done
done
done
done
done
done
set +x
#!/bin/bash
# TODO: run this script from CK root or build directory
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
KNAME=1
export CK_WARMUP=0
export CK_REPEAT=1
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
TEST_SPLITKV=0
TEST_APPENDKV=0
# options:
# -s: run splitkv tests
# -a: run appendkv tests
while getopts ":sa" opt; do
case "${opt}" in
s)
TEST_SPLITKV=1
;;
a)
TEST_APPENDKV=1
;;
*)
;;
esac
done
run_fp16_bf16_tests() {
local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi
for prec in "fp16" "bf16" ; do
for mode in 1 0 ; do
for perm in 0 1 ; do
for vlayout in "r" "c" ; do
for hdim in 32 64 128 256 ; do
for lse in 0 1 ; do
for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do
for num_splits in $NUM_SPLITS ; do
for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done ; done ; done
done ;
}
run_fp8_tests() {
for perm in 0 1 ; do
for bias in "n" "e" "a" ; do
for b in 1 2 ; do
for hdim in 64 128 256 ; do
$EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS
done ; done ; done ; done
}
run_fp16_appendkv_tests() {
for s in $(seq 63 1 65) ; do
for s_k in 65 129 ; do
for s_knew in 0 64 $s_k ; do
for hdim in 32 64 128 256 ; do
for ri in 0 1 ; do
for rdim in 0 16 32 $hdim ; do
for page_block_size in 0 128 ; do
for cache_batch_idx in 0 1 ; do
$EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
done ; done ; done ; done ; done
done ; done ; done
}
set -x
run_fp16_bf16_tests
run_fp8_tests
if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests
fi
set +x
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <algorithm>
#include <cstdint>
#include <cstdlib>
#include <functional>
#include <optional>
#include <ostream>
#include <sstream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "ck_tile/core/container/span.hpp"
enum class mode_enum
{
batch = 0,
group
};
std::ostream& operator<<(std::ostream& stream, mode_enum mode)
{
return stream << (mode == mode_enum::batch ? "batch" : "group");
}
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
{
std::vector<int32_t> seqstarts = {0};
for(int32_t seqlen : seqlens)
{
seqstarts.push_back(seqstarts.back() + seqlen);
}
assert(seqstarts.size() == seqlens.size() + 1);
return seqstarts;
}
std::vector<int32_t> generate_seqlens(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1, // if not negative, clamp min
int32_t seqlen_max = -1, // if not negative, clamp max
std::optional<unsigned> seed = std::nullopt)
{
assert(0 < count);
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
assert(seqlen_min <= seqlen_max);
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
if(mode == mode_enum::group && 1 < count)
{
using size_type = std::vector<int32_t>::size_type;
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
auto next_step = std::bind(step_dist, std::ref(random_engine));
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
{
const size_type to_decrease = next_idx();
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
if(seqlens[to_decrease] == seqlen_min)
{
continue;
}
const size_type to_increase = (to_decrease + next_step()) % count;
if(seqlens[to_increase] >= seqlen_max)
{
continue;
}
--seqlens[to_decrease];
++seqlens[to_increase];
}
}
return seqlens;
}
std::vector<int32_t> generate_seqstarts(mode_enum mode,
unsigned count,
int32_t seqlen_avg,
int32_t seqlen_min = -1,
int32_t seqlen_max = -1,
std::optional<unsigned> seed = std::nullopt)
{
return to_seqstarts(generate_seqlens(mode, count, seqlen_avg, seqlen_min, seqlen_max, seed));
}
// return random integer generated uniformly in range [low, high]
template <typename Int = int>
auto randint(Int low, Int high, std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>, Int>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
return dist(engine);
}
// return random integers generated uniformly in range [low, high]
template <typename Int, typename ForwardIterator>
auto randints(ForwardIterator first,
ForwardIterator last,
Int low,
Int high,
std::optional<unsigned> seed = std::nullopt)
-> std::enable_if_t<std::is_integral_v<Int>>
{
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::uniform_int_distribution<Int> dist(low, high);
std::generate(first, last, [&] { return dist(engine); });
}
/*
* decode the seqlen string from cmdline
* example (assume batch=3)
* q_val=1,2,3 k_val=4,5,6 -> OK
* q_val=1,2,3 -> OK, k same as q
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
* q_val=1,2,3,4 -> OK, but ignore exceed one
*
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
*/
std::tuple<std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>,
std::vector<ck_tile::index_t>>
decode_seqlen(mode_enum mode,
ck_tile::index_t batch,
std::string q_val,
std::string k_val,
std::string k_pad_val,
ck_tile::index_t seqlen_k_min = 0,
bool need_append_kvcache = false,
std::optional<unsigned> seed = std::nullopt)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
if(mode == mode_enum::batch)
{
ck_tile::index_t q = _S2I_(q_val);
ck_tile::index_t k = _S2I_(k_val);
auto s_q = std::vector<ck_tile::index_t>(batch, q);
auto s_k = [&] {
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
if(1 < batch && need_append_kvcache)
{
// to keep the original s_k value, we always use seqlen_k_max in first batch
randints(std::next(seqlen_ks.begin()),
seqlen_ks.end(),
seqlen_k_min,
seqlen_k_max,
seed);
return seqlen_ks;
}
return seqlen_ks;
}();
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
// s_k should be greater than or equal to seqlen_k_min if provided
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
else
{
ck_tile::index_t idx = 0;
std::string::size_type pos_q = 0;
std::string::size_type pos_k = 0;
std::string::size_type pos_kp = 0;
std::vector<ck_tile::index_t> s_q;
std::vector<ck_tile::index_t> s_k;
std::vector<ck_tile::index_t> s_kpad;
while(true)
{
auto found_q = q_val.find(',', pos_q);
auto found_k = k_val.find(',', pos_k);
auto found_kp = k_pad_val.find(',', pos_kp);
ck_tile::index_t q = _S2I_(
q_val.substr(pos_q, found_q == std::string::npos ? found_q : found_q - pos_q));
ck_tile::index_t k = _S2I_(
k_val.substr(pos_k, found_k == std::string::npos ? found_k : found_k - pos_k));
ck_tile::index_t kp = _S2I_(k_pad_val.substr(
pos_kp, found_kp == std::string::npos ? found_kp : found_kp - pos_kp));
s_q.push_back(q);
s_k.push_back(k < 0 ? q : k);
s_kpad.push_back(kp);
// s_k should be greater than or equal to seqlen_k_min
if(s_k.back() < seqlen_k_min)
{
std::ostringstream msg;
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
throw std::runtime_error(msg.str());
}
idx++;
if(found_q == std::string::npos || idx >= batch)
{
break;
}
pos_q = found_q + 1;
pos_k = found_k == std::string::npos ? pos_k : found_k + 1;
pos_kp = found_kp == std::string::npos ? pos_kp : found_kp + 1;
}
if(idx < batch)
{
auto rem_q = generate_seqlens(mode, batch - idx, s_q.back(), 1, s_kpad.back(), seed);
auto rem_k =
generate_seqlens(mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), seed);
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
}
return std::make_tuple(s_q, s_k, s_kpad);
}
#undef _S2I_
}
int env_get_int(const char* var_name, int default_int)
{
char* v = getenv(var_name);
int r = default_int;
if(v)
r = std::atoi(v);
return r;
}
template <typename RandomAccessIterator, typename Int>
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
RandomAccessIterator last,
Int value,
std::optional<unsigned> seed = std::nullopt)
{
std::iota(first, last, value);
std::mt19937 engine(seed.has_value() ? *seed : std::random_device{}());
std::shuffle(first, last, engine);
}
...@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) ...@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe) add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm) add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm) add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flexattn)
add_subdirectory(35_batched_transpose) add_subdirectory(35_batched_transpose)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment