Commit b885995c authored by letaoqin's avatar letaoqin
Browse files

first right version

parent 40df5c8b
......@@ -87,32 +87,42 @@ template <typename IndexType>
void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
{
std::cout << std::endl;
std::cout << "[";
for(int i = 0; i < m; i++)
{
std::cout << "Line " << i << "\t";
std::cout << "[";
for(int j = 0; j < n; j++)
{
std::cout << ck_tile::type_convert<float>(data(i, j)) << "\t";
std::cout << ck_tile::type_convert<float>(data(i, j));
if(j != n - 1)
std::cout << ", ";
}
std::cout << std::endl;
std::cout << "],\n";
}
std::cout << "]\n";
}
template <typename IndexType>
void output_matrix_3d(ck_tile::HostTensor<IndexType>& data, int M, int N, int J)
{
std::cout << std::endl;
std::cout << "[";
for(int m = 0; m < M; m++)
{
std::cout << "[";
for(int n = 0; n < N; n++)
{
std::cout << "experts: " << m << " Line: " << n << "\t";
std::cout << "[";
for(int j = 0; j < J; j++)
{
std::cout << ck_tile::type_convert<float>(data(m, n, j)) << "\t";
std::cout << ck_tile::type_convert<float>(data(m, n, j));
if(j != j - 1)
std::cout << ", ";
}
std::cout << std::endl;
std::cout << "],\n";
}
std::cout << "],\n";
}
std::cout << "]\n";
}
auto create_args(int argc, char* argv[])
{
......@@ -237,6 +247,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<ODataType> c_host({tokens, intermediate_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
......@@ -269,6 +280,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
// ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
......@@ -389,6 +403,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host);
ck_tile::DeviceMem c_buf(c_host);
c_buf.SetZero();
std::cout << "\nc size: " << c_buf.GetBufferSize()
<< " tokens * intermediate_size: " << tokens * intermediate_size << std::endl;
// manually clear output buffer for atomic
o_buf.SetZero();
......@@ -428,7 +446,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts,
topk,
stride,
max_num_tokens_padded};
max_num_tokens_padded,
c_buf.GetDeviceBuffer()};
float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
......@@ -469,6 +488,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
gate_only);
auto o_dev = o_buf.ToHost<ODataType>();
auto c_dev = c_buf.ToHost<ADataType>();
std::cout << std::endl;
std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl;
// int count = 0;
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "[";
// for(int j = 0; j < intermediate_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(c_dev(count++)) << ",";
// }
// std::cout << "],\n";
// }
// std::cout << "]\n";
// o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err(
......
......@@ -340,7 +340,7 @@ template <typename T>
struct FillConstant
{
T value_{0};
FillConstant(float value):value_(ck_tile::type_convert<T>(value)){}
template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const
{
......
......@@ -586,7 +586,7 @@ struct HostTensor
}
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{
os << type_convert<float>(t.mData[idx]) << " #### ";
os << type_convert<float>(t.mData[idx]) << " ";
}
else
{
......
......@@ -137,6 +137,7 @@ void reference_fused_moe(
", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{
//y(0, i_n) = acc_0(0, i_n);
Activation{}(y(0, i_n), acc_0(0, i_n));
// if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n));
......
......@@ -199,6 +199,7 @@ struct FusedMoeGemmGlKernel
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
void* c_ptr;
};
// TODO: switch karg based on
......@@ -255,8 +256,6 @@ struct FusedMoeGemmGlKernel
const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
......@@ -305,19 +304,26 @@ struct FusedMoeGemmGlKernel
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return g_window_;
}();
auto c_window = [&]() {
YDataType* c_ptr = reinterpret_cast<YDataType*>(kargs.c_ptr);
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
auto c_view_ = make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.num_tokens, kargs.intermediate_size),
make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
auto c_window_ = make_tile_window(
c_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_N0>{}),
{0, 0});
return c_window_;
}();
const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1;
......@@ -371,7 +377,8 @@ struct FusedMoeGemmGlKernel
topk_weight,
smem,
kargs.hidden_size,
kargs.intermediate_size);
kargs.intermediate_size,
c_window);
}
};
......
......@@ -118,6 +118,7 @@ struct FusedMoeGemmHostArgs
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
void* c_ptr;
};
// This is scatter/gather b2b group-gemm
......
......@@ -68,16 +68,24 @@ struct FusedMoeGemmPipeline_General
static constexpr const char* name = "flatmm_gl";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
return smem_mat_a;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a = GetSmemSizeA();
constexpr index_t smem_mat_d =
BlockShape::Block_N0 * BlockShape::Block_K0 * sizeof(GDataType);
// shuffle C matrix
constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge);
return max(smem_mat_a + smem_mat_d, smem_bridge);
// return Policy::template GetSmemSize<Problem>();
}
......@@ -117,7 +125,11 @@ struct FusedMoeGemmPipeline_General
});
});
}
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow>
template <typename AWindow,
typename GWindow,
typename DWindow,
typename OWindow,
typename CWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_,
const DWindow& d_window_,
......@@ -125,9 +137,16 @@ struct FusedMoeGemmPipeline_General
TopkWeightDataType topk_weight,
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t intermediate_size)
index_t /*intermediate_size*/,
CWindow& c_window_)
{
ignore = topk_weight;
ignore = c_window_;
ignore = hidden_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType));
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window(
......@@ -135,6 +154,13 @@ struct FusedMoeGemmPipeline_General
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto g_lds_view = make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsBlockDesc_G<Problem>());
auto g_lds_win = make_tile_window(
g_lds_view,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
......@@ -148,69 +174,85 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>());
// gemm gate
#if 0
PrintMem(g_dram_block, "G", 0);
#endif
// gemm0(gate)
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// save tokens to lds
auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
PrintMem(a_dram_block,"A", 0, 1);
#endif
auto g_dram_block = load_tile(g_global_to_dram_window);
// block_sync_load_raw();
// save tokens to lds
store_tile(a_lds_win, a_dram_block);
store_tile(g_lds_win, g_dram_block);
#if 1
PrintMem(g_dram_block, "G", 0, 1);
#if 0
PrintMem(a_dram_block,"A", 0);
#endif
clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(hidden_size, kK0);
index_t iCounter0 = k0_loops - 1;
while(iCounter0 > 0)
while(iCounter0 >= 0)
{
if(iCounter0 > 0)
{
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
block_sync_lds();
move_tile_window(a_global_to_dram_window, {0, kK0});
move_tile_window(g_global_to_dram_window, {0, kK0});
a_dram_block = load_tile(a_global_to_dram_window);
g_dram_block = load_tile(g_global_to_dram_window);
}
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_lds_win);
// gemm_0(s_acc, a_lds_win, g_dram_block);
block_sync_lds();
if(iCounter0 > 0)
{
store_tile(a_lds_win, a_dram_block);
store_tile(g_lds_win, g_dram_block);
}
iCounter0--;
}
// tail
{
block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block);
}
// {
// block_sync_lds();
// // gemm_0(s_acc, a_lds_win, g_dram_block);
// gemm_0(s_acc, a_lds_win, g_lds_win);
// block_sync_lds();
// }
#if 0
PrintMem(s_acc);
PrintMem(s_acc, "S", 0);
#endif
// relu
const auto activation = ck_tile::element_wise::Gelu{};
tile_elementwise_inout(activation, s_acc, s_acc);
// const auto activation = ck_tile::element_wise::Gelu{};
// tile_elementwise_inout(activation, s_acc, s_acc);
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
// move sacc to LDS
#if 0
PrintMem(y_pre, "Y_pre", 0);
#endif
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
block_sync_lds();
store_tile(c_window_, y_pre);
}
// save to lds
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win =
make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0});
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
#if 0
PrintMem(y_pre);
#endif
// save to lds
store_tile(bridge_slds_win, y_pre);
block_sync_lds();
......@@ -225,7 +267,20 @@ struct FusedMoeGemmPipeline_General
{0, 0},
Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win);
block_sync_lds();
#if 0
PrintMem(y,"Y",0);
//PrintMem(y,"Y",32);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
for(int i = 0; i < 16; i++)
{
printf("\n smem_0[%d]: %f ", i, type_convert<float>(smem_0[i]));
}
}
//store_tile(c_window_, y);
#endif
// d data
auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(),
......@@ -234,20 +289,20 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window);
#if 0
PrintMem(d,"D",64);
PrintMem(d,"D",0);
#endif
// add to LDS
auto o_alds_view =
auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>(
smem_0,
make_tuple(number<32>{}, number<32>{}),
make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1),
number<8>{},
number<1>{});
auto o_alds_win =
make_tile_window(o_alds_view, make_tuple(number<32>{}, number<32>{}), {0, 0});
make_tile_window(o_lds_view, make_tuple(number<128>{}, number<32>{}), {0, 0});
auto o_olds_win =
make_tile_window(o_alds_view,
make_tile_window(o_lds_view,
make_tuple(number<32>{}, number<32>{}),
{0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>());
......@@ -278,31 +333,38 @@ struct FusedMoeGemmPipeline_General
gemm_1(o_acc, y, d);
// block_sync_lds();
tile_elementwise_inout(
[&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
// tile_elementwise_inout(
// [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc);
#if 0
PrintMem(o, "O", 65);
#endif
store_tile(o_alds_win, o);
block_sync_lds();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
if(threadIdx.x < 64)
{
auto o_out = load_tile(o_olds_win);
block_sync_lds();
store_tile(o_window_, o_out);
auto o0 = load_tile(o_olds_win);
for(int step = 1; step < 4; step++)
{
move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win);
for(int i = 0; i < 16; i++)
{
o0.get_thread_buffer()(i) = type_convert<ODataType>(
type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
}
}
update_tile(o_window_, o0);
}
}
// ignore = o_olds_win;
// store_tile(o_window_, o);
#if 0
PrintMem(o,"O");
#endif
}
// store_tile(o_window_, a_dram_block);
}
};
......
......@@ -10,6 +10,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
......@@ -198,7 +200,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
sequence<1, 1>>{});
}
template <typename Problem>
......@@ -214,13 +216,17 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType,
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
// using BlockGemmPolicy =
// BlockGemmASmemBRegCRegV1CustomPolicy<typename
// Problem::ADataType,
typename Problem::GDataType,
typename Problem::AccDataType,
typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
// return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
......@@ -288,28 +294,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_block_dstr;
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
// {
// using S_ = remove_cvref_t<typename Problem::BlockShape>;
// using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// // using CDataType = typename WarpGemm::CDataType;
// constexpr auto c_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<>,
// tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
// sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<1, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// return c_block_dstr;
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
{
......@@ -322,7 +306,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kK0>{}, number<Block_M>{}, number<kK1>{}),
make_tuple(number<(Block_M + 1) * kK1>{}, number<kK1>{}, number<1>{}),
make_tuple(number<Block_M * kK1>{}, number<kK1>{}, number<1>{}),
number<8>{},
number<1>{});
......@@ -333,9 +317,47 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_G()
{
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t kK1 = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kK0 = Block_K / kK1;
static_assert(Block_K % kK1 == 0);
constexpr auto d_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kK0>{}, number<Block_N>{}, number<kK1>{}),
make_tuple(number<Block_N * kK1>{}, number<kK1>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto d_lds_block_desc = transform_tensor_descriptor(
d_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<Block_N>{}),
make_merge_transform(make_tuple(number<kK0>{}, number<kK1>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto d_lds_block_desc = make_naive_tensor_descriptor(
// make_tuple(number<Block_N>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return d_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{
......@@ -343,11 +365,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0;
constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}),
make_tuple(number<Block_N>{}, number<1>{}),
number<KVector>{},
number<1>{});
return desc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment