Commit e1b457ec authored by letaoqin's avatar letaoqin
Browse files

g ad d add pading

parent e97fdbc3
......@@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel
number<Pipeline::kAlignmentG>{},
number<1>{});
const auto g_window_ = make_tile_window(
const auto g_view_1_ = pad_tensor_view(
g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
sequence<PadIntermediateSize, PadHiddenSize>{});
const auto g_window_ = make_tile_window(
g_view_1_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0});
return g_window_;
......@@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel
number<Pipeline::kAlignmentD>{},
number<1>{});
const auto d_window_ = make_tile_window(
const auto d_view_1_ = pad_tensor_view(
d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
sequence<PadHiddenSize, PadIntermediateSize>{});
const auto d_window_ = make_tile_window(
d_view_1_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, idx_n0});
return d_window_;
}();
......
......@@ -391,7 +391,7 @@ struct FusedMoeGemmKernel
number<Pipeline::kAlignmentO>{},
number<1>{});
// gather is here
// scatter is here
auto o_scatter_view_ = transform_tensor_view(
o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
......
......@@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{
// matrix a or tokens smem
constexpr index_t smem_mat_a =
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
return smem_mat_a;
return Policy::template GetSmemSize_A<Problem>();
}
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
......@@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR void* smem,
index_t hidden_size,
index_t /*intermediate_size*/,
CWindow& c_window_)
CWindow& /*c_window_*/)
{
ignore = c_window_;
ignore = hidden_size;
ignore = w_window_;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType));
......@@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General
#if 0
PrintMem(y_pre, "Y_pre", 0);
#endif
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{
block_sync_lds();
store_tile(c_window_, y_pre);
}
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
// {
// block_sync_lds();
// store_tile(c_window_, y_pre);
// }
// save to lds
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
......@@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return a_lds_block_desc;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment