Unverified Commit 7d50244e authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #209 from ROCm/andriy/merge_from_public

Update develop branch from public repository
parents f221c2b0 d51701d4
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "matrix_core_swizzle_kernel.hpp"
#include <string>
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum class matrix_core_inst_enum
{
MFMA_32x32x8_F16 = 0,
MFMA_16x16x16_F16 = 1,
};
namespace detail {
template <matrix_core_inst_enum>
struct to_warp_gemm;
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
};
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
};
} // namespace detail
template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
// TODO: in below permute pattern, the last 3 dim is within wave
enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
//
// 4(waves) 32(mfma_m lane)
// | |
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
// nr kr |
// nr 4 32 kr 2 8 2(klane)
//
// permute: 0,1,4,2,5,3,6
// or
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
// permute: 0,1,2,4,5,3,6
//
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
// for simplicity, only consider n/k is multiple of block-size
// independend host arg with no template
struct matrix_core_swizzle_host_args
{
const void* p_src;
void* p_dst;
int32_t batch;
int32_t n;
int32_t k;
};
// NOTE: this kernel could follow the style of generic permute kernel
// but here we pass in fixed layout as template arg and generate different kernel instance
// purposely
template <int BLOCK_SIZE_ = 256,
int NPerBlock_ = 256,
int KPerBlock_ = 128,
matrix_core_permute_style pstyle_ =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
struct matrix_core_swizzle_kernel
{
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
static constexpr matrix_core_inst_enum Inst = Inst_;
static constexpr ck_tile::index_t Alignment = 8;
karg a;
dim3 grids;
using WarpGemm = to_warp_gemm_t<Inst>;
__host__ matrix_core_swizzle_kernel(harg h)
{
a = h;
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
grids = dim3(ks, ns, h.batch);
}
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<5, 3>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 4, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
__device__ static constexpr auto get_dst_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<3>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 2, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 3, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
#if MERGE_2D_013425
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 2>, // Y major
sequence<0, 0, 3>>{}); // y minor
#else
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 3>, // Y major
sequence<0, 0, 2>>{}); // y minor
#endif
// clang-format on
}
}
__device__ void operator()(karg a_)
{
using namespace ck_tile;
index_t i_k = blockIdx.x;
index_t i_n = blockIdx.y;
index_t i_b = blockIdx.z;
constexpr index_t k2 = Alignment;
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
const index_t k0 = a_.k / (k1 * k2);
const index_t n0 = a_.n / (n1 * n2);
constexpr index_t k2_tile = Alignment;
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
const auto src_view = [&]() {
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_src,
make_tuple(n0, n1, n2, k0, k1, k2),
number<Alignment>{}); // control vector load
return tmp;
}();
const auto src_window = make_tile_window(src_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<n2_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
get_src_dist());
auto dst_view = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, k0, n1, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, n1, k0, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else
{
#if MERGE_2D_013425
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
// constexpr index_t waveflatten = kw*nw*kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
number<Alignment>{}); // control vector load
auto tmp_1 = transform_tensor_view(
tmp,
make_tuple(
make_merge_transform(make_tuple(nr, number<nw>{})),
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load
return tmp;
#endif
}
}();
auto dst_window = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<k0_tile>{},
number<n1_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist());
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist());
}
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view,
make_tuple(number<nr_tile>{},
number<kr_tile>{},
number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist());
#endif
}
}();
// actual load store
auto src_tile = load_tile(src_window);
// now we only swap the distribution from src to dst, no extra movement occurs
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
// final store
store_tile(dst_window, dst_tile);
}
};
};
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/permute.hpp"
#include <string>
struct permute_traits
{
std::string data_type;
};
using permute_args = ck_tile::GenericPermuteHostArgs;
// host API
float permute(permute_traits, permute_args, const ck_tile::stream_config&);
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_permute
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
if [ $# -ge 1 ] ; then
set -x
fi
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
echo "------------------------------------------------------------------"
for prec in "fp8" "fp16" "fp32" ; do
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
echo "------------------------------------------------------------------"
done
add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
# topk-softmax
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-st_o row stride of output/indices, -1 means same as topk (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment