Unverified Commit 694c3001 authored by Thomas Ning's avatar Thomas Ning Committed by GitHub
Browse files

Ck tile gemm padding dim (#1516)

* Support the N dimension padding

* Finished the padding feature for different dimension of K
parent e84adec3
...@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf, ...@@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf,
std::cout << "The overall perfomance of the GEMM with " std::cout << "The overall perfomance of the GEMM with "
<< "[" << data_type << "]" << "[" << data_type << "]"
<< "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K << "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K
<< "is: \n"; << " is: \n";
std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n"
<< std::flush; << std::flush;
return ave_time; return ave_time;
...@@ -235,7 +235,7 @@ int main(int argc, char* argv[]) ...@@ -235,7 +235,7 @@ int main(int argc, char* argv[])
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadA = true;
constexpr bool kPadB = true; constexpr bool kPadB = true;
constexpr bool kPadC = false; constexpr bool kPadC = true;
// This part comes from the Codegen // This part comes from the Codegen
constexpr ck_tile::index_t M_Tile = 128; constexpr ck_tile::index_t M_Tile = 128;
...@@ -348,7 +348,7 @@ int main(int argc, char* argv[]) ...@@ -348,7 +348,7 @@ int main(int argc, char* argv[])
pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref);
std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail") std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail")
<< std::flush; << std::flush;
} }
......
...@@ -123,14 +123,26 @@ struct GemmKernel ...@@ -123,14 +123,26 @@ struct GemmKernel
} }
}(); }();
auto ABlockWindow = make_tile_window( auto a_pad_view = pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadA ? 1 : 0 > {});
auto ABlockWindow = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto BBlockWindow = make_tile_window( auto b_pad_view = pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < 0,
GemmPipeline::kPadB ? 1 : 0 > {});
auto BBlockWindow = make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
{i_n, 0}); {i_n, 0});
// allocate LDS // allocate LDS
...@@ -163,12 +175,16 @@ struct GemmKernel ...@@ -163,12 +175,16 @@ struct GemmKernel
} }
}(); }();
auto CBlockWindow = make_tile_window( auto c_pad_view = pad_tensor_view(
c_tensor_view, c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < 0,
GemmPipeline::kPadC ? 1 : 0 > {});
auto CBlockWindow_pad = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
// epilogue. EpiloguePipeline{}(CBlockWindow_pad, acc);
EpiloguePipeline{}(CBlockWindow, acc);
} }
}; };
......
...@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 ...@@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1
static constexpr index_t AlignmentB = Problem::AlignmentB; static constexpr index_t AlignmentB = Problem::AlignmentB;
static constexpr index_t AlignmentC = Problem::AlignmentC; static constexpr index_t AlignmentC = Problem::AlignmentC;
static constexpr bool kPadA = Problem::kPadA;
static constexpr bool kPadB = Problem::kPadB;
static constexpr bool kPadC = Problem::kPadC;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{ {
return ck_tile::integer_divide_ceil( return ck_tile::integer_divide_ceil(
......
...@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem ...@@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadC = kPadC_;
static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1; static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1; static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1; static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType);
}; };
} // namespace ck_tile } // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment