"...resnet50_tensorflow.git" did not exist on "e02da6578851cfea3e1f4faeaf692c3eef88fe61"
Commit e1b457ec authored by letaoqin's avatar letaoqin
Browse files

g ad d add pading

parent e97fdbc3
...@@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel ...@@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel
number<Pipeline::kAlignmentG>{}, number<Pipeline::kAlignmentG>{},
number<1>{}); number<1>{});
const auto g_window_ = make_tile_window( const auto g_view_1_ = pad_tensor_view(
g_view_, g_view_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}), make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
sequence<PadIntermediateSize, PadHiddenSize>{});
const auto g_window_ = make_tile_window(
g_view_1_,
make_tuple(number<BlockShape::Block_N0>{}, number<BlockShape::Block_K0>{}),
{idx_n0, 0}); {idx_n0, 0});
return g_window_; return g_window_;
...@@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel ...@@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel
number<Pipeline::kAlignmentD>{}, number<Pipeline::kAlignmentD>{},
number<1>{}); number<1>{});
const auto d_window_ = make_tile_window( const auto d_view_1_ = pad_tensor_view(
d_view_, d_view_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}), make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
sequence<PadHiddenSize, PadIntermediateSize>{});
const auto d_window_ = make_tile_window(
d_view_1_,
make_tuple(number<BlockShape::Block_N1>{}, number<BlockShape::Block_K1>{}),
{0, idx_n0}); {0, idx_n0});
return d_window_; return d_window_;
}(); }();
......
...@@ -391,7 +391,7 @@ struct FusedMoeGemmKernel ...@@ -391,7 +391,7 @@ struct FusedMoeGemmKernel
number<Pipeline::kAlignmentO>{}, number<Pipeline::kAlignmentO>{},
number<1>{}); number<1>{});
// gather is here // scatter is here
auto o_scatter_view_ = transform_tensor_view( auto o_scatter_view_ = transform_tensor_view(
o_view_, o_view_,
make_tuple(make_indexing_transform(kargs.num_tokens, token_id), make_tuple(make_indexing_transform(kargs.num_tokens, token_id),
......
...@@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General ...@@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA()
{ {
// matrix a or tokens smem // matrix a or tokens smem
constexpr index_t smem_mat_a = return Policy::template GetSmemSize_A<Problem>();
BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType);
return smem_mat_a;
} }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General ...@@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General
CK_TILE_LDS_ADDR void* smem, CK_TILE_LDS_ADDR void* smem,
index_t hidden_size, index_t hidden_size,
index_t /*intermediate_size*/, index_t /*intermediate_size*/,
CWindow& c_window_) CWindow& /*c_window_*/)
{ {
ignore = c_window_;
ignore = hidden_size;
ignore = w_window_;
CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem); CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast<CK_TILE_LDS_ADDR ADataType*>(smem);
CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>( CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast<CK_TILE_LDS_ADDR GDataType*>(
smem_0 + GetSmemSizeA() / sizeof(ADataType)); smem_0 + GetSmemSizeA() / sizeof(ADataType));
...@@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General ...@@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General
#if 0 #if 0
PrintMem(y_pre, "Y_pre", 0); PrintMem(y_pre, "Y_pre", 0);
#endif #endif
if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) // if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
{ // {
block_sync_lds(); // block_sync_lds();
store_tile(c_window_, y_pre); // store_tile(c_window_, y_pre);
} // }
// save to lds // save to lds
auto bridge_lds_view = make_tensor_view<address_space_enum::lds>( auto bridge_lds_view = make_tensor_view<address_space_enum::lds>(
smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>()); smem_0, Policy::template MakeBridgeLdsBlockDesc<Problem>());
......
...@@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy ...@@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy
make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
// constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
// make_tuple(number<Block_M>{}, number<Block_K>{}),
// make_tuple(number<Block_K>{}, number<1>{}),
// number<8>{},
// number<1>{});
return a_lds_block_desc; return a_lds_block_desc;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment