Commit b885995c authored by letaoqin's avatar letaoqin
Browse files

first right version

parent 40df5c8b
...@@ -87,32 +87,42 @@ template <typename IndexType> ...@@ -87,32 +87,42 @@ template <typename IndexType>
void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n) void output_matrix_2d(ck_tile::HostTensor<IndexType>& data, int m, int n)
{ {
std::cout << std::endl; std::cout << std::endl;
std::cout << "[";
for(int i = 0; i < m; i++) for(int i = 0; i < m; i++)
{ {
std::cout << "Line " << i << "\t"; std::cout << "[";
for(int j = 0; j < n; j++) for(int j = 0; j < n; j++)
{ {
std::cout << ck_tile::type_convert<float>(data(i, j)) << "\t"; std::cout << ck_tile::type_convert<float>(data(i, j));
if(j != n - 1)
std::cout << ", ";
} }
std::cout << std::endl; std::cout << "],\n";
} }
std::cout << "]\n";
} }
template <typename IndexType> template <typename IndexType>
void output_matrix_3d(ck_tile::HostTensor<IndexType>& data, int M, int N, int J) void output_matrix_3d(ck_tile::HostTensor<IndexType>& data, int M, int N, int J)
{ {
std::cout << std::endl; std::cout << std::endl;
std::cout << "[";
for(int m = 0; m < M; m++) for(int m = 0; m < M; m++)
{ {
std::cout << "[";
for(int n = 0; n < N; n++) for(int n = 0; n < N; n++)
{ {
std::cout << "experts: " << m << " Line: " << n << "\t"; std::cout << "[";
for(int j = 0; j < J; j++) for(int j = 0; j < J; j++)
{ {
std::cout << ck_tile::type_convert<float>(data(m, n, j)) << "\t"; std::cout << ck_tile::type_convert<float>(data(m, n, j));
if(j != j - 1)
std::cout << ", ";
} }
std::cout << std::endl; std::cout << "],\n";
} }
std::cout << "],\n";
} }
std::cout << "]\n";
} }
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
{ {
...@@ -237,6 +247,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -237,6 +247,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// host verify // host verify
ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ADataType> a_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size}); ck_tile::HostTensor<GDataType> g_host({experts, shared_intermediate_size_0, hidden_size});
ck_tile::HostTensor<ODataType> c_host({tokens, intermediate_size});
ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1}); ck_tile::HostTensor<DDataType> d_host({experts, hidden_size, shared_intermediate_size_1});
ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1}); ck_tile::HostTensor<ODataType> o_host({tokens, hidden_size}, {stride, 1});
ck_tile::HostTensor<AScaleDataType> sa_host({tokens}); ck_tile::HostTensor<AScaleDataType> sa_host({tokens});
...@@ -269,6 +280,9 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -269,6 +280,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host); ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f, seed, true}(a_host);
ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host); ck_tile::FillUniformDistribution<GDataType>{-.5f, .5f, seed, true}(g_host);
ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host); ck_tile::FillUniformDistribution<DDataType>{-.5f, .5f, seed, true}(d_host);
// ck_tile::FillConstant<ADataType>{1}(a_host);
// ck_tile::FillConstant<GDataType>{1}(g_host);
// ck_tile::FillConstant<DDataType>{1}(d_host);
ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host); ck_tile::FillUniformDistribution<AScaleDataType>{-.5f, .5f, seed, true}(sa_host);
ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host); ck_tile::FillUniformDistribution<GScaleDataType>{-.5f, .5f, seed, true}(sg_host);
ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host); ck_tile::FillUniformDistribution<DScaleDataType>{-.5f, .5f, seed, true}(sd_host);
...@@ -389,6 +403,10 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -389,6 +403,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
ck_tile::DeviceMem sd_buf(sd_host); ck_tile::DeviceMem sd_buf(sd_host);
ck_tile::DeviceMem sy_buf(sy_host); ck_tile::DeviceMem sy_buf(sy_host);
ck_tile::DeviceMem o_buf(o_host); ck_tile::DeviceMem o_buf(o_host);
ck_tile::DeviceMem c_buf(c_host);
c_buf.SetZero();
std::cout << "\nc size: " << c_buf.GetBufferSize()
<< " tokens * intermediate_size: " << tokens * intermediate_size << std::endl;
// manually clear output buffer for atomic // manually clear output buffer for atomic
o_buf.SetZero(); o_buf.SetZero();
...@@ -428,7 +446,8 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -428,7 +446,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
experts, experts,
topk, topk,
stride, stride,
max_num_tokens_padded}; max_num_tokens_padded,
c_buf.GetDeviceBuffer()};
float ave_time = fused_moegemm( float ave_time = fused_moegemm(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
...@@ -469,6 +488,22 @@ bool run(const ck_tile::ArgParser& arg_parser) ...@@ -469,6 +488,22 @@ bool run(const ck_tile::ArgParser& arg_parser)
gate_only); gate_only);
auto o_dev = o_buf.ToHost<ODataType>(); auto o_dev = o_buf.ToHost<ODataType>();
auto c_dev = c_buf.ToHost<ADataType>();
std::cout << std::endl;
std::cout << o_dev << std::endl;
// std::cout << c_dev << std::endl;
// int count = 0;
// std::cout << "[";
// for(int i = 0; i < tokens; i++)
// {
// std::cout << "[";
// for(int j = 0; j < intermediate_size; j++)
// {
// std::cout << ck_tile::type_convert<float>(c_dev(count++)) << ",";
// }
// std::cout << "],\n";
// }
// std::cout << "]\n";
// o_dev.savetxt("gpu-out.txt", "float"); // o_dev.savetxt("gpu-out.txt", "float");
auto [rtol, atol] = get_elimit<ADataType>(); auto [rtol, atol] = get_elimit<ADataType>();
pass &= ck_tile::check_err( pass &= ck_tile::check_err(
......
...@@ -340,7 +340,7 @@ template <typename T> ...@@ -340,7 +340,7 @@ template <typename T>
struct FillConstant struct FillConstant
{ {
T value_{0}; T value_{0};
FillConstant(float value):value_(ck_tile::type_convert<T>(value)){}
template <typename ForwardIter> template <typename ForwardIter>
void operator()(ForwardIter first, ForwardIter last) const void operator()(ForwardIter first, ForwardIter last) const
{ {
......
...@@ -586,7 +586,7 @@ struct HostTensor ...@@ -586,7 +586,7 @@ struct HostTensor
} }
if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>) if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
{ {
os << type_convert<float>(t.mData[idx]) << " #### "; os << type_convert<float>(t.mData[idx]) << " ";
} }
else else
{ {
......
...@@ -137,6 +137,7 @@ void reference_fused_moe( ...@@ -137,6 +137,7 @@ void reference_fused_moe(
", 1:" + std::to_string(intermediate_size_1)); ", 1:" + std::to_string(intermediate_size_1));
for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++) for(ck_tile::index_t i_n = 0; i_n < intermediate_size_1; i_n++)
{ {
//y(0, i_n) = acc_0(0, i_n);
Activation{}(y(0, i_n), acc_0(0, i_n)); Activation{}(y(0, i_n), acc_0(0, i_n));
// if(i_expert == 0) // if(i_expert == 0)
// printf("in:%d, %f\t", i_n, y(0, i_n)); // printf("in:%d, %f\t", i_n, y(0, i_n));
......
...@@ -199,6 +199,7 @@ struct FusedMoeGemmGlKernel ...@@ -199,6 +199,7 @@ struct FusedMoeGemmGlKernel
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
void* c_ptr;
}; };
// TODO: switch karg based on // TODO: switch karg based on
...@@ -251,12 +252,10 @@ struct FusedMoeGemmGlKernel ...@@ -251,12 +252,10 @@ struct FusedMoeGemmGlKernel
index_t idx_n0 = index_t idx_n0 =
__builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0); __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0);
const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col]
const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m
// position // position
// index_t token_id =
// reinterpret_cast<const index_t*>(kargs.sorted_token_ids_ptr)[sorted_token_id];
auto topk_weight = auto topk_weight =
reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id]; reinterpret_cast<const TopkWeightDataType*>(kargs.sorted_weight_ptr)[sorted_token_id];
...@@ -305,19 +304,26 @@ struct FusedMoeGemmGlKernel ...@@ -305,19 +304,26 @@ struct FusedMoeGemmGlKernel
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0}); {idx_n0, 0});
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 16; i++)
// {
// printf("in G index is %d , value is: %f\n",
// i,
// ck_tile::type_convert<float>(g_ptr[i]));
// }
// }
return g_window_; return g_window_;
}(); }();
auto c_window = [&]() {
YDataType* c_ptr = reinterpret_cast<YDataType*>(kargs.c_ptr);
// note interm_idx_nr is along the gemm-k dim of 2nd gemm
auto c_view_ = make_naive_tensor_view<address_space_enum::global>(
c_ptr,
make_tuple(kargs.num_tokens, kargs.intermediate_size),
make_tuple(kargs.intermediate_size, 1),
number<Pipeline::kAlignmentD>{},
number<1>{});
auto c_window_ = make_tile_window(
c_view_,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_N0>{}),
{0, 0});
return c_window_;
}();
const auto d_window = [&]() { const auto d_window = [&]() {
const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) + const DDataType* d_ptr = reinterpret_cast<const DDataType*>(kargs.d_ptr) +
static_cast<long_index_t>(expert_id) * expert_stride_1; static_cast<long_index_t>(expert_id) * expert_stride_1;
...@@ -371,7 +377,8 @@ struct FusedMoeGemmGlKernel ...@@ -371,7 +377,8 @@ struct FusedMoeGemmGlKernel
topk_weight, topk_weight,
smem, smem,
kargs.hidden_size, kargs.hidden_size,
kargs.intermediate_size); kargs.intermediate_size,
c_window);
} }
}; };
......
...@@ -118,6 +118,7 @@ struct FusedMoeGemmHostArgs ...@@ -118,6 +118,7 @@ struct FusedMoeGemmHostArgs
index_t stride_token; // for input/output, stride for each row, should >= hidden_size index_t stride_token; // for input/output, stride for each row, should >= hidden_size
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
void* c_ptr;
}; };
// This is scatter/gather b2b group-gemm // This is scatter/gather b2b group-gemm
......
...@@ -68,16 +68,24 @@ struct FusedMoeGemmPipeline_General ...@@ -68,16 +68,24 @@ struct FusedMoeGemmPipeline_General
static constexpr const char* name = "flatmm_gl"; static constexpr const char* name = "flatmm_gl";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{ {
// matrix a or tokens smem // matrix a or tokens smem
constexpr index_t smem_mat_a = constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
return smem_mat_a;
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a = GetSmemSizeA();
constexpr index_t smem_mat_d =
BlockShape::Block_N0 * BlockShape::Block_K0 * sizeof(GDataType);
// shuffle C matrix // shuffle C matrix
constexpr index_t smem_bridge = constexpr index_t smem_bridge =
BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType); BlockShape::Block_M0 * BlockShape::Block_N0 * sizeof(YDataType);
return max(smem_mat_a, smem_bridge); return max(smem_mat_a + smem_mat_d, smem_bridge);
// return Policy::template GetSmemSize<Problem>(); // return Policy::template GetSmemSize<Problem>();
} }
...@@ -117,7 +125,11 @@ struct FusedMoeGemmPipeline_General ...@@ -117,7 +125,11 @@ struct FusedMoeGemmPipeline_General
}); });
}); });
} }
template <typename AWindow, typename GWindow, typename DWindow, typename OWindow> template <typename AWindow,
typename GWindow,
typename DWindow,
typename OWindow,
typename CWindow>
CK_TILE_DEVICE auto operator()(const AWindow& a_window_, CK_TILE_DEVICE auto operator()(const AWindow& a_window_,
const GWindow& g_window_, const GWindow& g_window_,
const DWindow& d_window_, const DWindow& d_window_,
...@@ -125,16 +137,30 @@ struct FusedMoeGemmPipeline_General ...@@ -125,16 +137,30 @@ struct FusedMoeGemmPipeline_General
TopkWeightDataType topk_weight, TopkWeightDataType topk_weight,
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t hidden_size, index_t hidden_size,
index_t intermediate_size) index_t /*intermediate_size*/,
CWindow& c_window_)
{ {
ignore = topk_weight;
ignore = c_window_;
ignore = hidden_size;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
auto a_lds_view = make_tensor_view<address_space_enum::lds>( CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType));
auto a_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeLdsBlockDesc_A<Problem>()); smem_0, Policy::template MakeLdsBlockDesc_A<Problem>());
auto a_lds_win = make_tile_window( auto a_lds_win = make_tile_window(
a_lds_view, a_lds_view,
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
{0, 0}); {0, 0});
auto g_lds_view = make_tensor_view<address_space_enum::lds>(
smem_1, Policy::template MakeLdsBlockDesc_G<Problem>());
auto g_lds_win = make_tile_window(
g_lds_view,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{0, 0});
auto a_global_to_dram_window = make_tile_window( auto a_global_to_dram_window = make_tile_window(
a_window_.get_bottom_tensor_view(), a_window_.get_bottom_tensor_view(),
make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_M0>{}, number<BlockShape::Block_K0>{}),
...@@ -148,69 +174,85 @@ struct FusedMoeGemmPipeline_General ...@@ -148,69 +174,85 @@ struct FusedMoeGemmPipeline_General
g_window_.get_window_origin(), g_window_.get_window_origin(),
Policy::template MakeGlobalTileDistribution_G<Problem>()); Policy::template MakeGlobalTileDistribution_G<Problem>());
// gemm gate #if 0
PrintMem(g_dram_block, "G", 0);
#endif
// gemm0(gate)
constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>(); constexpr auto gemm_0 = Policy::template GetBlockGemm0<Problem>();
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{}; auto s_acc = SaccBlockTileType{};
// save tokens to lds
auto a_dram_block = load_tile(a_global_to_dram_window); auto a_dram_block = load_tile(a_global_to_dram_window);
store_tile(a_lds_win, a_dram_block);
#if 0
PrintMem(a_dram_block,"A", 0, 1);
#endif
auto g_dram_block = load_tile(g_global_to_dram_window); auto g_dram_block = load_tile(g_global_to_dram_window);
// block_sync_load_raw();
// save tokens to lds
store_tile(a_lds_win, a_dram_block);
store_tile(g_lds_win, g_dram_block);
#if 1 #if 0
PrintMem(g_dram_block, "G", 0, 1); PrintMem(a_dram_block,"A", 0);
#endif #endif
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
constexpr index_t kK0 = BlockShape::Block_K0; constexpr index_t kK0 = BlockShape::Block_K0;
const index_t k0_loops = ck_tile::integer_divide_ceil(hidden_size, kK0); const index_t k0_loops = ck_tile::integer_divide_ceil(hidden_size, kK0);
index_t iCounter0 = k0_loops - 1; index_t iCounter0 = k0_loops - 1;
while(iCounter0 > 0) while(iCounter0 >= 0)
{ {
block_sync_lds(); if(iCounter0 > 0)
{
move_tile_window(a_global_to_dram_window, {0, kK0});
move_tile_window(g_global_to_dram_window, {0, kK0});
gemm_0(s_acc, a_lds_win, g_dram_block); a_dram_block = load_tile(a_global_to_dram_window);
g_dram_block = load_tile(g_global_to_dram_window);
}
block_sync_lds(); block_sync_lds();
move_tile_window(a_global_to_dram_window, {0, kK0}); gemm_0(s_acc, a_lds_win, g_lds_win);
move_tile_window(g_global_to_dram_window, {0, kK0}); // gemm_0(s_acc, a_lds_win, g_dram_block);
a_dram_block = load_tile(a_global_to_dram_window); block_sync_lds();
g_dram_block = load_tile(g_global_to_dram_window);
if(iCounter0 > 0)
{
store_tile(a_lds_win, a_dram_block);
store_tile(g_lds_win, g_dram_block);
}
store_tile(a_lds_win, a_dram_block);
iCounter0--; iCounter0--;
} }
// tail // tail
{ // {
block_sync_lds(); // block_sync_lds();
gemm_0(s_acc, a_lds_win, g_dram_block); // // gemm_0(s_acc, a_lds_win, g_dram_block);
} // gemm_0(s_acc, a_lds_win, g_lds_win);
// block_sync_lds();
// }
#if 0 #if 0
PrintMem(s_acc); PrintMem(s_acc, "S", 0);
#endif #endif
// relu // relu
const auto activation = ck_tile::element_wise::Gelu{}; // const auto activation = ck_tile::element_wise::Gelu{};
tile_elementwise_inout(activation, s_acc, s_acc); // tile_elementwise_inout(activation, s_acc, s_acc);
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
// move sacc to LDS #if 0
PrintMem(y_pre, "Y_pre", 0);
#endif
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
block_sync_lds();
store_tile(c_window_, y_pre);
}
// save to lds
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
auto bridge_slds_win = auto bridge_slds_win =
make_tile_window(bridge_lds_view, make_tile_window(bridge_lds_view,
Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(), Policy::template MakeBridgeLdsBlockDesc<Problem>().get_lengths(),
{0, 0}); {0, 0});
// cast data to YDataType
auto y_pre = cast_tile<YDataType>(s_acc);
#if 0
PrintMem(y_pre);
#endif
// save to lds
store_tile(bridge_slds_win, y_pre); store_tile(bridge_slds_win, y_pre);
block_sync_lds(); block_sync_lds();
...@@ -225,7 +267,20 @@ struct FusedMoeGemmPipeline_General ...@@ -225,7 +267,20 @@ struct FusedMoeGemmPipeline_General
{0, 0}, {0, 0},
Policy::template MakeYTileDistribution<Problem>()); Policy::template MakeYTileDistribution<Problem>());
auto y = load_tile(bridge_llds_win); auto y = load_tile(bridge_llds_win);
block_sync_lds();
#if 0
PrintMem(y,"Y",0);
//PrintMem(y,"Y",32);
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
for(int i = 0; i < 16; i++)
{
printf("\n smem_0[%d]: %f ", i, type_convert<float>(smem_0[i]));
}
}
//store_tile(c_window_, y);
#endif
// d data // d data
auto d_global_to_dram_window = make_tile_window( auto d_global_to_dram_window = make_tile_window(
d_window_.get_bottom_tensor_view(), d_window_.get_bottom_tensor_view(),
...@@ -234,20 +289,20 @@ struct FusedMoeGemmPipeline_General ...@@ -234,20 +289,20 @@ struct FusedMoeGemmPipeline_General
Policy::template MakeGlobalTileDistribution_D<Problem>()); Policy::template MakeGlobalTileDistribution_D<Problem>());
auto d = load_tile(d_global_to_dram_window); auto d = load_tile(d_global_to_dram_window);
#if 0 #if 0
PrintMem(d,"D",64); PrintMem(d,"D",0);
#endif #endif
// add to LDS // add to LDS
auto o_alds_view = auto o_lds_view =
make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>( make_naive_tensor_view<address_space_enum::lds, memory_operation_enum::atomic_add>(
smem_0, smem_0,
make_tuple(number<32>{}, number<32>{}), make_tuple(number<128>{}, number<32>{}),
make_tuple(32, 1), make_tuple(32, 1),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
auto o_alds_win = auto o_alds_win =
make_tile_window(o_alds_view, make_tuple(number<32>{}, number<32>{}), {0, 0}); make_tile_window(o_lds_view, make_tuple(number<128>{}, number<32>{}), {0, 0});
auto o_olds_win = auto o_olds_win =
make_tile_window(o_alds_view, make_tile_window(o_lds_view,
make_tuple(number<32>{}, number<32>{}), make_tuple(number<32>{}, number<32>{}),
{0, 0}, {0, 0},
Policy::template MakeGlobalTileDistribution_O<Problem>()); Policy::template MakeGlobalTileDistribution_O<Problem>());
...@@ -278,31 +333,38 @@ struct FusedMoeGemmPipeline_General ...@@ -278,31 +333,38 @@ struct FusedMoeGemmPipeline_General
gemm_1(o_acc, y, d); gemm_1(o_acc, y, d);
// block_sync_lds(); // block_sync_lds();
tile_elementwise_inout( // tile_elementwise_inout(
[&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc); // [&topk_weight](auto& x) { x = x * type_convert<float>(topk_weight); }, o_acc);
auto o = cast_tile<ODataType>(o_acc); auto o = cast_tile<ODataType>(o_acc);
#if 0
PrintMem(o, "O", 65);
#endif
store_tile(o_alds_win, o); store_tile(o_alds_win, o);
block_sync_lds(); block_sync_lds();
// if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// for(int i = 0; i < 42; i++)
// {
// printf("\n%d value is %f\t", i, type_convert<float>(smem_0[i]));
// }
// }
if(threadIdx.x < 64)
{ {
auto o_out = load_tile(o_olds_win); if(threadIdx.x < 64)
block_sync_lds(); {
store_tile(o_window_, o_out); auto o0 = load_tile(o_olds_win);
for(int step = 1; step < 4; step++)
{
move_tile_window(o_olds_win, {32, 0});
auto o1 = load_tile(o_olds_win);
for(int i = 0; i < 16; i++)
{
o0.get_thread_buffer()(i) = type_convert<ODataType>(
type_convert<float>(o0.get_thread_buffer()[i]) +
type_convert<float>(o1.get_thread_buffer()[i]));
}
}
update_tile(o_window_, o0);
}
} }
// ignore = o_olds_win;
// store_tile(o_window_, o); // store_tile(o_window_, o);
#if 0 #if 0
PrintMem(o,"O"); PrintMem(o,"O");
#endif #endif
} }
// store_tile(o_window_, a_dram_block); // store_tile(o_window_, a_dram_block);
} }
}; };
......
...@@ -10,6 +10,8 @@ ...@@ -10,6 +10,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
...@@ -198,7 +200,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -198,7 +200,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
tuple<sequence<0, 1>, sequence<1, 2>>, tuple<sequence<0, 1>, sequence<1, 2>>,
tuple<sequence<0, 0>, sequence<2, 0>>, tuple<sequence<0, 0>, sequence<2, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<1, 1>>{});
} }
template <typename Problem> template <typename Problem>
...@@ -214,13 +216,17 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -214,13 +216,17 @@ struct FusedMoeGemmPipelineGeneralPolicy
typename S_::WarpTile_0>>; typename S_::WarpTile_0>>;
constexpr auto warp_gemm = GetWarpGemm0<Problem>(); constexpr auto warp_gemm = GetWarpGemm0<Problem>();
using BlockGemmPolicy = BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::GDataType, // using BlockGemmPolicy =
typename Problem::AccDataType, // BlockGemmASmemBRegCRegV1CustomPolicy<typename
typename S_::WarpPerBlock_0, // Problem::ADataType,
decltype(warp_gemm)>; typename Problem::GDataType,
typename Problem::AccDataType,
return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{}; typename S_::WarpPerBlock_0,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
// return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
} }
template <typename Problem> template <typename Problem>
...@@ -288,28 +294,6 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -288,28 +294,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
return d_block_dstr; return d_block_dstr;
} }
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto MakeGlobalTileDistribution_O()
// {
// using S_ = remove_cvref_t<typename Problem::BlockShape>;
// using WarpGemm = remove_cvref_t<decltype(GetWarpGemm1<Problem>())>;
// // using CDataType = typename WarpGemm::CDataType;
// constexpr auto c_block_outer_dstr_encoding =
// tile_distribution_encoding<sequence<>,
// tuple<sequence<S_::Repeat_M1, S_::WarpPerBlock_M1>,
// sequence<S_::Repeat_N1, S_::WarpPerBlock_N1>>,
// tuple<sequence<1, 2>>,
// tuple<sequence<1, 1>>,
// sequence<1, 2>,
// sequence<0, 0>>{};
// constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
// c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
// constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
// return c_block_dstr;
// }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A() CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_A()
{ {
...@@ -322,7 +306,7 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -322,7 +306,7 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kK0>{}, number<Block_M>{}, number<kK1>{}), make_tuple(number<kK0>{}, number<Block_M>{}, number<kK1>{}),
make_tuple(number<(Block_M + 1) * kK1>{}, number<kK1>{}, number<1>{}), make_tuple(number<Block_M * kK1>{}, number<kK1>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
...@@ -333,9 +317,47 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -333,9 +317,47 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return a_lds_block_desc; return a_lds_block_desc;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDesc_G()
{
constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t Block_K = Problem::BlockShape::Block_K0;
constexpr index_t kK1 = GetSmemKPack_A<Problem>(); // LDS
constexpr index_t kK0 = Block_K / kK1;
static_assert(Block_K % kK1 == 0);
constexpr auto d_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kK0>{}, number<Block_N>{}, number<kK1>{}),
make_tuple(number<Block_N * kK1>{}, number<kK1>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto d_lds_block_desc = transform_tensor_descriptor(
d_lds_block_desc_0,
make_tuple(make_pass_through_transform(number<Block_N>{}),
make_merge_transform(make_tuple(number<kK0>{}, number<kK1>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto d_lds_block_desc = make_naive_tensor_descriptor(
// make_tuple(number<Block_N>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return d_lds_block_desc;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc() CK_TILE_HOST_DEVICE static constexpr auto MakeBridgeLdsBlockDesc()
{ {
...@@ -343,11 +365,10 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -343,11 +365,10 @@ struct FusedMoeGemmPipelineGeneralPolicy
constexpr index_t Block_N = Problem::BlockShape::Block_N0; constexpr index_t Block_N = Problem::BlockShape::Block_N0;
constexpr index_t KVector = GetSmemKPack_Y<Problem>(); constexpr index_t KVector = GetSmemKPack_Y<Problem>();
constexpr index_t KPad = 0;
constexpr auto desc = constexpr auto desc =
make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}), make_naive_tensor_descriptor(make_tuple(number<Block_M>{}, number<Block_N>{}),
make_tuple(number<Block_N + KPad>{}, number<1>{}), make_tuple(number<Block_N>{}, number<1>{}),
number<KVector>{}, number<KVector>{},
number<1>{}); number<1>{});
return desc; return desc;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment