Commit 824809c1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes after merge.

parent 4085e3d0
......@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using BaseGemmPipeline =
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType,
CDataType,
GemmShape,
ALayout,
BLayout,
CLayout>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
......@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CDataType,
GemmShape,
ALayout,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC,
Traits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
......
......@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[])
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>(
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
......
......@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
......@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1
{
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<AccDataType, typename CBlockTensor::DataType>,
std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!");
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
......@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<AccDataType>(c_block_dstr);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
......
......@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
std::is_same_v<typename Problem::CDataType, float>)
{
#if 0
constexpr index_t kBlockSize = Problem::kBlockSize;
......@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
......
......@@ -4,7 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
......@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy>
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{
using Base = BaseGemmPipelineAgBgCrMem<Problem>;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadA_,
bool kPadB_,
bool kPadC_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
typename ALayout_,
typename BLayout_,
typename CLayout_>
struct TileGemmTraits
{
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
using LayoutC = LayoutC_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
};
} // namespace ck_tile
......@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using BaseGemmPipeline =
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType,
CDataType,
GemmShape,
ALayout,
BLayout,
CLayout>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
......@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CDataType,
GemmShape,
ALayout,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC,
Traits,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment