Unverified Commit aa61ccf0 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into hip_tensor_permute

parents 4498e2a1 bec84efb
...@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -127,7 +127,50 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
PipelineVer, PipelineVer,
ComputeType>; ComputeType>;
using Argument = typename GridwiseGemm::Argument; struct Argument : public GridwiseGemm::Argument
{
Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t MPadded_,
index_t NPadded_,
index_t KPadded_,
index_t K0_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: GridwiseGemm::Argument(p_a_grid_,
p_b_grid_,
p_c_grid_,
M_,
N_,
K_,
StrideA_,
StrideB_,
StrideC_,
MPadded_,
NPadded_,
KPadded_,
K0_,
k_batch_),
a_element_op(a_element_op_),
b_element_op(b_element_op_),
c_element_op(c_element_op_)
{
}
AElementwiseOperation a_element_op;
BElementwiseOperation b_element_op;
CElementwiseOperation c_element_op;
};
using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap; using DefaultBlock2CTileMap = typename GridwiseGemm::DefaultBlock2CTileMap;
// Invoker // Invoker
...@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -168,8 +211,17 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
karg.M * karg.N * sizeof(CDataType), karg.M * karg.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
ave_time = launch_and_time_kernel( ave_time =
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg, b2c_map); launch_and_time_kernel(stream_config,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
static_cast<typename GridwiseGemm::Argument>(karg),
b2c_map,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
}; };
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
...@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -180,7 +232,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -190,7 +245,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -203,7 +261,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -213,7 +274,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm, kernel_gemm_xdlops_v2r4r2_simplified<GridwiseGemm,
false, false,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
DefaultBlock2CTileMap>; DefaultBlock2CTileMap,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>;
Run(kernel); Run(kernel);
} }
...@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -261,12 +325,12 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation, CElementwiseOperation c_element_op,
index_t KBatch) index_t KBatch)
{ {
return Argument{p_a, return Argument(p_a,
p_b, p_b,
p_c, p_c,
M, M,
...@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -279,7 +343,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0(K, KBatch),
KBatch}; KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
static auto MakeInvoker() { return Invoker{}; } static auto MakeInvoker() { return Invoker{}; }
...@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -294,9 +361,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t StrideA, index_t StrideA,
index_t StrideB, index_t StrideB,
index_t StrideC, index_t StrideC,
AElementwiseOperation, AElementwiseOperation a_element_op,
BElementwiseOperation, BElementwiseOperation b_element_op,
CElementwiseOperation, CElementwiseOperation c_element_op,
ck::index_t KBatch = 1) override ck::index_t KBatch = 1) override
{ {
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
...@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout, ...@@ -312,7 +379,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm::CalculateNPadded(N), GridwiseGemm::CalculateNPadded(N),
GridwiseGemm::CalculateKPadded(K, KBatch), GridwiseGemm::CalculateKPadded(K, KBatch),
GridwiseGemm::CalculateK0(K, KBatch), GridwiseGemm::CalculateK0(K, KBatch),
KBatch); KBatch,
a_element_op,
b_element_op,
c_element_op);
} }
// polymorphic // polymorphic
......
...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle ...@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseGemm, GridwiseGemm,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...@@ -22,32 +23,6 @@ namespace ck { ...@@ -22,32 +23,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatAB, typename FloatAB,
typename FloatC, typename FloatC,
...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
// element-wise op // element-wise op
OutElementwiseOperation a_element_op_; OutElementwiseOperation a_element_op_;
...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa ...@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>, remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>, remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop, has_main_loop,
has_double_loop>; has_double_loop>;
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp" #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -21,32 +22,6 @@ namespace ck { ...@@ -21,32 +22,6 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace {
struct ComputePtrOffsetOfStridedBatch
{
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
}
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideC_);
}
index_t BatchStrideA_;
index_t BatchStrideB_;
index_t BatchStrideC_;
};
} // namespace
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename FloatA, typename FloatA,
typename FloatB, typename FloatB,
...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
// for computing batch offset // for computing batch offset
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>, remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
remove_reference_t<DeviceOp::Block2CTileMap>, remove_reference_t<DeviceOp::Block2CTileMap>,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch<I0>,
has_main_loop>; has_main_loop>;
return launch_and_time_kernel(stream_config, return launch_and_time_kernel(stream_config,
...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle ...@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(!ck::is_xdl_supported())
{
return false;
}
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
if constexpr(!is_GNWK_GKXC_GNWC) if constexpr(!is_GNWK_GKXC_GNWC)
......
...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
auto launch_kernel = [&](auto has_main_k_block_loop) { auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value; constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle< const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
GridwiseOp, GridwiseOp,
ADataType, ADataType,
BDataType, BDataType,
......
...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
return ds_offset; return ds_offset;
} }
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const [[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{ {
return g_idx * static_cast<long_index_t>(BatchStrideE_); return g_idx * static_cast<long_index_t>(BatchStrideE_);
} }
...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch ...@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
index_t BatchStrideB_; index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_; Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_; index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
}; };
} // namespace device } // namespace device
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_AMD_XDLOPS_HPP #pragma once
#define CK_AMD_XDLOPS_HPP
#include "data_type.hpp"
namespace ck { namespace ck {
...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -355,7 +352,6 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
} }
}; };
#if defined CK_ENABLE_FP8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8; struct intrin_mfma_f32_32x32x16f8f8;
...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16> ...@@ -418,9 +414,7 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8bf8; struct intrin_mfma_f32_32x32x16bf8bf8;
...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16> ...@@ -483,9 +477,7 @@ struct intrin_mfma_f32_16x16x32bf8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8bf8; struct intrin_mfma_f32_32x32x16f8bf8;
...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16> ...@@ -548,9 +540,7 @@ struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
#endif #endif
} }
}; };
#endif
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
template <index_t MPerWave, index_t NPerWave> template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16bf8f8; struct intrin_mfma_f32_32x32x16bf8f8;
...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> ...@@ -613,6 +603,5 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16>
#endif #endif
} }
}; };
#endif
} // namespace ck } // namespace ck
#endif
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment