Commit da8e50dd authored by Jakub Piasecki's avatar Jakub Piasecki
Browse files

tmp save

parent dd21c599
...@@ -9,6 +9,8 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -9,6 +9,8 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP32 = float; using FP32 = float;
using FP16 = ck_tile::half_t; using FP16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t; using BF16 = ck_tile::bf16_t;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s) float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::stream_config& s)
{ {
...@@ -27,7 +29,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -27,7 +29,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
// clang-format off // clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s); return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on // clang-format on
} }
} }
...@@ -44,7 +46,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -44,7 +46,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
// clang-format off // clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s); return gemm_<gemm_traits_< FP16, FP16, FP32, FP16, Row, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on // clang-format on
} }
} }
...@@ -102,7 +104,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -102,7 +104,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
// clang-format off // clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s); return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on // clang-format on
} }
} }
...@@ -119,7 +121,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -119,7 +121,7 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
{ {
// clang-format off // clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK // ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(a, s); return gemm_<gemm_traits_< BF16, BF16, FP32, BF16, Row, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(a, s);
// clang-format on // clang-format on
} }
} }
...@@ -162,6 +164,81 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile:: ...@@ -162,6 +164,81 @@ float gemm(const gemm_traits& t, const ck_tile::GemmHostArgs& a, const ck_tile::
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n"); throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
} }
} }
else if(t.data_type.compare("fp8") == 0)
{
if(t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Row, Row, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Row, Row, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
}
else if(t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Row, Col, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Row, Col, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Col, Row, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Col, Row, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
}
else if(!t.is_a_rowmajor && !t.is_b_rowmajor && t.is_c_rowmajor)
{
if(a.M > 512)
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Col, Col, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
else
{
// clang-format off
// ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout, M_Tile, N_Tile, K_Tile, M_Warp, N_Warp, K_Warp, M_Warp_Tile, N_Warp_Tile, K_Warp_Tile, PadM, PadN, PadK
return gemm_<gemm_traits_< FP8, FP8, FP32, FP8, Col, Col, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(a, s);
// clang-format on
}
}
else
{
throw std::runtime_error("Wrong! ColumnMajor layout not supported for C Matrix!\n");
}
}
else else
{ {
throw std::runtime_error("Wrong! DataTypes not supported!\n"); throw std::runtime_error("Wrong! DataTypes not supported!\n");
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Col, Row, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Col, Col, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Row, Row, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_comp_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Row, Col, Row, 256, 256, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
...@@ -6,5 +6,5 @@ ...@@ -6,5 +6,5 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
// clang-format off // clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&); template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on // clang-format on
...@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; ...@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&); template float gemm_<gemm_traits_<ck_tile::bf16_t, ck_tile::bf16_t, float, ck_tile::bf16_t, Row, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on // clang-format on
...@@ -6,5 +6,5 @@ ...@@ -6,5 +6,5 @@
using Row = ck_tile::tensor_layout::gemm::RowMajor; using Row = ck_tile::tensor_layout::gemm::RowMajor;
// clang-format off // clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&); template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Row, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on // clang-format on
...@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; ...@@ -7,5 +7,5 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Col, Row, 128, 32, 64, 4, 1, 1, 32, 32, 8, false, false, false>>(const A&, const S&); template float gemm_<gemm_traits_<ck_tile::half_t, ck_tile::half_t, float, ck_tile::half_t, Row, Col, Row, 128, 128, 32, 2, 2, 1, 32, 32, 8, false, false, false>>(const A&, const S&);
// clang-format on // clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Col, Row, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Col, Col, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Row, Row, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_universal_mem_instance_common.hpp"
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
template float gemm_<gemm_traits_<ck_tile::fp8_t, ck_tile::fp8_t, float, ck_tile::fp8_t, Row, Col, Row, 128, 128, 64, 2, 2, 1, 32, 32, 16, false, false, false>>(const A&, const S&);
// clang-format on
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment