"include/vscode:/vscode.git/clone" did not exist on "72d9bd744093e5ab8be77eb3594758d5bdd16c54"
Unverified Commit 7d50244e authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #209 from ROCm/andriy/merge_from_public

Update develop branch from public repository
parents f221c2b0 d51701d4
#include "matrix_core_swizzle.hpp"
#include "matrix_core_swizzle_kernel.hpp"
float matrix_core_swizzle(matrix_core_swizzle_traits t,
matrix_core_swizzle_args a,
const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
if(t.inst.compare("32x32x8") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_32x32x8_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
else if(t.inst.compare("16x16x16") == 0)
{
constexpr int BLOCK_SIZE = 256;
constexpr int NPerBlock = 256;
constexpr int KPerBlock = 128;
constexpr matrix_core_inst_enum Inst = matrix_core_inst_enum::MFMA_16x16x16_F16;
if(t.permute.compare("0,1,4,2,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,2,4,5,3,6") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
auto k = Kernel(a);
float ave_time = ck_tile::launch_kernel(s, k);
return ave_time;
}
}
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "matrix_core_swizzle_kernel.hpp"
#include <string>
struct matrix_core_swizzle_traits
{
std::string data_type; // fp16 only
std::string inst; // 32x32x8, 16x16x16
std::string permute; //
};
using matrix_core_swizzle_args = matrix_core_swizzle_host_args;
// host API
float matrix_core_swizzle(matrix_core_swizzle_traits,
matrix_core_swizzle_args,
const ck_tile::stream_config&);
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum class matrix_core_inst_enum
{
MFMA_32x32x8_F16 = 0,
MFMA_16x16x16_F16 = 1,
};
namespace detail {
template <matrix_core_inst_enum>
struct to_warp_gemm;
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_32x32x8_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M32N32K8;
};
template <>
struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
{
using type = ck_tile::WarpGemmMfmaF16F16F32M16N16K16;
};
} // namespace detail
template <matrix_core_inst_enum Inst>
using to_warp_gemm_t = typename detail::to_warp_gemm<Inst>::type;
// TODO: in below permute pattern, the last 3 dim is within wave
enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
// now batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
// assume using 32x32x8-f16, 4 waves and extend the KPerLane to 8xfp16(dwordx4)
//
// 4(waves) 32(mfma_m lane)
// | |
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2 -> 8(thread loading)
// nr kr |
// nr 4 32 kr 2 8 2(klane)
//
// permute: 0,1,4,2,5,3,6
// or
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*n1*k0*k1*n2*k2 -> 8(thread loading)
// permute: 0,1,2,4,5,3,6
//
// this kernel only deal with fp16/bf16 data(16bit), and use 2d block size to do the swizzling
// for simplicity, only consider n/k is multiple of block-size
// independend host arg with no template
struct matrix_core_swizzle_host_args
{
const void* p_src;
void* p_dst;
int32_t batch;
int32_t n;
int32_t k;
};
// NOTE: this kernel could follow the style of generic permute kernel
// but here we pass in fixed layout as template arg and generate different kernel instance
// purposely
template <int BLOCK_SIZE_ = 256,
int NPerBlock_ = 256,
int KPerBlock_ = 128,
matrix_core_permute_style pstyle_ =
matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2,
matrix_core_inst_enum Inst_ = matrix_core_inst_enum::MFMA_32x32x8_F16>
struct matrix_core_swizzle_kernel
{
using karg = matrix_core_swizzle_host_args;
using harg = matrix_core_swizzle_host_args;
static constexpr int BLOCK_SIZE = BLOCK_SIZE_;
static constexpr int WavesPerBlock_N = 4;
static constexpr int WavesPerBlock_K = 1;
static_assert(WavesPerBlock_N * WavesPerBlock_K * 64 == BLOCK_SIZE);
static constexpr int NPerBlock = NPerBlock_;
static constexpr int KPerBlock = KPerBlock_;
static constexpr matrix_core_permute_style pstyle = pstyle_;
static constexpr matrix_core_inst_enum Inst = Inst_;
static constexpr ck_tile::index_t Alignment = 8;
karg a;
dim3 grids;
using WarpGemm = to_warp_gemm_t<Inst>;
__host__ matrix_core_swizzle_kernel(harg h)
{
a = h;
ck_tile::index_t ns = (h.n + NPerBlock - 1) / NPerBlock;
ck_tile::index_t ks = (h.k + KPerBlock - 1) / KPerBlock;
grids = dim3(ks, ns, h.batch);
}
__host__ bool is_applicable(harg h) { return h.n % NPerBlock == 0 && h.k % KPerBlock == 0; }
__host__ void operator()(const ck_tile::stream_config& s) const
{
ck_tile::kentry<BLOCK_SIZE, 1, kernel><<<grids, BLOCK_SIZE, 0, s.stream_id_>>>(a);
}
struct kernel
{
__device__ static constexpr auto get_src_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<N2>, sequence<K0>, sequence<K1>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<5, 3>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 4, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
__device__ static constexpr auto get_dst_dist()
{
using namespace ck_tile;
constexpr index_t K2 = Alignment;
constexpr index_t N2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t K1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t N1 = BLOCK_SIZE / get_warp_size();
static_assert(NPerBlock % (N1 * N2) == 0);
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t K0 = KPerBlock / (K1 * K2);
constexpr index_t N0 = NPerBlock / (N1 * N2);
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<K0>, sequence<N1>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<3>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 2, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
// clang-format off
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,// 0
// 1 2 3 4 5 6
tuple<sequence<N0>, sequence<N1>, sequence<K0>, sequence<K1>, sequence<N2>, sequence<K2>>,
// N1 K1 N2
tuple<sequence<2>, sequence<4, 5>>,
tuple<sequence<0>, sequence<0, 0>>,
// N0 K0 K2
sequence<1, 3, 6>,
sequence<0, 0, 0>>{});
// clang-format on
}
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
static_assert(KPerBlock % (K1 * K2) == 0);
constexpr index_t Nr = NPerBlock / Nw;
constexpr index_t Kr = KPerBlock / (Kv * Kw);
constexpr index_t Nr_p = WavesPerBlock_N;
constexpr index_t Kr_p = WavesPerBlock_K;
constexpr index_t Nr_y = Nr / Nr_p;
constexpr index_t Kr_y = Kr / Kr_p;
return make_static_tile_distribution(
#if MERGE_2D_013425
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple<sequence<Nr_y, Nr_p, Nw>, sequence<Kr_y, Kr_p, Kw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<2, 1>>, // p major
tuple<sequence<1 , 1>, sequence<2, 2>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 2>, // Y major
sequence<0, 0, 3>>{}); // y minor
#else
tile_distribution_encoding<
sequence<1>,// 0 R
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple<sequence<Nr_y, Nr_p>, sequence<Kr_y, Kr_p>, sequence<Kw, Nw, Kv>>, // H
// Nr_p, Kr_p Kw Nw
tuple<sequence<1 , 2>, sequence<3, 3>>, // p major
tuple<sequence<1 , 1>, sequence<0, 1>>, // p minor
// Nr_y Kr_y Kv
sequence<1, 2, 3>, // Y major
sequence<0, 0, 2>>{}); // y minor
#endif
// clang-format on
}
}
__device__ void operator()(karg a_)
{
using namespace ck_tile;
index_t i_k = blockIdx.x;
index_t i_n = blockIdx.y;
index_t i_b = blockIdx.z;
constexpr index_t k2 = Alignment;
constexpr index_t n2 = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1 = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1 = BLOCK_SIZE / get_warp_size();
const index_t k0 = a_.k / (k1 * k2);
const index_t n0 = a_.n / (n1 * n2);
constexpr index_t k2_tile = Alignment;
constexpr index_t n2_tile = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t k1_tile = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t n1_tile = BLOCK_SIZE / get_warp_size();
constexpr index_t k0_tile = KPerBlock / (k1_tile * k2_tile);
constexpr index_t n0_tile = NPerBlock / (n1_tile * n2_tile);
const fp16_t* p_src = reinterpret_cast<const fp16_t*>(a_.p_src) + i_b * a_.k * a_.n;
fp16_t* p_dst = reinterpret_cast<fp16_t*>(a_.p_dst) + i_b * a_.k * a_.n;
const auto src_view = [&]() {
const auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_src,
make_tuple(n0, n1, n2, k0, k1, k2),
number<Alignment>{}); // control vector load
return tmp;
}();
const auto src_window = make_tile_window(src_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<n2_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, 0, i_k * k0_tile, 0, 0},
get_src_dist());
auto dst_view = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, k0, n1, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(n0, n1, k0, k1, n2, k2),
number<Alignment>{}); // control vector load
return tmp;
}
else
{
#if MERGE_2D_013425
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
// constexpr index_t waveflatten = kw*nw*kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, number<kw>{}, number<nw>{}, number<kv>{}),
number<Alignment>{}); // control vector load
auto tmp_1 = transform_tensor_view(
tmp,
make_tuple(
make_merge_transform(make_tuple(nr, number<nw>{})),
make_merge_transform(make_tuple(kr, number<kw>{}, number<kv>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten = kw * nw * kv;
const index_t kr = a_.k / (k1 * k2);
const index_t nr = a_.n / nw;
auto tmp = make_naive_tensor_view_packed<address_space_enum::global>(
p_dst,
make_tuple(nr, kr, waveflatten),
number<Alignment>{}); // control vector load
return tmp;
#endif
}
}();
auto dst_window = [&]() {
if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_k0_n1_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<k0_tile>{},
number<n1_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, i_k * k0_tile, 0, 0, 0, 0},
get_dst_dist());
}
else if constexpr(pstyle == matrix_core_permute_style::permute_b_n0_n1_k0_k1_n2_k2)
{
return make_tile_window(dst_view,
make_tuple(number<n0_tile>{},
number<n1_tile>{},
number<k0_tile>{},
number<k1_tile>{},
number<n2_tile>{},
number<k2_tile>{}),
{i_n * n0_tile, 0, i_k * k0_tile, 0, 0, 0},
get_dst_dist());
}
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
constexpr index_t waveflatten_tile = kw * nw * kv;
constexpr index_t nr_tile = NPerBlock / nw;
constexpr index_t kr_tile = KPerBlock / (kw * kv);
return make_tile_window(dst_view,
make_tuple(number<nr_tile>{},
number<kr_tile>{},
number<waveflatten_tile>{}),
{i_n * nr_tile, i_k * kr_tile, 0},
get_dst_dist());
#endif
}
}();
// actual load store
auto src_tile = load_tile(src_window);
// now we only swap the distribution from src to dst, no extra movement occurs
auto dst_tile = make_static_distributed_tensor<fp16_t>(get_dst_dist());
dst_tile.get_thread_buffer() = src_tile.get_thread_buffer();
// final store
store_tile(dst_window, dst_tile);
}
};
};
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "permute.hpp"
#include "ck_tile/host.hpp"
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
#include "alternative_impl/matrix_core_swizzle.hpp"
#endif
namespace detail {
template <int bytes>
struct to_integer_type;
template <>
struct to_integer_type<4>
{
using type = int32_t;
};
template <>
struct to_integer_type<2>
{
using type = int16_t;
};
template <>
struct to_integer_type<1>
{
using type = int8_t;
};
} // namespace detail
template <int bytes>
using to_integer_type = typename detail::to_integer_type<bytes>::type;
// host API (shoule come from codegen)
float permute(permute_traits t, permute_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp8") == 0)
{
using DataType = ck_tile::fp8_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp16") == 0)
{
using DataType = ck_tile::half_t;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
else if(t.data_type.compare("fp32") == 0)
{
using DataType = float;
using PipelineProblem = ck_tile::GenericPermuteProblem<DataType>;
using Kernel = ck_tile::GenericPermute<PipelineProblem>;
auto kargs = Kernel::MakeKargs(a);
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
return 0;
}
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp8/fp16/fp32 (representing 8/16/32 bit data)")
.insert("shape", "2,3,4", "the shape of the input tensor")
.insert("perm", "2,1,0", "permute perm")
.insert("kname", "0", "t to 1 will print kernel name")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
// "1,2,3,4" -> vector{1,2,3,4}
std::vector<ck_tile::index_t> decode_vec(std::string q_val)
{
#define _S2I_(str_) static_cast<ck_tile::index_t>(std::atoi((str_).c_str()))
std::string::size_type pos = 0;
std::vector<ck_tile::index_t> v;
while(true)
{
auto found = q_val.find(',', pos);
ck_tile::index_t n =
_S2I_(q_val.substr(pos, found == std::string::npos ? found : found - pos));
v.push_back(n);
if(found == std::string::npos)
{
break;
}
pos = found + 1;
}
return v;
#undef _S2I_
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
auto shape = decode_vec(arg_parser.get_str("shape"));
auto perm = decode_vec(arg_parser.get_str("perm"));
int stream_warmup = arg_parser.get_int("warmup");
int stream_repeat = arg_parser.get_int("repeat");
bool kname = arg_parser.get_bool("kname");
int seed = arg_parser.get_int("seed");
assert(shape.size() == perm.size());
ck_tile::index_t rank = perm.size();
if(rank > ck_tile::GenericPermuteHostArgs::kMaxRanks)
{
printf("rank %d permute is not support yet\n", rank);
return false;
}
ck_tile::HostTensor<DataType> x(shape);
ck_tile::FillUniformDistributionIntegerValue<DataType>{-15, 15, seed}(x);
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
// std::cout << "@@@@" << tmp << std::endl;
for(int i = 0; i < static_cast<int>(rank); i++)
{
// std::cout << " i:" << i << ", perm:" << perm[i] << ", rak:" <<
// static_cast<int>(rank)
// << std::endl;
tmp[i] = shape[perm[i]];
}
// std::cout << "@@@" << tmp << std::endl;
return tmp;
}();
ck_tile::HostTensor<DataType> y(y_shape);
ck_tile::DeviceMem x_buf(x.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y.get_element_space_size_in_bytes());
x_buf.ToDevice(x.data());
std::cout << "[" << data_type << "] shape:" << shape << "->" << y_shape << ", permute:" << perm
<< std::flush;
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
stream_warmup,
stream_repeat};
float ave_time = 0.f;
auto run_permute = [&]() {
permute_traits t;
t.data_type = data_type;
permute_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.rank = rank;
std::copy(shape.begin(), shape.end(), a.shape);
std::copy(perm.begin(), perm.end(), a.perm);
return permute(t, a, stream_config);
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if((arg_parser.get_str("perm") == std::string("0,1,4,2,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,2,4,5,3,6") ||
arg_parser.get_str("perm") == std::string("0,1,3,4,2,5")))
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
auto nr = shape[1];
auto nw = shape[2];
auto kr = shape[3];
auto kw = shape[4];
auto kv = shape[5];
a.n = nr * nw;
a.k = kr * kw * kv;
if(kv == 8 && kw == 4 && nw == 16 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(kv == 8 && kw == 2 && nw == 32 && nr % 4 == 0 && kr % 8 == 0)
{
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_waveflatten_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
else
{
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
matrix_core_swizzle_args a;
a.p_src = x_buf.GetDeviceBuffer();
a.p_dst = y_buf.GetDeviceBuffer();
a.batch = shape[0];
a.n = shape[1] * shape[2] * shape[3];
a.k = shape[4] * shape[5] * shape[6];
if(shape[6] == 8 && shape[3] == 32 && shape[5] == 2 && shape[2] == 4 &&
shape[4] % 8 == 0 && shape[1] % 2 == 0)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,2x,8x,4,2,32,8 (3,6,16,4,2,32,8)
// shape = *,2x,4,32,8x,2,8 (3,6,4,32,16,2,8)
t.inst = "32x32x8";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else if(shape[6] == 8 && shape[3] == 16 && shape[5] == 4 && shape[2] == 4 &&
shape[4] % 4 == 0 && shape[1] % 4 == 0)
{
// 16x16x16 inst
// perm=0,1,4,2,5,3,6
// y_shape=*,4x,4x,4,4,16,8
// shape = *,4x,4,16,4x,4,8 (3,8,4,16,16,4,8)
t.inst = "16x16x16";
std::cout << ", matrix_core_swizzle_" << t.inst << std::flush;
ave_time = matrix_core_swizzle(t, a, stream_config);
}
else
{
ave_time = run_permute();
}
}
}
else
#endif
{
ave_time = run_permute();
}
std::cout << ", time:" << ave_time << "ms" << std::flush;
bool pass = true;
if(do_validation)
{
reference_permute(x, y, perm);
#if 0
if constexpr (std::is_same_v<float, DataType>){
// using itype = to_integer_type<sizeof(DataType)>;
fflush(stdout);
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", x.mData[zz]);
}
printf("->\n");
for(int zz = 0; zz < static_cast<int>(x.get_element_size()); zz++ ) {
printf("%3.0f ", y.mData[zz]);
}
fflush(stdout);
}
#endif
ck_tile::HostTensor<DataType> y_dev(y.get_lengths());
y_buf.FromDevice(y_dev.data());
pass = std::equal(
y_dev.begin(), y_dev.end(), y.begin(), [&](const DataType& d, const DataType& h) {
using itype = to_integer_type<sizeof(DataType)>;
itype i_d = ck_tile::bit_cast<itype>(d);
itype i_h = ck_tile::bit_cast<itype>(h);
return i_d == i_h;
});
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush;
}
std::cout << std::endl;
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp8")
{
return run<ck_tile::fp8_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "fp32")
{
return run<float>(arg_parser) ? 0 : -2;
}
return -3;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/permute.hpp"
#include <string>
struct permute_traits
{
std::string data_type;
};
using permute_args = ck_tile::GenericPermuteHostArgs;
// host API
float permute(permute_traits, permute_args, const ck_tile::stream_config&);
#!/bin/sh
# TODO: run this script from CK root
BUILD=build
EXE=$BUILD/bin/tile_example_permute
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
# mode=0
# export HIP_VISIBLE_DEVICES=4
if [ $# -ge 1 ] ; then
set -x
fi
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,6,4,32,16,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=5,10,4,32,8,2,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,2,4,5,3,6 $COMMON_ARGS
$EXE -prec=fp16 -shape=2,8,16,8,4,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
$EXE -prec=fp16 -shape=1,24,32,16,2,8 -perm=0,1,3,4,2,5 $COMMON_ARGS
echo "------------------------------------------------------------------"
for prec in "fp8" "fp16" "fp32" ; do
$EXE -prec=$prec -shape=3,8 -perm=1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=48,6,8 -perm=2,1,0 $COMMON_ARGS
$EXE -prec=$prec -shape=24,128,3 -perm=0,2,1 $COMMON_ARGS
$EXE -prec=$prec -shape=4,10,7,6 -perm=0,2,3,1 $COMMON_ARGS
$EXE -prec=$prec -shape=8,24,36,10 -perm=3,1,2,0 $COMMON_ARGS
$EXE -prec=$prec -shape=8,1,36,4 -perm=2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=5,10,16,2,36,4 -perm=4,5,2,1,0,3 $COMMON_ARGS
$EXE -prec=$prec -shape=2,32,8,3,6,2,5,4 -perm=5,2,4,7,1,6,3,0 $COMMON_ARGS
echo "------------------------------------------------------------------"
done
add_executable(tile_example_topk_softmax EXCLUDE_FROM_ALL topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
# topk-softmax
This folder contains example for topk-softmax kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input is a `token*expert` 2d matrix. The op will do a softmax per row(`expert`), then find the `topk` value for each row. Output is a `token*topk` weight(usually fp32) and index(int32) 2d tensor.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_topk_softmax -j
```
This will result in an executable `build/bin/tile_example_topk_softmax`
## example
```
args:
-v weather do CPU validation or not (default:1)
-pr_i input data type. fp16/fp32 (representing 8/16/32 bit data) (default:fp16)
-pr_w output weight data type(currently only fp32 supported now) (default:fp32)
-t number of input tokens (default:32)
-e number of experts (default:8)
-k topk (default:2)
-st_i row stride of input, -1 means same as experts (default:-1)
-st_o row stride of output/indices, -1 means same as topk (default:-1)
-seed seed to be used, -1 means random every time (default:-1)
-kname when set to 1 it will print kernel name (default:0)
```
#!/bin/sh
EXE=./build/bin/tile_example_topk_softmax
for pr_i in "fp16" "bf16" ; do
$EXE -pr_i=$pr_i -t=80 -e=17
$EXE -pr_i=$pr_i -t=111 -e=117
$EXE -pr_i=$pr_i -t=1000 -e=55
$EXE -pr_i=$pr_i -t=99 -e=180
$EXE -pr_i=$pr_i -t=175 -e=64 -k=8
$EXE -pr_i=$pr_i -t=65 -e=8 -k=2
$EXE -pr_i=$pr_i -t=1 -e=25
$EXE -pr_i=$pr_i -t=31 -e=19 -k=15
$EXE -pr_i=$pr_i -t=81 -e=37 -k=7
$EXE -pr_i=$pr_i -t=199 -e=128 -k=13
$EXE -pr_i=$pr_i -t=23 -e=1 -k=1
$EXE -pr_i=$pr_i -t=127 -e=99 -k=19 -st_i=233 -st_o=31
$EXE -pr_i=$pr_i -t=71 -e=11 -k=11 -st_i=30 -st_o=12
$EXE -pr_i=$pr_i -t=1 -e=1 -k=1
$EXE -pr_i=$pr_i -t=99 -e=2 -k=1 -st_i=11 -st_o=5
$EXE -pr_i=$pr_i -t=333 -e=99 -k=13 -st_i=191 -st_o=17
done
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_api.hpp"
#if 0
template <typename T>
void dump_host_tensor_2d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 2);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto v = ck_tile::type_convert<float>(x(i, j));
std::cout << v;
if(j != len[1] - 1)
std::cout << ",";
}
else
{
std::cout << x(i, j) << " ";
}
}
std::cout << "]";
if(i != len[0] - 1)
std::cout << ",";
else
std::cout << "]";
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// CPU reference
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
ck_tile::HostTensor<WeightType>& y_values,
ck_tile::HostTensor<IndexType>& y_indices,
ck_tile::index_t k,
ck_tile::index_t dim = -1,
bool largest = true,
bool sorted = true)
{
using namespace ck_tile;
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
}
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
.insert("t", "32", "number of input tokens")
.insert("e", "8", "number of experts")
.insert("k", "2", "topk")
.insert("st_i", "-1", "row stride of input, -1 means same as experts")
.insert("st_o", "-1", "row stride of output/indices, -1 means same as topk")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "when set to 1 it will print kernel name")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool test_topk_softmax(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
int tokens = args.get_int("t");
int experts = args.get_int("e");
int topk = args.get_int("k");
int seed = args.get_int("seed");
int stride_input = args.get_int("st_i");
int stride_output = args.get_int("st_o");
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
if(stride_input < 0)
{
stride_input = experts;
}
if(stride_output < 0)
{
stride_output = topk;
}
assert(stride_input >= experts);
assert(stride_output >= topk);
if(seed < 0)
{
seed = std::time(nullptr);
}
if(topk > experts)
{
printf("topk:%d value should be smaller than, or equal to number of experts:%d\n",
topk,
experts);
return false;
}
// tokens already considered batch size
ck_tile::HostTensor<InputType> x_host({tokens, experts}, {stride_input, 1});
ck_tile::HostTensor<WeightType> value_host({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_host({tokens, topk}, {stride_output, 1});
{
// random require per-row unique
auto rand_gen = ck_tile::FillUniformDistribution_Unique<InputType>{
-5.f, 5.f, static_cast<uint32_t>(seed)};
for(int i_t = 0; i_t < tokens; i_t++)
{
ck_tile::HostTensor<InputType> x_row({experts});
rand_gen(x_row);
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input);
rand_gen.clear();
}
}
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
topk_softmax_trait trait{input_prec, weight_prec, experts};
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
value_dev.GetDeviceBuffer(),
index_dev.GetDeviceBuffer(),
tokens,
experts,
topk,
stride_input,
stride_output};
ck_tile::stream_config sc{nullptr,
true,
/* log_level = */ (kname ? 1 : 0),
warmup,
repeat};
auto ms = topk_softmax(trait, karg, sc);
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
input_prec.c_str(),
weight_prec.c_str(),
tokens,
experts,
topk,
stride_input,
stride_output,
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
value_dev.FromDevice(value_host.data());
index_dev.FromDevice(index_host.data());
bool rtn = true;
if(validate)
{
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
reference_topk_softmax<InputType, WeightType, IndexType>(
x_host, value_ref, index_ref, topk);
auto [rtol, atol] = get_elimit<InputType>("");
for(int i_t = 0; i_t < tokens; i_t++)
{
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
auto s_end =
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
auto s_value_host = value_host.slice(s_begin, s_end);
auto s_value_ref = value_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_value_host,
s_value_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Value Error:"),
rtol,
atol);
auto s_index_host = index_host.slice(s_begin, s_end);
auto s_index_ref = index_ref.slice(s_begin, s_end);
rtn &= ck_tile::check_err(s_index_host,
s_index_ref,
std::string("[") + std::to_string(i_t) +
std::string("] Index Error:"),
rtol,
atol);
}
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string input_prec = args.get_str("pr_i");
std::string weight_prec = args.get_str("pr_w");
bool r = true;
if(input_prec.compare("fp16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::fp16_t, float, ck_tile::index_t>(args);
}
else if(input_prec.compare("bf16") == 0 && weight_prec.compare("fp32") == 0)
{
r &= test_topk_softmax<ck_tile::bf16_t, float, ck_tile::index_t>(args);
}
return r ? 0 : -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "topk_softmax_api.hpp"
#define TOPK_SOFTMAX_DISPATCH(experts_) \
constexpr ck_tile::index_t ts_experts = experts_; \
using ts_problem = ck_tile:: \
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
constexpr dim3 blocks = kernel::BlockSize(); \
\
float ave_time = ck_tile::launch_kernel( \
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
#if 1
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#else
if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
#endif
}
else if(t.input_type == "bf16" && t.weight_type == "fp32")
{
#if 1
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
if(t.experts <= 8)
{
TOPK_SOFTMAX_DISPATCH(8)
}
else if(t.experts <= 16)
{
TOPK_SOFTMAX_DISPATCH(16)
}
else if(t.experts <= 32)
{
TOPK_SOFTMAX_DISPATCH(32)
}
else if(t.experts <= 64)
{
TOPK_SOFTMAX_DISPATCH(64)
}
else if(t.experts <= 128)
{
TOPK_SOFTMAX_DISPATCH(128)
}
else if(t.experts <= 192)
{
TOPK_SOFTMAX_DISPATCH(192)
}
#endif
}
return -1;
}
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
};
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
{
};
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);
set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding ${TILE_RMSNORM2D_FWD}")
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp)
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS})
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd")
add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp)
target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
# Rmsnorm2D forward
This folder contains example for Rmsnorm2D forward using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_rmsnorm2d_fwd -j
```
This will result in an executable `build/bin/tile_rmsnorm2d_fwd`
## cmdline
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
```
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/rmsnorm2d.hpp"
#include <cstring>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("e", "1e-5", "epsilon")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = n;
float epsilon = arg_parser.get_float("e");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= n);
using XDataType = DataType;
using YDataType = DataType;
using GammaDataType = DataType;
using InvRmsDataType = ck_tile::null_type;
using ComputeDataType = float;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {stride, 1});
ck_tile::HostTensor<GammaDataType> gamma_host({n});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
gamma_buf.ToDevice(gamma_host.data());
constexpr bool kTwoPass = true;
using BlockWarps = ck_tile::sequence<2, 2>;
using BlockTile = ck_tile::sequence<2, 128>;
using WarpTile = ck_tile::sequence<1, 64>;
using Vector = ck_tile::sequence<1, 1>;
using Shape = ck_tile::Rmsnorm2dShape<BlockTile, BlockWarps, WarpTile, Vector>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType,
Shape,
true, // kPadN
false, // kSaveInvRms
kTwoPass>;
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline>;
ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(),
gamma_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
nullptr,
epsilon,
m,
n,
stride};
auto kargs = Kernel::MakeKargs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat};
ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
bool pass = true;
if(do_validation)
{
// reference
ck_tile::reference_rmsnorm2d_fwd<XDataType,
GammaDataType,
ComputeDataType,
YDataType,
InvRmsDataType>(
x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon);
y_buf.FromDevice(y_host_dev.data());
auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3);
if(stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp> #include <ck_tile/core.hpp>
#include "layernorm2d_fwd.hpp" #include "rmsnorm2d_fwd.hpp"
template <typename DataType_, template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M ck_tile::index_t Repeat_M_, // each thread repeat along M
...@@ -11,142 +11,140 @@ template <typename DataType_, ...@@ -11,142 +11,140 @@ template <typename DataType_,
ck_tile::index_t ThreadPerBlock_N_, // num threads along N ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_, bool kPadN_,
bool kSaveMeanInvStd_, bool kSaveInvRms_,
bool kTwoPass_> bool kTwoPass_>
using trait_ = layernorm2d_fwd_traits_<DataType_, using trait_ = rmsnorm2d_fwd_traits_<DataType_,
Repeat_M_, Repeat_M_,
Repeat_N_, Repeat_N_,
ThreadPerBlock_M_, ThreadPerBlock_M_,
ThreadPerBlock_N_, ThreadPerBlock_N_,
Vector_N_, Vector_N_,
kPadN_, kPadN_,
kSaveMeanInvStd_, kSaveInvRms_,
kTwoPass_>; kTwoPass_>;
template <typename data_type> template <typename data_type>
float layernorm2d_fwd_b16_(layernorm2d_fwd_traits /*t*/, float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/,
layernorm2d_fwd_args a, rmsnorm2d_fwd_args a,
const ck_tile::stream_config& s) const ck_tile::stream_config& s)
{ {
#if 1 #if 1
float r = -1; float r = -1;
// clang-format off // clang-format off
// rm rn tm tn vn pd mv 2p // rm rn tm tn vn pd rms 2p
if(a.n <= 64) { if(a.n <= 64) {
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 1, true, false, false>>(s, a);
} }
else if(a.n <= 128) { else if(a.n <= 128) {
if (a.n % 2 == 0) if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 1, true, false, false>>(s, a);
} }
else if(a.n <= 256) { else if(a.n <= 256) {
if (a.n % 4 == 0) if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 1, true, false, false>>(s, a);
} }
else if(a.n <= 512) { else if(a.n <= 512) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 4, 64, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 4, 64, 1, true, false, false>>(s, a);
} }
else if(a.n <= 768) { else if(a.n <= 768) {
if (a.n % 4 == 0) if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 4, 64, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1,12, 4, 64, 1, true, false, false>>(s, a);
} }
else if(a.n <= 1024) { else if(a.n <= 1024) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 2, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 2, 128, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 1, true, false, false>>(s, a);
} }
else if(a.n <= 1536) { else if(a.n <= 1536) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 4, 64, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 2, 128, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 1, true, false, false>>(s, a);
} }
else if(a.n <= 2048) { else if(a.n <= 2048) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 8, 1, 256, 1, true, false, false>>(s, a);
} }
else if(a.n <= 3072) { else if(a.n <= 3072) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 128, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 6, 1, 256, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 3, 1, 1024, 1, true, false, false>>(s, a);
} }
else if(a.n <= 4096) { else if(a.n <= 4096) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, false>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, false>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, false>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, false>>(s, a);
} }
else if(a.n > 4096) { else if(a.n > 4096) {
if (a.n % 8 == 0) if (a.n % 8 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 256, 8, true, false, true>>(s, a);
else if (a.n % 4 == 0) else if (a.n % 4 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 256, 4, true, false, true>>(s, a);
else if (a.n % 2 == 0) else if (a.n % 2 == 0)
r = layernorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 2, 1, 1024, 2, true, false, true>>(s, a);
else else
r = layernorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a); r = rmsnorm2d_fwd_<trait_<data_type, 1, 4, 1, 1024, 1, true, false, true>>(s, a);
} }
return r; return r;
#else #else
return layernorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a); return rmsnorm2d_fwd_<trait_<data_type, 1, 1, 1, 256, 4, true, false, false>>(s, a);
#endif #endif
// clang-format on // clang-format on
} }
float layernorm2d_fwd(layernorm2d_fwd_traits t, float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s)
layernorm2d_fwd_args a,
const ck_tile::stream_config& s)
{ {
float r = -1; float r = -1;
if(t.data_type.compare("fp16") == 0) if(t.data_type.compare("fp16") == 0)
{ {
return layernorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s); return rmsnorm2d_fwd_b16_<ck_tile::fp16_t>(t, a, s);
} }
else if(t.data_type.compare("bf16") == 0) else if(t.data_type.compare("bf16") == 0)
{ {
return layernorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s); return rmsnorm2d_fwd_b16_<ck_tile::bf16_t>(t, a, s);
} }
if(r < 0) if(r < 0)
throw std::runtime_error("Without supported instances!"); throw std::runtime_error("Without supported instances!");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
#if 0
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , false, false>>(const S&, A);
#endif
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd rms 2p
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false, false>>(const S&, A);
template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false, false>>(const S&, A);
// clang-format on
...@@ -2,11 +2,11 @@ ...@@ -2,11 +2,11 @@
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "layernorm2d_fwd_instance_common.hpp" #include "rmsnorm2d_fwd_instance_common.hpp"
// clang-format off // clang-format off
// rm rn tm tn vn pd mv 2p // rm rn tm tn vn pd rms 2p
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A); template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A); template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true , false, false>>(const S&, A);
template float layernorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A); template float rmsnorm2d_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true , false, false>>(const S&, A);
// clang-format on // clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment