Commit ff5115be authored by Aleksander Dudek's avatar Aleksander Dudek
Browse files

[CK_TILE] Add GetName for grouped gemm

parent 0c4cf86e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
......@@ -91,7 +91,7 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
......@@ -128,7 +128,7 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:"
std::cout << "Launching kernel: " << GroupedGemmKernel::GetName() << " with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -57,6 +57,30 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using BLayout = typename Base::BLayout;
using CLayout = typename Base::CLayout;
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using P_ = GemmPipeline;
auto prec_str = [&] () {
std::string base_str = _SS_(Base::template t2s<ADataType>::name);
if (!std::is_same_v<ADataType, BDataType>) {
base_str += _SS_("_") + _SS_(Base::template t2s<BDataType>::name);
}
return base_str;
}();
return _SS_("gemm_batched_") + _SS_(prec_str) + "_" +
_TS_(P_::kMPerBlock) + "x" + _TS_(P_::kNPerBlock) + "x" + _TS_(P_::kKPerBlock) + "_" +
_TS_(P_::VectorSizeA) + "x" + _TS_(P_::VectorSizeB) + "x" + _TS_(P_::VectorSizeC) + "_" +
_TS_(P_::kPadM) + "x" + _TS_(P_::kPadN) + "x" + _TS_(P_::kPadK);
#undef _SS_
#undef _TS_
// clang-format on
}
struct BatchedGemmKernelArgs : GemmKernelArgs
{
index_t batch_stride_A;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -44,6 +44,40 @@ struct GroupedGemmKernel
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<bf8_t> { static constexpr const char * name = "bf8"; };
template <> struct t2s<int8_t> { static constexpr const char * name = "int8"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using P_ = GemmPipeline;
auto prec_str = [&] () {
std::string base_str = _SS_(t2s<ADataType>::name);
if (!std::is_same_v<ADataType, BDataType>) {
base_str += _SS_("_") + _SS_(t2s<BDataType>::name);
}
return base_str;
}();
return _SS_("gemm_grouped_") + _SS_(prec_str) + "_" +
_TS_(P_::kMPerBlock) + "x" + _TS_(P_::kNPerBlock) + "x" + _TS_(P_::kKPerBlock) + "_" +
_TS_(P_::VectorSizeA) + "x" + _TS_(P_::VectorSizeB) + "x" + _TS_(P_::VectorSizeC) + "_" +
_TS_(P_::kPadM) + "x" + _TS_(P_::kPadN) + "x" + _TS_(P_::kPadK);
#undef _SS_
#undef _TS_
// clang-format on
}
struct GemmTransKernelArg
{
GroupedGemmHostArgs group_karg;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment