Commit d5c5d2a3 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers comments

parent c2945b96
......@@ -18,188 +18,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<FP16,
FP16,
FP32,
FP16,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else
......@@ -213,188 +93,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{
if(a.M > 512)
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
return gemm_<gemm_traits_<BF16,
BF16,
FP32,
BF16,
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on
}
}
else
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
auto create_args(int argc, char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment