"research/adversarial_logit_pairing/adversarial_attack.py" did not exist on "a8ba923c873f9848d0f6453f3e2e3fa2dd1187dc"
Commit f3e5a74e authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

Gemm Kernel Refactor part1

parent feb9a2bd
......@@ -28,6 +28,10 @@ struct GemmKernel
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
static constexpr auto I0 = number<0>();
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
return TilePartitioner::GridSize(M, N, KBatch);
......@@ -139,13 +143,11 @@ struct GemmKernel
return true;
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
CK_TILE_DEVICE auto make_gemm_tensor_views(const ADataType* a_start,
const BDataType* b_start,
CDataType* c_start,
const GemmCommonKargs& kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
// Convert pointers to tensor views
auto a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
......@@ -188,7 +190,43 @@ struct GemmKernel
}
}();
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
}
CK_TILE_DEVICE void operator()(GemmCommonKargs kargs) const
{
const auto [i_m, i_n] = TilePartitioner{}();
// options
const ADataType* a_start = static_cast<const ADataType*>(kargs.a_ptr);
const BDataType* b_start = static_cast<const BDataType*>(kargs.b_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
// Convert pointers to tensor views
const auto gemm_tensor_views_tuple =
make_gemm_tensor_views(a_start, b_start, c_start, kargs);
auto a_pad_view = [&]() {
auto a_tensor_view = gemm_tensor_views_tuple.at(I0);
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......@@ -212,6 +250,7 @@ struct GemmKernel
{i_m, 0});
auto b_pad_view = [&]() {
auto b_tensor_view = gemm_tensor_views_tuple.at(I1);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(
......@@ -242,29 +281,8 @@ struct GemmKernel
auto c_block_tile =
GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
CDataType* c_start = static_cast<CDataType*>(kargs.c_ptr);
auto c_tensor_view = [&]() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1),
number<GemmPipeline::VectorSizeC>{},
number<1>{});
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
c_start,
make_tuple(kargs.M, kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
auto c_pad_view = [&]() {
auto c_tensor_view = gemm_tensor_views_tuple.at(I2);
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment