Commit d20c20a6 authored by Mirza Halilcevic's avatar Mirza Halilcevic
Browse files

Merge remote-tracking branch 'upstream/develop' into gemm_elementwise_gemm

parents 250a89f3 10158b0f
......@@ -4,15 +4,15 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct BlockGemmPipelineAGmemBGmemCRegV2
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
struct GemmPipelineAGmemBGmemCRegV2
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
......
......@@ -7,12 +7,11 @@
namespace ck_tile {
// Default policy for BlockGemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy =
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy;
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
} // namespace ck_tile
......@@ -13,20 +13,23 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
bool kPadA_ = false,
bool kPadB_ = false,
bool kPadC_ = false>
struct BlockGemmPipelineProblem
typename TileGemmTraits_>
struct GemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
static constexpr bool kPadA = GemmTraits::kPadA;
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC;
using LayoutA = remove_cvref_t<typename GemmTraits::LayoutA>;
using LayoutB = remove_cvref_t<typename GemmTraits::LayoutB>;
using LayoutC = remove_cvref_t<typename GemmTraits::LayoutC>;
static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType);
static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadA_,
bool kPadB_,
bool kPadC_,
typename LayoutA_,
typename LayoutB_,
typename LayoutC_>
struct TileGemmTraits
{
static constexpr bool kPadA = kPadA_;
static constexpr bool kPadB = kPadB_;
static constexpr bool kPadC = kPadC_;
using LayoutA = LayoutA_;
using LayoutB = LayoutB_;
using LayoutC = LayoutC_;
};
} // namespace ck_tile
......@@ -14,17 +14,21 @@ template <typename XDataType_,
typename YDataType_,
typename MeanDataType_,
typename InvStdDataType_,
typename BlockShape_>
typename BlockShape_,
bool kPadM_,
bool kPadN_>
struct BlockLayernorm2dFwdProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using XDataType = remove_cvref_t<XDataType_>;
using GammaDataType = remove_cvref_t<GammaDataType_>;
using BetaDataType = remove_cvref_t<BetaDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using MeanDataType = remove_cvref_t<MeanDataType_>;
using InvStdDataType = remove_cvref_t<InvStdDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename ComputeTypeA,
typename ComputeTypeB>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
naive_gemm_kernel(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
CDataType* __restrict__ p_c_grid,
index_t m,
index_t n,
index_t k,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation c_element_op)
{
using RowMajor = ck::tensor_layout::gemm::RowMajor;
const int row_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
if(row_idx < m && col_idx < n)
{
AccDataType v_acc = static_cast<AccDataType>(0.0);
ComputeTypeA v_a = static_cast<ComputeTypeA>(0.0);
ComputeTypeB v_b = static_cast<ComputeTypeB>(0.0);
CDataType v_c = static_cast<CDataType>(0.0);
for(int k_idx = 0; k_idx < k; ++k_idx)
{
// check input matrices layout
int element_idx_a = 0;
int element_idx_b = 0;
if constexpr(std::is_same_v<ALayout, RowMajor>)
{
element_idx_a = row_idx * k + k_idx;
}
else
{
element_idx_a = row_idx + m * k_idx;
}
if constexpr(std::is_same_v<BLayout, RowMajor>)
{
element_idx_b = k_idx * n + col_idx;
}
else
{
element_idx_b = k_idx + k * col_idx;
}
// apply a_element_op
a_element_op(v_a, p_a_grid[element_idx_a]);
// apply b_element_op
b_element_op(v_b, p_b_grid[element_idx_b]);
// multiply and accumulate
v_acc += static_cast<AccDataType>(v_a) * static_cast<AccDataType>(v_b);
}
// apply c_element_op
c_element_op(v_c, v_acc);
// check output matrix layout
int element_idx_c = 0;
if constexpr(std::is_same_v<CLayout, RowMajor>)
{
element_idx_c = row_idx * n + col_idx;
}
else
{
element_idx_c = row_idx + m * col_idx;
}
// prepare output
p_c_grid[element_idx_c] = v_c;
}
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
void* p_c_grid,
index_t m,
index_t n,
index_t k,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_c_grid_{static_cast<CDataType*>(p_c_grid)},
m_{m},
n_{n},
k_{k},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
index_t m_;
index_t n_;
index_t k_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceGemm::Argument;
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
int block_size = 16;
dim3 block_dim(block_size, block_size, 1);
dim3 grid_dim(
(arg.m_ + block_size - 1) / block_size, (arg.n_ + block_size - 1) / block_size, 1);
auto launch_kernel = [&]() {
const auto kernel = naive_gemm_kernel<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeTypeA,
ComputeTypeB>;
return launch_and_time_kernel(stream_config,
kernel,
grid_dim,
block_dim,
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.m_,
arg.n_,
arg.k_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
};
return launch_kernel();
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const void* p_a_grid,
const void* p_b_grid,
void* p_c_grid,
index_t m,
index_t n,
index_t k,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{
p_a_grid, p_b_grid, p_c_grid, m, n, k, a_element_op, b_element_op, c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "Device Reference Gemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach()
endif()
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Do not build DL instances if DL_KERNELS macro is not set
foreach(source IN LISTS ARGN)
......@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
endforeach()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94 or gfx90a targets are not on the target list
foreach(source IN LISTS ARGN)
if(NOT INST_TARGETS MATCHES "gfx94" AND source MATCHES "mha")
if(NOT INST_TARGETS MATCHES "gfx94" AND NOT INST_TARGETS MATCHES "gfx90a" AND source MATCHES "mha")
message("removing mha instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
......@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if(ARGN)
set(INST_OBJ)
foreach(source IN LISTS ARGN)
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
if(source MATCHES "_xdl")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
elseif(ARGN MATCHES "mha")
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201)
endif()
set(offload_targets)
foreach(target IN LISTS INST_TARGETS)
......@@ -191,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set(add_inst 1)
endif()
if(INSTANCES_ONLY)
set(INST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(INST_TARGETS ${GPU_TARGETS})
endif()
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
if(("${cmake_instance}" MATCHES "quantization") AND (DEFINED DTYPES) AND (NOT DTYPES MATCHES "int8"))
message("quantization instances will not be built!")
......@@ -320,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif()
if(CK_DEVICE_MHA_INSTANCES)
set(gpu_list ${INST_TARGETS})
list(FILTER gpu_list INCLUDE REGEX "^gfx94")
if(gpu_list)
if(gpu_list MATCHES "gfx94" OR gpu_list MATCHES "gfx90a")
add_library(device_mha_operations STATIC ${CK_DEVICE_MHA_INSTANCES})
add_library(composablekernels::device_mha_operations ALIAS device_mha_operations)
target_compile_features(device_mha_operations PUBLIC)
......
This diff is collapsed.
......@@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then
GPU_TARGETS=$2
REST_ARGS=${@:3}
shift 2
REST_ARGS=$@
else
GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
......
......@@ -7,7 +7,8 @@ MY_PROJECT_SOURCE=$1
if [ $# -ge 2 ] ; then
GPU_TARGETS=$2
REST_ARGS=${@:3}
shift 2
REST_ARGS=$@
else
GPU_TARGETS="gfx908;gfx90a;gfx940"
REST_ARGS=
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment