Commit d5c5d2a3 authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

add reviewers comments

parent c2945b96
...@@ -18,188 +18,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -18,188 +18,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
FP16, // clang-format on
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
FP16, // clang-format on
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
FP16, // clang-format on
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
FP16, // clang-format on
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
FP16, // clang-format on
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
FP16, // clang-format on
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
FP16, // clang-format on
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<FP16, // clang-format off
FP16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
FP16, // clang-format on
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else else
...@@ -213,188 +93,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -213,188 +93,68 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
BF16, // clang-format on
Row,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
BF16, // clang-format on
Row,
Row,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
BF16, // clang-format on
Row,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s);
BF16, // clang-format on
Row,
Col,
Row,
128,
32,
64,
4,
1,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Row, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
BF16, // clang-format on
Col,
Row,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
BF16, // clang-format on
Col,
Row,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor) else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{ {
if(a.M > 512) if(a.M > 512)
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Col, Row, 256, 256, 32, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
BF16, // clang-format on
Col,
Col,
Row,
256,
256,
32,
2,
2,
1,
32,
32,
16,
false,
false,
false>>(a, s);
} }
else else
{ {
return gemm_<gemm_traits_<BF16, // clang-format off
BF16, // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
FP32, return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Col, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
BF16, // clang-format on
Col,
Col,
Row,
128,
128,
32,
2,
2,
1,
32,
32,
8,
false,
false,
false>>(a, s);
} }
} }
else else
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
auto create_args(int argc, char* argv[]) auto create_args(int argc, char* argv[])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment