Commit 83074f4c authored by Astha Rai's avatar Astha Rai
Browse files

Merge branch 'develop' into codegen_hiprtc

parents 7c56cd01 5fb150db
...@@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}") ...@@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}")
message("checking which targets are supported") message("checking which targets are supported")
#In order to build just the CK library (without tests and examples) for all supported GPU targets #In order to build just the CK library (without tests and examples) for all supported GPU targets
#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" #use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201"
#the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. #the GPU_TARGETS flag will be reset in this case in order to avoid conflicts.
# #
#In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. #In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures.
if(NOT ENABLE_ASAN_PACKAGING) if(NOT ENABLE_ASAN_PACKAGING)
if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000)
# WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102")
else() else()
set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201")
endif() endif()
else() else()
#build CK only for xnack-supported targets when using ASAN #build CK only for xnack-supported targets when using ASAN
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+")
endif() endif()
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list #if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
......
...@@ -1101,11 +1101,11 @@ pipeline { ...@@ -1101,11 +1101,11 @@ pipeline {
agent{ label rocmnode("gfx90a") } agent{ label rocmnode("gfx90a") }
environment{ environment{
setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx942" \
-DCMAKE_CXX_FLAGS=" -O3 " """ -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ -DGPU_TARGETS="gfx908;gfx90a;gfx942" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
} }
...@@ -1165,7 +1165,7 @@ pipeline { ...@@ -1165,7 +1165,7 @@ pipeline {
execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \
-D CMAKE_BUILD_TYPE=Release \ -D CMAKE_BUILD_TYPE=Release \
-D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \
-D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """
} }
steps{ steps{
......
...@@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle ...@@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle
using DeviceReduceInstance = using DeviceReduceInstance =
ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType, ck::tensor_operation::device::DeviceReduceMultiBlock<OutputDataType,
OutputDataType, ScaleDataType,
OutputDataType, OutputDataType,
NumDim, NumDim,
NumDim, NumDim,
...@@ -108,7 +108,8 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input, ...@@ -108,7 +108,8 @@ void reference_scale_permute_amax(Tensor<InputDataType>& input,
host_output_scaled_casted_transposed(m, k) = y1; host_output_scaled_casted_transposed(m, k) = y1;
const OutputDataType y_fabs = const OutputDataType y_fabs =
ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0))); ck::type_convert<OutputDataType>(ck::math::abs(ck::type_convert<float>(y0)));
host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0)); host_output_amax(0) = ck::type_convert<OutputDataType>(ck::math::max(
ck::type_convert<float>(y_fabs), ck::type_convert<float>(host_output_amax(0))));
} }
} }
} }
......
...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
if(FILE_NAME MATCHES "_xdl") if(FILE_NAME MATCHES "_xdl")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic)
elseif(FILE_NAME MATCHES "_wmma") elseif(FILE_NAME MATCHES "_wmma")
list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif() endif()
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter ...@@ -47,6 +47,9 @@ def list_blobs(output_file : Optional[str], api_list : List[str], kernel_filter
assert output_file is not None assert output_file is not None
file_path = Path(output_file) file_path = Path(output_file)
# create an empty file / drop its contents if it exists
open(file_path, "w").close()
for api in api_list: for api in api_list:
handler = handlers[api][HandlerId.LIST_BLOBS] handler = handlers[api][HandlerId.LIST_BLOBS]
handler(file_path, kernel_filter, receipt, mask_impl) handler(file_path, kernel_filter, receipt, mask_impl)
......
...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do ...@@ -29,14 +29,14 @@ while getopts ":sa" opt; do
done done
run_fp16_bf16_tests() { run_fp16_bf16_tests() {
local NUM_SPLITS=(1) local NUM_SPLITS="1"
local PAGE_BLOCK_SIZE=(0) local PAGE_BLOCK_SIZE="0"
local CACHE_BATCH_IDX=(0) local CACHE_BATCH_IDX="0"
if [ $TEST_SPLITKV -eq 1 ] ; then if [ $TEST_SPLITKV -eq 1 ] ; then
NUM_SPLITS+=(2 3) NUM_SPLITS="$NUM_SPLITS 2 3"
PAGE_BLOCK_SIZE+=(128) PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
CACHE_BATCH_IDX+=(1) CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
fi fi
for prec in "fp16" "bf16" ; do for prec in "fp16" "bf16" ; do
...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() { ...@@ -47,9 +47,9 @@ run_fp16_bf16_tests() {
for lse in 0 1 ; do for lse in 0 1 ; do
for bias in "n" "e" "a" ; do for bias in "n" "e" "a" ; do
for p_drop in 0.0 0.2 ; do for p_drop in 0.0 0.2 ; do
for num_splits in "${NUM_SPLITS[@]}" ; do for num_splits in $NUM_SPLITS ; do
for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do for page_block_size in $PAGE_BLOCK_SIZE ; do
for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then ...@@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then
run_fp16_appendkv_tests run_fp16_appendkv_tests
fi fi
set +x set +x
\ No newline at end of file
...@@ -57,6 +57,7 @@ template <typename XDataType_, ...@@ -57,6 +57,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_, bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0, ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0> ck_tile::index_t kFusedQuant_ = 0>
...@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_ ...@@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
static constexpr bool kFastFDiv = kFastFDiv_;
static constexpr bool kTwoPass = kTwoPass_; static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
...@@ -134,6 +136,7 @@ template <typename XDataType_, ...@@ -134,6 +136,7 @@ template <typename XDataType_,
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveMeanInvStd_,
bool kFastFDiv_,
bool kTwoPass_, bool kTwoPass_,
int kFusedAdd_, int kFusedAdd_,
int kFusedQuant_> int kFusedQuant_>
...@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_, ...@@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_<XDataType_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kSaveMeanInvStd_, kSaveMeanInvStd_,
kFastFDiv_,
kTwoPass_, kTwoPass_,
kFusedAdd_, kFusedAdd_,
kFusedQuant_>; kFusedQuant_>;
...@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a) ...@@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a)
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN, using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
Traits_::kSaveMeanInvStd, Traits_::kSaveMeanInvStd,
Traits_::kFastFDiv,
Traits_::kTwoPass, Traits_::kTwoPass,
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd), static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>; static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
...@@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
#include "layernorm2d_fwd_api_common.hpp" #include "layernorm2d_fwd_api_common.hpp"
// clang-format off // clang-format off
// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep // prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep
{F_instance_def} {F_instance_def}
// clang-format on // clang-format on
...@@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
F_Vector_N : int F_Vector_N : int
F_kPadN : bool F_kPadN : bool
F_kSaveMeanInvStd_ : bool F_kSaveMeanInvStd_ : bool
F_kFastFDiv_ : bool
F_kTwoPass_ : bool F_kTwoPass_ : bool
F_kFusedAdd : int F_kFusedAdd : int
F_kFusedQuant : int F_kFusedQuant : int
...@@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
@property @property
def trait_name(self) ->str: def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
return t_ return t_
...@@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
fused_add_list = [0, 1] fused_add_list = [0, 1]
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
# rm rn tm tn vn pd mv 2p add sweep # rm rn tm tn vn pd mv fdiv 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0),
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0),
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0),
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0),
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0),
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0),
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0),
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0),
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0),
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)],
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]}
total_blob = list() total_blob = list()
for hs_key in h_trait_dict: for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key] hs = h_trait_dict[hs_key]
...@@ -559,7 +568,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, ...@@ -559,7 +568,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
w_p = Path(self.working_path) w_p = Path(self.working_path)
list_p = w_p / 'layernorm2d_fwd_blobs.txt' list_p = w_p / 'layernorm2d_fwd_blobs.txt'
blobs = self.get_blobs() blobs = self.get_blobs()
with list_p.open('a') as list_f: with list_p.open('w') as list_f:
# api related file # api related file
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n") list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n") list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
......
...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[]) ...@@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[])
ck_tile::ArgParser arg_parser; ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension") arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension") .insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n") .insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
.insert("e", "1e-5", "epsilon") .insert("e", "1e-5", "epsilon")
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
.insert("v", "1", "cpu validation or not") .insert("v", "1", "cpu validation or not")
...@@ -54,11 +57,20 @@ template <typename InDataType, ...@@ -54,11 +57,20 @@ template <typename InDataType,
bool SaveMeanVar> bool SaveMeanVar>
bool run(const ck_tile::ArgParser& arg_parser) bool run(const ck_tile::ArgParser& arg_parser)
{ {
ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n"); ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(stride < 0) if(x_stride < 0)
stride = n; x_stride = n;
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
if(xr_stride < 0)
xr_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
if(yr_stride < 0)
yr_stride = n;
float epsilon = arg_parser.get_float("e"); float epsilon = arg_parser.get_float("e");
std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o"); std::string prec_o = arg_parser.get_str("prec_o");
...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
return false; return false;
} }
assert(stride >= n); assert(x_stride >= n);
using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>; using TypeConfig = LayerNormTypeConfig<InDataType, OutDataType, XScaleDataType, YScaleDataType>;
...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
using ComputeDataType = typename TypeConfig::ComputeDataType; using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify // host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1}); ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n}); ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<BetaDataType> beta_host({n}); ck_tile::HostTensor<BetaDataType> beta_host({n});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m}); ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m}); ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
}(); }();
std::cout << "[" << prec_str << "]" std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
layernorm2d_fwd_traits traits{ layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
epsilon, epsilon,
m, m,
n, n,
stride}; x_stride, // x row_stride
xr_stride, // x residule row stride
y_stride, // y row stride
yr_stride}; // y residule row stride
float ave_time = layernorm2d_fwd( float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_buf.FromDevice(y_host_dev.data()); y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1}); ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
if(fused_add == 1) if(fused_add == 1)
{ {
y_residual_buf.FromDevice(y_residual_host_dev.data()); y_residual_buf.FromDevice(y_residual_host_dev.data());
...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto [rtol, atol] = get_elimit<InDataType>(); auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n) if(x_stride == n)
{ {
pass = ck_tile::check_err( pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
{ {
for(int i_r = 0; i_r < m; i_r++) for(int i_r = 0; i_r < m; i_r++)
{ {
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride, std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
y_host_dev.begin() + i_r * stride + n); y_host_dev.begin() + i_r * y_stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride, std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
y_host_ref.begin() + i_r * stride + n); y_host_ref.begin() + i_r * y_stride + n);
pass &= ck_tile::check_err(y_host_dev_row, pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row, y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) + std::string("OUT[") + std::to_string(i_r) +
...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(fused_add == 1) if(fused_add == 1)
{ {
std::vector<YResidualDataType> y_residual_host_dev_row( std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride, y_residual_host_dev.begin() + i_r * yr_stride,
y_residual_host_dev.begin() + i_r * stride + n); y_residual_host_dev.begin() + i_r * yr_stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row( std::vector<YResidualDataType> y_residual_host_ref_row(
x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row, pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row, y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) + std::string("ADD[") + std::to_string(i_r) +
......
add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp)
target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS})
# moe-sorting
This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_moe_sorting -j
```
This will result in an executable `build/bin/tile_example_moe_sorting`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i index data type. (currently only fp32 supported now) (default:int32)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <set>
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "moe_sorting_api.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "int32", "index data type. (currently only int32 supported now)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "128", "number of input tokens")
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
.insert("moe_buf_size", "0", "moe_buf_size")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename IndexType>
void topid_unique_gen(
std::vector<IndexType>& host_tensor, int tokens, int topk, int num_expert, int seed)
{
size_t total_size = topk * tokens;
std::srand(seed);
std::set<IndexType> unique_set;
IndexType current_v;
for(size_t i = 0; i < total_size; i++)
{
if(i % topk == 0)
{
unique_set.clear();
}
current_v = std::rand() % num_expert;
while(unique_set.find(current_v) != unique_set.end())
{
current_v = std::rand() % num_expert;
}
unique_set.insert(current_v);
host_tensor[i] = current_v;
}
}
template <typename WeightType, typename IndexType = ck_tile::index_t>
bool test_moe_sorting(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int num_experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int moe_buf_size = args.get_int("moe_buf_size");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > num_experts)
{
printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n",
topk,
num_experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<IndexType> topk_ids_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<WeightType> weights_host({tokens, topk}, {topk, 1});
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_expert_ids_dev(
sorted_expert_ids_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes());
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
moe_sorting_trait trait{index_prec, weight_prec};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
static_cast<ck_tile::index_t>(moe_buf_size * sizeof(float))};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = moe_sorting(trait, karg, sc);
printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ",
index_prec.c_str(),
weight_prec.c_str(),
tokens,
num_experts,
topk,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
sorted_ids_dev.FromDevice(sorted_ids_host.data());
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_size > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_ref({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_ref({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_ref({max_output_ids / unit_size}, {1});
int32_t ref_total_tokens_post_pad = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(topk_ids_host,
weights_host,
sorted_ids_ref,
sorted_weights_ref,
sorted_expert_ids_ref,
ref_total_tokens_post_pad,
num_experts,
unit_size);
rtn &= ck_tile::check_err(
sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6);
rtn &= ck_tile::check_err(sorted_weights_host,
sorted_weights_ref,
std::string("OUT Error: Incorrect w!"),
1e-6,
1e-6);
rtn &= ck_tile::check_err(sorted_expert_ids_host,
sorted_expert_ids_ref,
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
if(moe_buf_size)
{
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}
rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0];
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string index_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0)
{
r &= test_moe_sorting<float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "moe_sorting_api.hpp"
#define MOE_SORTING_DISPATCH(unroll_num_) \
constexpr ck_tile::index_t unroll_num = unroll_num_; \
using ms_problem = ck_tile::MoeSortingProblem<index_t, ms_weight_type, unroll_num>; \
using kernel = ck_tile::MoeSortingKernel<ms_problem>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
const auto lds_bytes = kernel::GetSmemSize(a); \
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \
return ave_time;
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
if(t.weight_type == "fp32" && t.index_type == "int32")
{
if(a.num_experts > 127)
{
printf("lds size exceed, only support experts <127 \n");
return -1;
}
if(a.moe_buf_bytes % 16)
{
printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes);
return -1;
}
using index_t = ck_tile::index_t;
using ms_weight_type = float;
index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64);
switch(smem_io_unroll_num)
{
case(1): {
MOE_SORTING_DISPATCH(1);
}
case(2): {
MOE_SORTING_DISPATCH(2);
}
case(3): {
MOE_SORTING_DISPATCH(3);
}
case(5): {
MOE_SORTING_DISPATCH(5);
}
case(6): {
MOE_SORTING_DISPATCH(6);
}
case(7): {
MOE_SORTING_DISPATCH(7);
}
case(8): {
MOE_SORTING_DISPATCH(8);
}
case(9): {
MOE_SORTING_DISPATCH(9);
}
case(10): {
MOE_SORTING_DISPATCH(10);
}
case(11): {
MOE_SORTING_DISPATCH(11);
}
default: {
MOE_SORTING_DISPATCH(4);
}
}
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/moe_sorting.hpp"
struct moe_sorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
};
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
{
};
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
# #!/bin/sh
EXE=./build/bin/tile_example_moe_sorting
$EXE -t=80 -e=17 -moe_buf_size=16
$EXE -t=111 -e=117 -moe_buf_size=4
$EXE -t=1000 -e=55 -moe_buf_size=1024
$EXE -t=99 -e=120 -moe_buf_size=10244
$EXE -t=175 -e=64 -k=8
$EXE -t=65 -e=8 -k=2
$EXE -t=1 -e=25
$EXE -t=31 -e=19 -k=15
$EXE -t=81 -e=37 -k=7
$EXE -t=23 -e=1 -k=1
$EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
\ No newline at end of file
...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax) ...@@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax)
add_subdirectory(10_rmsnorm2d) add_subdirectory(10_rmsnorm2d)
add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(11_add_rmsnorm2d_rdquant)
add_subdirectory(12_smoothquant) add_subdirectory(12_smoothquant)
add_subdirectory(13_moe_sorting)
...@@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) ...@@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
#define __gfx101__ #define __gfx101__
#endif #endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \
defined(__gfx10_3_generic__)
#define __gfx103__ #define __gfx103__
#endif #endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
defined(__gfx1103__) || defined(__gfx11_generic__)
#define __gfx11__ #define __gfx11__
#endif #endif
#if defined(__gfx1200__) || defined(__gfx1201__) #if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__)
#define __gfx12__ #define __gfx12__
#endif #endif
......
...@@ -93,12 +93,12 @@ __global__ void ...@@ -93,12 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -60,12 +60,12 @@ __global__ void ...@@ -60,12 +60,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
...@@ -117,12 +117,12 @@ __global__ void ...@@ -117,12 +117,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
// Pass two lds pointer is the key to tell compiler that ds_read/write // Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy // operate on different lds chunk at same time without order dependecy
......
...@@ -98,12 +98,12 @@ __global__ void ...@@ -98,12 +98,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = const long_index_t c_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
...@@ -60,12 +60,12 @@ __global__ void ...@@ -60,12 +60,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
...@@ -155,12 +155,12 @@ __global__ void ...@@ -155,12 +155,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = const long_index_t a_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = const long_index_t b_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = const long_index_t e_batch_offset = amd_wave_read_first_lane(
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment