Commit 49c39b51 authored by carlushuang's avatar carlushuang
Browse files

moe pipeline

parent 03c6448b
......@@ -40,7 +40,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......@@ -83,7 +83,7 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
else if(t.permute.compare("0,1,3,4,2,5") == 0)
{
constexpr matrix_core_permute_style pstyle =
matrix_core_permute_style::permute_b_nr_kr_kw_nw_kv;
matrix_core_permute_style::b_nr_kr_kw_nw_kv;
using Kernel =
matrix_core_swizzle_kernel<BLOCK_SIZE, NPerBlock, KPerBlock, pstyle, Inst>;
......
......@@ -42,8 +42,8 @@ enum class matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
// assume this is B matrix, originally we have batch*n*k
......@@ -203,7 +203,7 @@ struct matrix_core_swizzle_kernel
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
// b_nr_kr_kw_nw_kv or b_nr_kr_waveflatten
constexpr index_t Kv = Alignment;
constexpr index_t Nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t Kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -332,7 +332,7 @@ struct matrix_core_swizzle_kernel
make_tuple(sequence<0>{}, sequence<1>{}));
return tmp_1;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......@@ -376,13 +376,13 @@ struct matrix_core_swizzle_kernel
else
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
return make_tile_window(dst_view,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{i_n * NPerBlock, i_k * KPerBlock},
get_dst_dist());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv
constexpr index_t kv = Alignment;
constexpr index_t nw = WarpGemm::WarpGemmAttribute::Impl::kAMLane;
constexpr index_t kw = WarpGemm::WarpGemmAttribute::Impl::kABKLane;
......
......@@ -264,7 +264,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
if(arg_parser.get_str("perm") == std::string("0,1,3,4,2,5"))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
// b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits t;
t.data_type = data_type;
t.permute = arg_parser.get_str("perm");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/layernorm2d.hpp"
#include <string>
// this is only a convenient structure for creating an example
// this is not part of the host API
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig;
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
{
using ADataType = ck_tile::bf16_t;
using GDataType = ck_tile::bf16_t;
using DDataType = ck_tile::bf16_t;
using AccDataType = float;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>;
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
template <typename ST, typename SW, typename SQ, typename KW>
struct FusedMoeGemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, ck_tile::bf16_t, ST, SW, SQ, KW>;
{
using ADataType = ck_tile::int8_t;
using GDataType = ck_tile::int8_t;
using DDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using ODataType = ck_tile::bf16_t;
using AScaleDataType = ck_tile::remove_cvref_t<ST>;
using W0ScaleDataType = ck_tile::remove_cvref_t<SW>;
using W1ScaleDataType = ck_tile::remove_cvref_t<SW>;
using YSmoothScaleDataType = ck_tile::remove_cvref_t<SQ>;
using TopkWeightDataType = ck_tile::remove_cvref_t<KW>;
using IndexDataType = ck_tile::index_t;
};
// runtime args
struct fused_moegemm_args : public ck_tile::Layernorm2dFwdHostArgs
{
};
// This is the public API, will be generated by script
struct fused_moegemm_traits
{
std::string prec_i; // input precision
std::string prec_w; // weight precision
std::string prec_o; // output precision
std::string prec_st; // token scale data type
std::string prec_sw; // weight scale data type
std::string prec_sq; // smooth quant scale
std::string prec_kw; // topk-weight data type
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
};
float fused_moegemm(fused_moegemm_traits, fused_moegemm_args, const ck_tile::stream_config&);
#include "ck_tile/host.hpp"
#include "layernorm2d_fwd.hpp"
#include <algorithm>
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
// mfma_type, 0:32x32, 1:16x16
template<typename H>
auto shuffle_moe_weight(const H& t, std::string mfma_dtype, int mfma_type = 0)
{
static_assert(t.get_lengths().size() == 3);
int b_ = t.get_lengths()[0];
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[2];
if ((mfma_dtype == "bf16" || mfma_dtype == "fp16") && mfma_type == 0) {
std::vector<ck_tile::index_t> new_lens {b_, n_/32, 32, k_/16, 2, 8};
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("t", "128", "num input tokens")
.insert("e", "32", "num of experts")
.insert("k", "5", "topk")
.insert("h", "8192", "hidden_size of this model")
.insert("i", "8192", "intermediate_size between 2 gemms of FFN")
.insert("stride", "-1", "stride per row, if -1 then equal to hidden_size")
.insert("bm", "32", "blocking factor for sorted tokens")
.insert("tp", "8", "tensor parallel size")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec_i", "bf16", "input precision")
.insert("prec_w", "bf16", "weight precision")
.insert("prec_o", "bf16", "output precision")
.insert("prec_st", "auto", "token scale data type. auto will set to fp32")
.insert("prec_sw", "auto", "weight scale data type. auto will set to fp32")
.insert("prec_sq", "auto", "(dynamic) smooth quant data type. auto will set to fp32")
.insert("prec_kw", "auto", "topk-weight data type. auto will set to fp32")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("gonly", "0", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("balance", "1", "if set to 1, will try balance the expert in topk-ids(convenient for testing)")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// I:input-type, W:weight-type, O:output-type, ST:toke-scale-tpye, SW:weight-scale-type, SQ:smooth-quant-type, KW:topk-weight-type
template <typename I, typename W, typename O, typename ST, typename SW, typename SQ, typename KW>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t tokens = arg_parser.get_int("t");
ck_tile::index_t experts = arg_parser.get_int("e");
ck_tile::index_t topk = arg_parser.get_int("k");
ck_tile::index_t hidden_size = arg_parser.get_int("h");
ck_tile::index_t intermediate_size = arg_parser.get_int("i");
ck_tile::index_t stride = arg_parser.get_int("stride");
ck_tile::index_t block_m = arg_parser.get_int("bm");
if(stride < 0)
stride = hidden_size;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_w = arg_parser.get_str("prec_w");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_st = arg_parser.get_str("prec_st");
std::string prec_sw = arg_parser.get_str("prec_sw");
std::string prec_sq = arg_parser.get_str("prec_sq");
std::string prec_kw = arg_parser.get_str("prec_kw");
prec_st = (prec_st == "auto") ? "fp32" : prec_st;
prec_sw = (prec_sw == "auto") ? "fp32" : prec_sw;
prec_sq = (prec_sq == "auto") ? "fp32" : prec_sq;
prec_kw = (prec_kw == "auto") ? "fp32" : prec_kw;
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int fused_quant = arg_parser.get_int("fquant");
int gonly = arg_parser.get_int("gonly");
int balance = arg_parser.get_int("balance");
int tp = arg_parser.get_int("tp");
ck_tile::index_t shared_intermediate_size = intermediate_size * (gonly ? 1 : 2) / tp;
using TypeConfig = FusedMoeGemmTypeConfig<I, W, O, ST, SW, SQ, KW>;
using ADataType = typename TypeConfig::ADataType ;
using GDataType = typename TypeConfig::GDataType ;
using DDataType = typename TypeConfig::DDataType ;
using AccDataType = typename TypeConfig::AccDataType ;
using ODataType = typename TypeConfig::ODataType ;
using AScaleDataType = typename TypeConfig::AScaleDataType ;
using W0ScaleDataType = typename TypeConfig::W0ScaleDataType ;
using W1ScaleDataType = typename TypeConfig::W1ScaleDataType ;
using YSmoothScaleDataType = typename TypeConfig::YSmoothScaleDataType;
using TopkWeightDataType = typename TypeConfig::TopkWeightDataType ;
using IndexDataType = typename TypeConfig::IndexDataType ;
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<ADataType> g_host({e, shared_intermediate_size, hidden_size});
ck_tile::HostTensor<ADataType> d_host({e, intermediate_size, hidden_size});
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {stride, 1});
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
ck_tile::HostTensor<XScaleDataType> x_scale_host({n});
ck_tile::HostTensor<XScaleDataType> x_scale_host_dev({n});
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
ck_tile::FillUniformDistribution<XScaleDataType>{-1.f, 1.f}(x_scale_host);
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
ck_tile::DeviceMem x_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_scale_buf(x_scale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
x_buf.ToDevice(a_host.data());
gamma_buf.ToDevice(gamma_host.data());
beta_buf.ToDevice(beta_host.data());
x_residual_buf.ToDevice(x_residual_host.data());
x_scale_buf.ToDevice(x_scale_host.data());
auto prec_str = [&]() {
auto base_str = prec_i;
if(prec_i != prec_o)
{
base_str += "|" + prec_o;
}
if(fused_quant == 1)
{
base_str += std::string("(") + prec_sy + ")";
}
return base_str;
}();
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
layernorm2d_fwd_traits traits{
prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant};
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant == 1 ? x_scale_buf.GetDeviceBuffer() : nullptr,
gamma_buf.GetDeviceBuffer(),
beta_buf.GetDeviceBuffer(),
y_buf.GetDeviceBuffer(),
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
nullptr, // p_mean, unsupported yet
nullptr, // p_invStd, unsupported yet
epsilon,
m,
n,
stride};
float ave_time = layernorm2d_fwd(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
if(ave_time < 0)
{
std::cout << " not supported!" << std::endl << std::flush;
return false;
}
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(GammaDataType) * n +
sizeof(BetaDataType) * n + sizeof(YDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
// reference
if(fused_add != 0)
{
// fused pre_add/pre_add_store
// TODO we accumulate directly to a_host for simplcity here...
std::transform(a_host.mData.cbegin(),
a_host.mData.cend(),
x_residual_host.mData.cbegin(),
a_host.mData.begin(),
[](auto x_, auto r_) {
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
ck_tile::type_convert<ComputeDataType>(r_);
return ck_tile::type_convert<ADataType>(o_);
});
}
ck_tile::reference_layernorm2d_fwd<ADataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
if(fused_quant != 0)
{
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
int N_ = acc_.mDesc.get_lengths()[1];
if(fused_quant == 1)
{
for(int n_ = 0; n_ < N_; n_++)
{
// input smooth outlier
acc_(m_, n_) =
acc_(m_, n_) * ck_tile::type_convert<ComputeDataType>(x_scale_host(n_));
}
}
ComputeDataType absmax = static_cast<ComputeDataType>(0);
for(int n_ = 0; n_ < N_; n_++)
{
const auto a = ck_tile::abs(acc_(m_, n_));
absmax = a > absmax ? a : absmax;
}
// printf("cpu:absmax:%f\n", absmax);
ComputeDataType y_scale = absmax / static_cast<ComputeDataType>(127.0);
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
for(int n_ = 0; n_ < N_; n_++)
{
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
}
};
ck_tile::reference_layernorm2d_fwd<ADataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(a_host,
gamma_host,
beta_host,
y_host_ref,
mean_host_ref,
invStd_host_ref,
epsilon,
dquant_functor);
}
else
{
ck_tile::reference_layernorm2d_fwd<ADataType,
GammaDataType,
BetaDataType,
ComputeDataType,
YDataType,
MeanDataType,
InvStdDataType>(
a_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
}
y_buf.FromDevice(y_host_dev.data());
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {stride, 1});
if(fused_add == 1)
{
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
auto [rtol, atol] = get_elimit<InDataType>();
if(stride == n)
{
pass = ck_tile::check_err(
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
if(fused_add == 1)
{
pass &= ck_tile::check_err(y_residual_host_dev,
a_host,
std::string("ADD Error: Incorrect results!"),
rtol,
atol);
}
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * stride,
y_host_dev.begin() + i_r * stride + n);
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * stride,
y_host_ref.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_host_dev_row,
y_host_ref_row,
std::string("OUT[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
if(fused_add == 1)
{
std::vector<YResidualDataType> y_residual_host_dev_row(
y_residual_host_dev.begin() + i_r * stride,
y_residual_host_dev.begin() + i_r * stride + n);
std::vector<YResidualDataType> y_residual_host_ref_row(
a_host.begin() + i_r * stride, a_host.begin() + i_r * stride + n);
pass &= ck_tile::check_err(y_residual_host_dev_row,
y_residual_host_ref_row,
std::string("ADD[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
if(fused_quant == 1)
{
y_scale_buf.FromDevice(y_scale_host_dev.data());
pass &= ck_tile::check_err(y_scale_host_dev,
y_scale_host_ref,
std::string("SCALE Error: Incorrect results!"),
rtol,
atol);
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string prec_i = arg_parser.get_str("prec_i");
std::string prec_o = arg_parser.get_str("prec_o");
std::string prec_sx = arg_parser.get_str("prec_sx");
std::string prec_sy = arg_parser.get_str("prec_sy");
if(prec_o == "auto")
{
prec_o = prec_i;
}
if(prec_sx == "auto")
{
prec_sx = "fp32";
}
if(prec_sy == "auto")
{
prec_sy = "fp32";
}
int save_mv = arg_parser.get_int("save_mv");
// no dynamic quant case
if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
}
// dynamic quant case, only in inference
else if(prec_i == "fp16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
else if(prec_i == "bf16" && prec_o == "int8" && prec_sx == "fp32" && prec_sy == "fp32")
{
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
}
return -3;
}
......@@ -53,6 +53,11 @@ template<> struct buffer_load_trait<4 , thread_buffer<bf16_t, 2>> { using payloa
// clang-format on
} // namespace impl
// TODO: this is hot-tmp fix to unblock user case. Need refactor into template arg
#ifndef CK_TILE_BUFFER_LOAD_AGPR
#define CK_TILE_BUFFER_LOAD_AGPR 0
#endif
// TODO: glc/slc/...
template <index_t bytes, bool pre_nop = false>
struct buffer_load;
......@@ -74,6 +79,19 @@ struct buffer_load<16, pre_nop>
{
static_assert(sizeof(T) == 16);
using mbuf_t = typename impl::buffer_load_trait<16, T>::payload_t;
#if CK_TILE_BUFFER_LOAD_AGPR
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "=a"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
else
asm volatile("buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
: "=a"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#else
if constexpr(pre_nop)
asm volatile("s_nop 4\n"
"buffer_load_dwordx4 %0, %1, %2, 0 offen offset:%3"
......@@ -85,6 +103,7 @@ struct buffer_load<16, pre_nop>
: "+v"(reinterpret_cast<mbuf_t&>(value))
: "v"(v_offset), "s"(res), "n"(i_offset)
: "memory");
#endif
}
};
......@@ -621,6 +640,60 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm volatile("s_waitcnt vmcnt(%0)" : : "n"(cnt) : "memory");
}
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add_if;
template <bool pre_nop>
struct buffer_atomic_add_if<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t flag = 1)
{
static_assert(sizeof(T) == 4);
auto save_exec = __builtin_amdgcn_read_exec();
using mbuf_t = float;
asm volatile("v_cmpx_le_u32 exec, 1, %4\n"
"global_atomic_pk_add_bf16 %0, %1, %2 offset:%3\n"
"s_mov_b64 exec %5"
:
: "v"(v_offset),
"v"(bit_cast<mbuf_t>(value)),
"s"(res.xy),
"n"(i_offset),
"v"(flag),
"s"(save_exec)
: "memory");
}
};
template <typename scalar_type, index_t N, bool pre_nop = false>
struct buffer_atomic_add;
template <bool pre_nop>
struct buffer_atomic_add<bf16_t, 2, pre_nop>
{
template <typename T>
CK_TILE_DEVICE void operator()(const T& value,
int32x4_t res /*buffer resource*/,
index_t v_offset,
index_t /*s_offset*/,
index_t i_offset /*max 0xFFF*/,
index_t /*flag = 1*/)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
asm volatile("global_atomic_pk_add_bf16 %0, %1, %2 offset:%3"
:
: "v"(v_offset), "v"(bit_cast<mbuf_t>(value)), "s"(res.xy), "n"(i_offset)
: "memory");
}
};
namespace impl {
// below type indicate the data type used for buffer load inline asm
// clang-format off
......@@ -2378,6 +2451,45 @@ CK_TILE_DEVICE void amd_buffer_atomic_add(const thread_buffer<T, N>& src_thread_
#endif
}
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void amd_buffer_atomic_add_raw(const thread_buffer<T, N>& src_thread_data,
T* p_dst_wave,
const index_t dst_thread_element_offset,
const index_t dst_linear_element_offset,
const bool dst_thread_element_valid,
const index_t dst_element_space_size,
bool_constant<pre_nop> = {})
{
const int32x4_t dst_wave_buffer_resource =
make_wave_buffer_resource(p_dst_wave, dst_element_space_size * sizeof(T));
index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T);
index_t dst_linear_addr_offset = dst_linear_element_offset * sizeof(T);
if constexpr(oob_conditional_check)
{
buffer_atomic_add_if<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
dst_thread_element_valid);
}
else
{
buffer_atomic_add<T, N, pre_nop>{}(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
0,
dst_linear_addr_offset,
1);
}
}
// buffer_atomic_max requires:
// 1) p_dst_wave must point to global memory
// 2) p_dst_wave must be a wavewise pointer.
......
......@@ -437,34 +437,74 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
CK_TILE_DEVICE void update(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set<X>(i, linear_offset, is_valid_element, x);
this->template set<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add<X>(i, linear_offset, is_valid_element, x);
this->template atomic_add<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
this->template atomic_max<X>(i, linear_offset, is_valid_element, x);
this->template atomic_max<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x);
}
// FIXME: remove memory_operation_enum::add
else if constexpr(Op == memory_operation_enum::add)
{
auto tmp = this->template get<X>(i, linear_offset, is_valid_element);
this->template set<X>(i, linear_offset, is_valid_element, x + tmp);
auto tmp =
this->template get<X, oob_conditional_check>(i, linear_offset, is_valid_element);
this->template set<X, oob_conditional_check>(
i, linear_offset, is_valid_element, x + tmp);
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
}
}
// i is offset of T, not X. i should be aligned to X
template <memory_operation_enum Op,
typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void update_raw(index_t i,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
if constexpr(Op == memory_operation_enum::set)
{
this->template set_raw<X, oob_conditional_check>(i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_add)
{
this->template atomic_add_raw<X, oob_conditional_check, pre_nop>(
i, linear_offset, is_valid_element, x);
}
else if constexpr(Op == memory_operation_enum::atomic_max)
{
// this->template atomic_max_raw<X>(i, linear_offset, is_valid_element, x);
}
}
// i is offset of T, not X. i should be aligned to X
template <typename X,
bool oob_conditional_check = true,
......@@ -533,6 +573,7 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......@@ -585,6 +626,57 @@ struct buffer_view<address_space_enum::global,
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
bool>::type = false>
CK_TILE_DEVICE void
atomic_add_raw(index_t i, index_t linear_offset, bool is_valid_element, const X& x)
{
using scalar_t = typename vector_traits<remove_cvref_t<T>>::scalar_type;
// X contains multiple T
constexpr index_t scalar_per_t_vector = vector_traits<remove_cvref_t<T>>::vector_size;
constexpr index_t scalar_per_x_vector = vector_traits<remove_cvref_t<X>>::vector_size;
static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
"wrong! X should contain multiple T");
static_assert(get_address_space() == address_space_enum::global, "only support global mem");
#if CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#elif CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, int32_t>;
#elif(!CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool constexpr use_amd_buffer_addressing =
std::is_same_v<remove_cvref_t<scalar_t>, float> ||
(std::is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
(std::is_same_v<remove_cvref_t<scalar_t>, bf16_t> && scalar_per_x_vector % 2 == 0);
#else
bool constexpr use_amd_buffer_addressing = false;
#endif
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
amd_buffer_atomic_add_raw<remove_cvref_t<T>,
t_per_x,
Coherence,
oob_conditional_check,
pre_nop>(
x, p_data_, i, linear_offset, is_valid_element, buffer_size_);
}
template <typename X,
bool oob_conditional_check = true,
typename std::enable_if<
std::is_same<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<T>>::scalar_type>::value,
......
......@@ -22,28 +22,32 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(number<-1>{}, bool_constant<oob_conditional_check>{});
return tile_window.load(number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename DistributedTensor_,
......@@ -51,15 +55,17 @@ template <typename DistributedTensor_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, bool_constant<oob_conditional_check>{});
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
/**
......@@ -76,6 +82,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -83,11 +90,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename T,
......@@ -95,6 +103,7 @@ template <typename T,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
......@@ -102,11 +111,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.load_raw(
tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
tile, number<i_access>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -114,6 +124,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto
......@@ -122,11 +133,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename LdsTileWindow_,
......@@ -134,6 +148,7 @@ template <typename LdsTileWindow_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
......@@ -141,11 +156,14 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_,
TileDistribution_,
LinearBottomDims_>& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
return tile_window.async_load_raw(
lds_tile, number<-1>{}, bool_constant<oob_conditional_check>{}, bool_constant<pre_nop>{});
return tile_window.async_load_raw(lds_tile,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
CK_TILE_DEVICE auto async_load_fence(index_t cnt = 0)
......
......@@ -333,6 +333,48 @@ struct tensor_view
coord.get_offset(), linear_offset, is_valid_element, x);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(),
linear_offset,
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
template <typename X,
bool oob_conditional_check = true,
bool pre_nop = false,
typename std::enable_if<
std::is_same_v<typename vector_traits<remove_cvref_t<X>>::scalar_type,
typename vector_traits<remove_cvref_t<DataType>>::scalar_type>,
bool>::type = false>
CK_TILE_HOST_DEVICE constexpr void
update_vectorized_elements_raw(const TensorCoord& coord,
index_t linear_offset,
bool is_valid_element,
const X& x,
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
buf_.template update_raw<DstInMemOp, X, oob_conditional_check, pre_nop>(
coord.get_offset(), linear_offset, is_valid_element, x);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("tensor_view{");
......
......@@ -785,6 +785,73 @@ struct tile_window_with_static_distribution
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
: idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d =
tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
0,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>);
// move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
......
......@@ -849,6 +849,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE();
}
template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
CK_TILE_DEVICE void update_raw(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {}) const
{
using vector_t = typename traits::vector_t;
using SFC_Ys = typename traits::SFC_Ys;
constexpr auto tile_dstr = TileDstr{};
// loop over thread tensor space [y0, y1, ...]
auto issue = [&](auto i_access_) {
constexpr auto IAccess = number<i_access_>{};
constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
auto bottom_tensor_flag = cached_flags_[IAccess];
// data index [y0, y1, ...]
constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
// read from distributed tensor
vector_t vec_value;
static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
constexpr auto idx_ys = generate_tuple(
[&](auto jj) {
return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
},
number<NDimY>{});
constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
vec_value.template get_as<DataType>()(j) =
dstr_tensor.get_thread_buffer().template at<d>();
});
// write into bottom tensor
get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
bottom_tensor_thread_coord,
linear_offset,
bottom_tensor_flag,
vec_value,
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
};
WINDOW_DISPATCH_ISSUE();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
......
......@@ -41,15 +41,64 @@ template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_>
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void
update_tile(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor)
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
tile_window.update(dstr_tensor);
tile_window.update(dstr_tensor, number<i_access>{}, bool_constant<oob_conditional_check>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename DataType_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE void
update_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
template <typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
typename LinearBottomDims_,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto update_tile_raw(
tile_window_linear<BottomTensorView_, WindowLengths_, TileDistribution_, LinearBottomDims_>&
tile_window,
const static_distributed_tensor<DataType_, TileDistribution_>& dstr_tensor,
number<i_access> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<pre_nop> = {})
{
tile_window.update_raw(dstr_tensor,
number<i_access>{},
bool_constant<oob_conditional_check>{},
bool_constant<pre_nop>{});
}
} // namespace ck_tile
......@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template <typename DataType>
CK_TILE_HOST void
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> dims)
reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
{
const auto x_len = x.mDesc.get_lengths();
const auto y_len = y.mDesc.get_lengths();
......@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std::vector<size_t> tmp(rank, 0);
for(index_t i = 0; i < rank; i++)
{
tmp[dims[i]] = y_coord[i];
tmp[perm[i]] = y_coord[i];
}
return tmp;
}();
......@@ -54,4 +54,24 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
}
template <typename DataType>
CK_TILE_HOST auto
reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
{
auto x_shape = x.get_lengths();
ck_tile::index_t rank = perm.size();
std::vector<ck_tile::index_t> y_shape = [&]() {
std::vector<ck_tile::index_t> tmp(rank, 0);
for(int i = 0; i < static_cast<int>(rank); i++)
{
tmp[i] = x_shape[perm[i]];
}
return tmp;
}();
HostTensor<DataType> y(y_shape);
reference_permute(x, y, perm);
return y;
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moe_kernel.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_shape.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moe_tile_partitioner.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_flatmm_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_pipeline_problem.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moe_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include <string>
#include <type_traits>
// clang-format off
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, top_k=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_tokens_post_padded : top_k * input_tokens + num_experts * (M_a - 1)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
//
// * length is max_tokens_post_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_tokens_post_padded + block_size - 1) / block_size
//
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
//
// * different from vLLM
// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
// 2)need sorted_weight_ptr
// 3) use num_sorted_tiles_ptr, already divided by M_a
//
// * below used for indexing
// 1) sorted_token_ids_ptr
// 2) sorted_weight_ptr
// 3) sorted_expert_ids_ptr
// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
//
//
// [indexing implementation-2]
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
//
// we generate original rol/col id as
// topk_rc_ids : [[0, 5, A], [1, 6, B], [2, 7, C], [3, 8, D], [4, 9, E]]
// let x be one element of above, we can get:
// tpok_row_id(token_id) = x % num_tokens(5)
// tpok_col_id(expert_Id) = x / num_tokens
// topk_row_id/col_id can be used to access original topk_ids/topk_weight
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 5, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// we can get permuted_rc_ids:
// [[0], [2, 3, 4], [1, 8], [5, 6, 7, D, 9], [], [A, B, C, E]]
//
//
// clang-format on
//
namespace ck_tile {
// m: num_tokens (or token*input-batch)
// k: intermediate_size
// n: intermediate_size used between 2 FC (TP slice this)
// e: num expert
// if doing pre-shuffle
// nr : n / Block_Nr
// kr : k / Block_Kr
// w : fattened 1d wave buffer
struct FusedMoeGemmHostArgs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
// index_t top_k; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// This is scatter/gather b2b group-gemm
template <typename Partitioner_, typename Pipeline_, typename Epilogue_>
struct FusedMoeGemmKernel
{
using Partitioner = remove_cvref_t<Partitioner_>;
using Pipeline = remove_cvref_t<Pipeline_>;
using Epilogue = remove_cvref_t<Epilogue_>; // TODO: not used
static constexpr index_t kBlockSize = Pipeline::kBlockSize;
// static constexpr index_t kBlockPerCu = Pipeline::kBlockPerCu;
// static_assert(kBlockPerCu > 0);
using BlockShape = typename Pipeline::BlockShape; // this is FusedMoeGemmShape
using ADataType = typename Pipeline::Problem::ADataType;
using GDataType = typename Pipeline::Problem::GDataType;
using DDataType = typename Pipeline::Problem::DDataType;
using AccDataType = typename Pipeline::Problem::AccDataType;
using ODataType = typename Pipeline::Problem::ODataType;
using AScaleDataType = typename Pipeline::Problem::AScaleDataType;
using GScaleDataType = typename Pipeline::Problem::GScaleDataType;
using DScaleDataType = typename Pipeline::Problem::DScaleDataType;
using YSmoothScaleDataType = typename Pipeline::Problem::YSmoothScaleDataType;
using TopkWeightDataType = typename Pipeline::Problem::TopkWeightDataType;
using IndexDataType = typename Pipeline::Problem::IndexDataType;
using YDataType = typename Pipeline::Problem::YDataType;
using Traits = typename Pipeline::Problem::Traits;
static constexpr bool IsGateOnly = Traits::IsGateOnly;
static constexpr bool UseSmoothQuant = Traits::UseSmoothQuant;
static constexpr bool PadHiddenSize = Traits::PadHiddenSize;
static constexpr bool PadIntermediateSize = Traits::PadIntermediateSize;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
// clang-format on
}
struct FusedMoeGemmKargs
{
const void* a_ptr; // [m, k], input token
const void* a_scale_ptr; // [m, 1], token scale
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
const void* d_scale_ptr; // [e, 1, k], down scale
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
void* o_ptr; // [m, k], output token
const void* sorted_token_ids_ptr;
const void* sorted_weight_ptr;
const void* sorted_expert_ids_ptr;
const void* num_sorted_tiles_ptr;
index_t hidden_size; // k
index_t intermediate_size; // n (TP slice this)
index_t num_tokens; // input number of tokens for current iteration
index_t num_experts; // number of groups
// index_t top_k; // need this?
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
};
// TODO: switch karg based on
using Kargs = FusedMoeGemmKargs;
using Hargs = FusedMoeGemmHostArgs;
CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
{
// TODO: hargs/kargs not guranteed to be the same
return bit_cast<Kargs>(hargs);
}
CK_TILE_HOST static constexpr auto GridSize(index_t num_cu, index_t blocks_per_cu)
{
return TilePartitioner::GridSize(num_cu, blocks_per_cu);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return max(Pipeline::GetSmemSize(), Epilogue::GetSmemSize());
}
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// allocate LDS
// __shared__ char smem_ptr[GetSmemSize()];
IndexDataType num_sorted_tiles = __builtin_amdgcn_readfirstlane(
*reinterpret_cast<const IndexDataType*>(kargs.num_sorted_tiles_ptr));
constexpr index_t hidden_radio_0 = IsGateOnly ? 1 : 2;
index_t nr_0 = kargs.intermediate_size / Pipeline::Block_Nr0;
index_t kr_0 = kargs.hidden_size / Pipeline::Block_Kr0;
index_t nr_1 = kargs.hidden_size / Pipeline::Block_Nr1; // should be same as kr_0
index_t kr_1 = kargs.intermediate_size / Pipeline::Block_Kr1; // should be same as nr_0
index_t expert_stride_0 = kargs.intermediate_size * hidden_radio_0 * kargs.hidden_size;
index_t expert_stride_1 = kargs.intermediate_size * kargs.hidden_size;
__shared__ CK_TILE_LDS_ADDR ADataType smem[GetSmemSize()];
// note this is in unit of tile, need multiple tile size to get the index
const auto [sorted_tile_id, hidden_tile_id] =
TilePartitioner{}(num_sorted_tiles, kargs.intermediate_size);
if(sorted_tile_id >= num_sorted_tiles)
return;
const IndexDataType expert_id = __builtin_amdgcn_readfirstlane(
reinterpret_cast<const IndexDataType*>(kargs.sorted_expert_ids_ptr)[sorted_tile_id]);
// index along intermediate_size
index_t hidden_idx = __builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_N0);
index_t hidden_idx_nr =
__builtin_amdgcn_readfirstlane(hidden_tile_id * BlockShape::Block_Nr0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + sorted_tile_id * BlockShape::Block_M0;
index_t token_id =
reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
const auto a_window = [&]() {
// A is already pre-padded in previous kernel
const ADataType* a_ptr = reinterpret_cast<const ADataType*>(kargs.a_ptr);
const auto a_view_ = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentA>{},
number<1>{});
// gather is here use indexing transform
const auto a_gather_view_ = transform_tensor_view(
a_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto a_window_ = make_tile_window(
a_gather_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<Pipeline::Block_K0>{}),
{0, 0});
return a_window_;
}();
// TODO: gtile using NSub to have less register pressure
const auto g_window = [&]() {
const GDataType* g_ptr = reinterpret_cast<const GDataType*>(kargs.g_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_0 +
hidden_idx_nr * kr_0 * BlockShape::Block_W0;
const auto g_view_ = make_naive_tensor_view<address_space_enum::global>(
g_ptr,
make_tuple(nr_0, kr_0, number<Pipeline::Block_W0>{}),
make_tuple(kr_0 * BlockShape::Block_W0, number<Pipeline::Block_W0>{}, 1),
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_view_1_ = pad_tensor_view(g_view_,
make_tuple(number<Pipeline::Block_Nr0>{},
number<Pipeline::Block_Kr0>{},
number<Pipeline::Block_W0>{}),
sequence<PadIntermediateSize, PadHiddenSize, 0>{});
const auto g_window_ = make_tile_window(g_view_1_,
make_tuple(number<BlockShape::Block_Nr0>{},
number<Pipeline::Block_Kr0>{},
number<Pipeline::Block_W0>{}),
{0, 0, 0});
return g_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = [&]() {
reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1 +
hidden_idx_nr* BlockShape::Block_W1;
// note hidden_idx_nr is along the gemm-k dim of 2nd gemm
}();
const auto d_view_ = make_naive_tensor_view<address_space_enum::global>(
d_ptr,
make_tuple(nr_1, kr_1, Pipeline::Block_W1),
make_tuple(kr_1 * Pipeline::Block_W1, Pipeline::Block_W1, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_view_1_ = pad_tensor_view(d_view_,
make_tuple(number<Pipeline::kBlockNr_1>{},
number<Pipeline::kBlockKr_1>{},
number<Pipeline::Block_W1>{}),
sequence<PadHiddenSize, PadIntermediateSize, 0>{});
const auto d_window_ = make_tile_window(d_view_1_,
make_tuple(number<Pipeline::kBlockNr_1>{},
number<Pipeline::kBlockKr_1>{},
number<Pipeline::Block_W1>{}),
{0, 0, 0});
return d_window_;
}();
auto o_window = [&]() {
const ODataType* o_ptr = reinterpret_cast<const ODataType*>(kargs.o_ptr);
const auto o_view_ = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
o_ptr,
make_tuple(kargs.num_tokens, kargs.hidden_size),
make_tuple(kargs.stride_token, 1),
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
const auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
make_pass_through_transform(kargs.hidden_size)),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto o_window_ = make_tile_window(
o_scatter_view_,
make_tuple(number<BlockShape::Block_M1>{}, number<Pipeline::Block_N1>{}),
{0, 0});
return o_window_;
}();
// do compute yeah
Pipeline{}(a_window,
g_window,
d_window,
o_window,
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
tensors:
1. act (A): input feature map
2. gate (G): B matrix for first gemm, output will do activation(Silu)
3. up (U): B matrix for first gemm
4. down (D): B matrix for second gemm
N1
/ \
+----------+ |
| Down | |
x----------x |
hidden hidden K1 | | |
N0 N0 x----------x |
| +------x-----x------+------x-----x------+ | | |
dim | | Gate | | | Up | | | | | |
contiguous | | | | | | | | | | |
| | | | | | | | | | |
v +------x-----x------+------x-----x------+ +----------+ V
K0 | | | | | contiguous
/ \ v v v v |
+---------+ +------x-----x------+------x-----x------+ |
M0 | A | | | | | | | | |
+---------+ +------x-----x------+------x-----x------+ |
----------> | | |
contiguous | V V
| x-----x +----------+
+------------> M1 | Y | ---------> | Out(O) |
ACT x-----x +----------+
K1 = N0 dim
* Note: Act could be Gelu/Silu/...
* Note: some model does not have Up
*/
template <typename BlockTile_0_,
typename WarpPerBlock_0_,
typename WarpTile_0_,
typename BlockTile_1_,
typename WarpPerBlock_1_,
typename WarpTile_1_>
struct FusedMoeGemmShape
{
using BlockTile_0 = remove_cvref_t<BlockTile_0_>;
using WarpPerBlock_0 = remove_cvref_t<WarpPerBlock_0_>;
using WarpTile_0 = remove_cvref_t<WarpTile_0_>;
using BlockTile_1 = remove_cvref_t<BlockTile_1_>;
using WarpPerBlock_1 = remove_cvref_t<WarpPerBlock_1_>;
using WarpTile_1 = remove_cvref_t<WarpTile_1_>;
static constexpr index_t NumWarps =
reduce_on_sequence(WarpPerBlock_0{}, multiplies{}, number<1>{});
static_assert(NumWarps == reduce_on_sequence(WarpPerBlock_1{}, multiplies{}, number<1>{}));
static constexpr index_t Block_M0 = BlockTile_0::at(number<0>{});
static constexpr index_t Block_N0 = BlockTile_0::at(number<1>{});
static constexpr index_t Block_K0 = BlockTile_0::at(number<2>{});
static constexpr index_t WarpPerBlock_M0 = WarpPerBlock_0::at(numner<0>{});
static constexpr index_t WarpPerBlock_N0 = WarpPerBlock_0::at(numner<1>{});
static constexpr index_t WarpPerBlock_K0 = WarpPerBlock_0::at(numner<2>{});
static constexpr index_t Warp_M0 = WarpTile_0::at(number<0>{});
static constexpr index_t Warp_N0 = WarpTile_0::at(number<1>{});
static constexpr index_t Warp_K0 = WarpTile_0::at(number<2>{});
static constexpr index_t ThreadPerBlock_M0 = Warp_M0 * WarpPerBlock_M0;
static constexpr index_t ThreadPerBlock_N0 = Warp_N0 * WarpPerBlock_N0;
static constexpr index_t ThreadPerBlock_K0 = Warp_K0 * WarpPerBlock_K0;
static_assert(Block_M0 % ThreadPerBlock_M0 == 0);
static_assert(Block_N0 % ThreadPerBlock_N0 == 0);
static_assert(Block_K0 % ThreadPerBlock_K0 == 0);
static constexpr index_t Repeat_M0 = Block_M0 / ThreadPerBlock_M0;
static constexpr index_t Repeat_N0 = Block_N0 / ThreadPerBlock_N0;
static constexpr index_t Repeat_K0 = Block_K0 / ThreadPerBlock_K0;
static constexpr index_t Block_M1 = BlockTile_1::at(number<0>{});
static constexpr index_t Block_N1 = BlockTile_1::at(number<1>{});
static constexpr index_t Block_K1 = BlockTile_1::at(number<2>{});
static constexpr index_t WarpPerBlock_M1 = WarpTile_1::at(numner<0>{});
static constexpr index_t WarpPerBlock_N1 = WarpTile_1::at(numner<1>{});
static constexpr index_t WarpPerBlock_K1 = WarpTile_1::at(numner<2>{});
static constexpr index_t Warp_M1 = WarpPerBlock_1::at(number<0>{});
static constexpr index_t Warp_N1 = WarpPerBlock_1::at(number<1>{});
static constexpr index_t Warp_K1 = WarpPerBlock_1::at(number<2>{});
static constexpr index_t ThreadPerBlock_M1 = Warp_M1 * WarpPerBlock_M1;
static constexpr index_t ThreadPerBlock_N1 = Warp_N1 * WarpPerBlock_N1;
static constexpr index_t ThreadPerBlock_K1 = Warp_K1 * WarpPerBlock_K1;
static_assert(Block_M1 % ThreadPerBlock_M1 == 0);
static_assert(Block_N1 % ThreadPerBlock_N1 == 0);
static_assert(Block_K1 % ThreadPerBlock_K1 == 0);
static constexpr index_t Repeat_M1 = Block_M1 / ThreadPerBlock_M1;
static constexpr index_t Repeat_N1 = Block_N1 / ThreadPerBlock_N1;
static constexpr index_t Repeat_K1 = Block_K1 / ThreadPerBlock_K1;
static constexpr index_t BlockSize = warpSize * NumWarps;
// some assert
static_assert(Block_M0 == Block_M1);
static_assert(Block_N0 == Block_K1 || (Block_N0 / 2) == Block_K1); // Gate Only or Gate+Up
// pre-shuffle tile size compute (assume only for B matrix)
// we flatten the each wave tile to a 1d linear tensor(at model loading time)
// e.g. originally we have Block_N*Block_K tile size, after pre-shuffle
// we can have Block_Nr*Block_Kr*Block_W, where Block_W is Warp_N*Warp_K,
// and Block_Nr=Block_N/Warp_N, Block_Kr=Block_K/Warp_K
static constexpr index_t Block_W0 = Warp_N0 * Warp_K0;
static constexpr index_t Block_Nr0 = Block_N0 / Warp_N0;
static constexpr index_t Block_Kr0 = Block_K0 / Warp_K0;
static constexpr index_t Block_W1 = Warp_N1 * Warp_K1;
static constexpr index_t Block_Nr1 = Block_N1 / Warp_N1;
static constexpr index_t Block_Kr1 = Block_K1 / Warp_K1;
static_assert(Block_W0 == Block_W1);
static_assert(Block_Nr0 == Block_Kr1);
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile {
template <typename BlockShape_>
struct FusedMoeGemmTilePartitioner_Linear
{
// FusedMoeGemmShape
using BlockShape = ck_tile::remove_cvref_t<BlockShape_>;
static constexpr const char* name = "eh"; // expert x hidden
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*num_sorted_tiles*/,
ck_tile::index_t /*hidden_size*/))
{
index_t i_n = blockIdx.x;
index_t i_m = blockIdx.y;
return ck_tile::make_tuple(i_m, i_n);
}
CK_TILE_HOST static constexpr auto GridSize(index_t max_tokens, index_t hidden_size)
{
// TODO: this may need tuning
index_t ms = ck_tile::integer_divide_ceil(max_tokens, BlockShape::Block_M0);
index_t ns = ck_tile::integer_divide_ceil(hidden_size, BlockShape::Block_N0);
return dim3(ns, ms, 1);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// TODO: alow 2 gemm have different type
template <typename ADataType_,
typename GDataType_,
typename DDataType_,
typename AccDataType_,
typename ODataType_,
typename AScaleDataType_,
typename W0ScaleDataType_,
typename W1ScaleDataType_,
typename YSmoothScaleDataType_,
typename TopkWeightDataType_,
typename IndexDataType_, // data type for all indexing
typename GateActivation_, // = ck_tile::element_wise::Silu,
typename BlockShape_, // shoule be FusedMoeGemmShape
typename Traits_>
struct FusedMoeGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using GDataType = remove_cvref_t<GDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using AScaleDataType = remove_cvref_t<AScaleDataType_>;
using GScaleDataType = remove_cvref_t<GScaleDataType_>;
using DScaleDataType = remove_cvref_t<DScaleDataType_>;
using YSmoothScaleDataType = remove_cvref_t<YSmoothScaleDataType_>;
using TopkWeightDataType = remove_cvref_t<TopkWeightDataType_>;
using IndexDataType = remove_cvref_t<IndexDataType_>;
// the input for next gemm should have same time as
using YDataType = ADataType;
using GateActivation = remove_cvref_t<GateActivation_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment