Commit 824809c1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fixes after merge.

parent 4085e3d0
...@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using BaseGemmPipeline = using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType, using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
CDataType, ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
GemmShape,
ALayout,
BLayout,
CLayout>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
...@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType,
GemmShape, GemmShape,
ALayout, Traits,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC,
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
......
...@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[]) ...@@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[])
c_m_n_gpu_ref.SetZero(); c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero();
ck_tile::reference_gemm_gpu<ADataType, BDataType, AccDataType, CDataType>( ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C);
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
......
...@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>; using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1
{ {
static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> && static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindow::DataType> && std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<AccDataType, typename CBlockTensor::DataType>, std::is_same_v<CDataType, typename CBlockTensor::DataType>,
"wrong!"); "wrong!");
constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
...@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<AccDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
} }
......
...@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
{ {
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> && if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> && std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>) std::is_same_v<typename Problem::CDataType, float>)
{ {
#if 0 #if 0
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
...@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy ...@@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
} }
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> && else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> && std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>) std::is_same_v<typename Problem::CDataType, float>)
{ {
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
} }
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem
// LocalPreFillStages: 1 // LocalPreFillStages: 1
// LocalPreFetchStages: 0 // LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1 // LocalSharedMemoryBuffer: 1
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy>
struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrMem<Problem>; using Base = BaseGemmPipelineAgBgCrMem<Problem>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <bool kPadA_, template <bool kPadA_,
bool kPadB_, bool kPadB_,
bool kPadC_, bool kPadC_,
typename LayoutA_, typename ALayout_,
typename LayoutB_, typename BLayout_,
typename LayoutC_> typename CLayout_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadA = kPadA_; static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadC = kPadC_;
using LayoutA = LayoutA_; using ALayout = ALayout_;
using LayoutB = LayoutB_; using BLayout = BLayout_;
using LayoutC = LayoutC_; using CLayout = CLayout_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>;
using BaseGemmPipeline = using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>;
ck_tile::BaseGemmPipelineAgBgCrMem<ck_tile::BlockGemmPipelineProblem<ADataType,
BDataType, using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
CDataType, ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
GemmShape,
ALayout,
BLayout,
CLayout>>;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
...@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType, AccDataType,
CDataType,
GemmShape, GemmShape,
ALayout, Traits,
BLayout,
CLayout,
kPadA,
kPadB,
kPadC,
ck_tile::GemmPipelineScheduler::Intrawave, ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v, has_hot_loop_v,
tail_number_v>>; tail_number_v>>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment