Commit ad65dfe7 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop' into amd-master

parents 2a27d15c 15baccf2
......@@ -509,7 +509,7 @@ struct DeviceFpAintBGemm_Wmma_CShuffle : public DeviceGemm_dequantB<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
......@@ -536,7 +536,7 @@ struct DeviceGemmDl : public DeviceGemm<ALayout,
}
if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported())
ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
......
......@@ -50,8 +50,9 @@ __global__ void
const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap block_2_ctile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
constexpr index_t shared_block_size =
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
......@@ -552,7 +553,7 @@ struct DeviceGemmMultipleD_Dl : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported())
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
return GridwiseGemm::CheckValidity(
arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.e_grid_desc_m_n_);
......
......@@ -515,7 +515,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -84,14 +84,21 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
// K1 = Max Vector Access Pixels
static constexpr auto K1Number = Number<K1>{};
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto AEnableLds_auto =
(NWaves == 1 && is_same<tensor_layout::gemm::RowMajor, ALayout>::value) ? false : true;
static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
static constexpr auto MaxVectorLoadA = K1 * sizeof(ADataType) == 16 ? true : false;
static constexpr auto MaxVectorLoadB = K1 * sizeof(BDataType) == 16 ? true : false;
static constexpr auto AEnableLds_auto = (NWaves == 1 && (MaxVectorLoadA || MRepeat == 1) &&
is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
? false
: true;
static constexpr auto BEnableLds_auto =
(MWaves == 1 && is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value) ? false : true;
(MWaves == 1 && (MaxVectorLoadB || NRepeat == 1) &&
is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
? false
: true;
// If true, LDS is used unconditionally
static constexpr auto AEnableLds_manu = false;
......@@ -443,7 +450,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static bool IsSupportedArgument(const Argument& arg)
{
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, ck::half_t> ||
is_same_v<AccDataType, int32_t>))
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_streamk_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_streamk_v3<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
float ave_time = 0;
index_t k_grain = KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
hipGetErrorString(hipMemsetAsync(
arg.p_c_grid, 0, arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
const auto Run = [&](const auto& kernel) {
dim3 grid_dim;
if(arg.Grid_size < 0)
{
int occupancy, num_cu;
hipError_t rtn;
rtn = hipOccupancyMaxActiveBlocksPerMultiprocessor(
&occupancy, kernel, BlockSize, 0);
hip_check_error(rtn);
hipDeviceProp_t dev_prop;
hipDevice_t dev;
rtn = hipGetDevice(&dev);
hip_check_error(rtn);
rtn = hipGetDeviceProperties(&dev_prop, dev);
hip_check_error(rtn);
num_cu = dev_prop.multiProcessorCount;
arg.Grid_size = num_cu * occupancy;
grid_dim = arg.Grid_size;
}
else
grid_dim = arg.Grid_size;
if(stream_config.flush_cache)
{
Argument arg_ = arg;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_,
stream_config.rotating_count,
arg_.M * arg_.K * sizeof(ADataType),
arg_.K * arg_.N * sizeof(BDataType));
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config, run_flush_cache, kernel, grid_dim, dim3(BlockSize), 0, arg_);
}
else
{
ave_time = launch_and_time_kernel(
stream_config, kernel, grid_dim, dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation)
{
return Argument{
p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, streamk_sel, Grid_size}; // HS
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t streamk_sel,
index_t Grid_size,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
streamk_sel,
Grid_size);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -629,7 +629,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -93,12 +93,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......
......@@ -48,18 +48,19 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
......
......@@ -66,12 +66,12 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
......@@ -47,24 +47,24 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -103,12 +103,12 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t num_k_per_block)
[[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
[[maybe_unused]] const index_t num_k_per_block)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
......@@ -116,12 +116,12 @@ __global__ void
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumBatchToMerge);
const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......@@ -1268,7 +1268,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
arg.Conv_G_;
std::array<index_t, I1> in_out_batch_strides = {
arg.compute_ptr_offset_of_batch_.BatchStrideC_};
static_cast<index_t>(arg.compute_ptr_offset_of_batch_.BatchStrideC_)};
const auto kernel = kernel_batched_elementwise<GridwiseElementwise,
ck::Tuple<CElementwiseGridDesc_M_N>,
......
......@@ -692,7 +692,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg)
{
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -61,12 +61,9 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
const long_index_t a_batch_offset = compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
const long_index_t b_batch_offset = compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
const long_index_t c_batch_offset = compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
__shared__ FloatA p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatA)];
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -90,19 +90,20 @@ __global__ void
const Block2CTileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......@@ -266,7 +267,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -312,8 +314,8 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -666,7 +668,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
ck::is_gfx103_supported() || ck::is_gfx11_supported()))
ck::is_gfx103_supported() || ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -107,7 +107,7 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
defined(__gfx11__))
defined(__gfx11__) || defined(__gfx12__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......@@ -263,7 +263,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -310,8 +311,8 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
const std::array<index_t, NDimSpatial + 3>& c_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(c_g_n_k_wos_lengths,
c_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<CLay>(
c_g_n_k_wos_lengths, c_g_n_k_wos_strides, c_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -602,7 +603,7 @@ struct DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK : public DeviceGroupedConvFwd<NDimS
// check device
if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
ck::is_gfx11_supported()))
ck::is_gfx11_supported() || ck::is_gfx12_supported()))
{
return false;
}
......
......@@ -69,7 +69,8 @@ template <typename GridwiseGemm,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffsetOfG,
typename ComputePtrOffsetOfN,
bool HasMainKBlockLoop,
bool isMultiA,
bool isMultiB>
......@@ -85,7 +86,7 @@ __global__ void
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const index_t batch_count,
const index_t groups_count,
const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
......@@ -93,18 +94,24 @@ __global__ void
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_,
const Block2ETileMap block_2_ctile_map,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
const ComputePtrOffsetOfG compute_ptr_offset_of_groups,
const ComputePtrOffsetOfN compute_ptr_offset_of_n)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const auto& ds_batch_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto& ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -121,13 +128,28 @@ __global__ void
AsPointer p_as_grid_grp;
BsPointer p_bs_grid_grp;
const auto& as_batch_offset = compute_ptr_offset_of_batch.GetAsPtrOffset(g_idx);
const auto& as_batch_offset = compute_ptr_offset_of_groups.GetAsPtrOffset(g_idx);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
if constexpr(isMultiA)
{
const auto& as_n_offset = compute_ptr_offset_of_n.GetAsPtrOffset(n_idx);
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i]; });
static constexpr index_t NumATensor = AGridDesc_AK0_M_AK1::Size();
static_for<0, NumATensor, 1>{}([&](auto i) {
p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + as_n_offset[i];
});
}
else
{
const long_index_t a_n_offset = compute_ptr_offset_of_n.GetAPtrOffset(n_idx);
static_for<0, 1, 1>{}(
[&](auto i) { p_as_grid_grp(i) = p_as_grid[i] + as_batch_offset[i] + a_n_offset; });
}
const auto& bs_batch_offset = compute_ptr_offset_of_batch.GetBsPtrOffset(g_idx);
const auto& bs_batch_offset = compute_ptr_offset_of_groups.GetBsPtrOffset(g_idx);
static constexpr index_t NumBTensor = BGridDesc_BK0_N_BK1::Size();
static_for<0, NumBTensor, 1>{}(
......@@ -137,7 +159,7 @@ __global__ void
p_as_grid_grp,
p_bs_grid_grp,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -150,16 +172,19 @@ __global__ void
}
else
{
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
GridwiseGemm::template Run<HasMainKBlockLoop>(
p_as_grid + a_batch_offset,
p_as_grid + a_batch_offset + a_n_offset,
p_bs_grid + b_batch_offset,
p_ds_grid_grp,
p_e_grid + e_batch_offset,
p_e_grid + e_batch_offset + e_n_offset,
p_shared,
a_element_op,
b_element_op,
......@@ -175,7 +200,7 @@ __global__ void
ignore = p_bs_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = batch_count;
ignore = groups_count;
ignore = a_grid_desc_k0_m_k1;
ignore = b_grid_desc_k0_n_k1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
......@@ -183,7 +208,8 @@ __global__ void
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = compute_ptr_offset_of_batch;
ignore = compute_ptr_offset_of_groups;
ignore = compute_ptr_offset_of_n;
ignore = block_2_ctile_map;
#endif
}
......@@ -309,7 +335,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
......@@ -321,7 +348,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -347,11 +375,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -363,24 +392,25 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N(
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
const index_t Conv_N)
{
return generate_tuple(
[&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]);
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], Conv_N);
},
Number<NumDTensor>{});
}
// desc for problem definition
using AGridDesc_M_K = remove_cvref_t<decltype(MakeAGridDescriptor_M_K<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_N_K = remove_cvref_t<decltype(MakeBGridDescriptor_N_K<BLayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, 1))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
......@@ -468,6 +498,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
......@@ -477,12 +513,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
input_right_pads,
conv_N_per_block_)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K<BLayout>(b_g_k_c_xs_lengths,
b_g_k_c_xs_strides)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
a_grid_desc_ak0_m_ak1_{
GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{
......@@ -490,7 +527,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
compute_ptr_offset_of_batch_{},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -511,8 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
if constexpr(isMultiA || isMultiB)
{
static_for<0, NumATensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideA_(i) = a_g_n_c_wis_strides[0];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
......@@ -524,16 +562,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
// p_as is tuple
p_as_grid_(i) = static_cast<const DataType*>(p_as[i.value]);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
// BatchStrideA_ is not tuple.
compute_ptr_offset_of_n_.BatchStrideA_(i) =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_(i) = static_cast<const DataType*>(p_as);
compute_ptr_offset_of_n_.BatchStrideA_ =
a_g_n_c_wis_strides[1] * conv_N_per_block_;
}
});
static_for<0, NumBTensor, 1>{}([&](auto i) {
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_.BatchStrideB_(i) = b_g_k_c_xs_strides[0];
using DataType = remove_cvref_t<tuple_element_t<i.value, GemmBDataType>>;
// It is possible that one of the AB is a pointer and one is a tuple.
......@@ -553,8 +598,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers
p_as_grid_(I0) = static_cast<const ADataType*>(p_as);
......@@ -570,13 +616,16 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
// D batch stride
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides[i][1] * conv_N_per_block_;
// D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i], conv_N_per_block_);
});
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
// populate desc for Ds/E
if constexpr(isMultiA || isMultiB)
......@@ -638,6 +687,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// tensor descriptors for problem definiton
index_t num_group_;
index_t conv_N_per_block_;
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
......@@ -655,7 +706,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// for computing batch offset
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>
compute_ptr_offset_of_batch_;
compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
......@@ -689,8 +741,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.Print();
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_) * arg.num_group_;
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
const index_t gdx = arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
const index_t gdy = arg.num_group_ * num_workgroups_per_Conv_N;
const index_t gdz = 1;
const auto K =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
......@@ -721,6 +777,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
......@@ -728,7 +785,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_as_grid_,
......@@ -744,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
else
{
......@@ -763,6 +821,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
Block2ETileMap,
ComputePtrOffsetOfStridedBatch<NumATensor, NumBTensor, NumDTensor>,
ComputePtrOffsetOfStridedBatch<NumATensor, I1, NumDTensor>,
has_main_loop,
isMultiA,
isMultiB>;
......@@ -770,7 +829,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return launch_and_time_kernel(
stream_config,
kernel,
dim3(grid_size),
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg.p_as_grid_.At(I0), // Pass just A descriptor instead of tuple
......@@ -786,7 +845,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_,
arg.compute_ptr_offset_of_batch_);
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_);
}
};
......
......@@ -60,7 +60,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffset,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -71,24 +71,32 @@ __global__ void
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t batch_count)
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......@@ -97,9 +105,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared,
karg,
a_grid_desc_ak0_m_ak1,
......@@ -114,7 +122,7 @@ template <typename GridwiseGemm,
typename AGridDesc_AK0_M_K1,
typename BGridDesc_BK0_N_K1,
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename ComputePtrOffsetOfBatch,
typename ComputePtrOffset,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
......@@ -125,24 +133,32 @@ __global__ void
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds(
typename GridwiseGemm::Argument karg,
const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1,
[[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1,
[[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock,
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const index_t batch_count)
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_groups,
[[maybe_unused]] const ComputePtrOffset compute_ptr_offset_of_n,
[[maybe_unused]] const index_t groups_count)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / batch_count);
const index_t num_blocks_per_batch = __builtin_amdgcn_readfirstlane(gridDim.y / groups_count);
const index_t& num_blocks_per_n = groups_count;
const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_batch);
const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / num_blocks_per_n);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
const long_index_t e_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const long_index_t a_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
const long_index_t e_n_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
......@@ -154,9 +170,9 @@ __global__ void
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(karg.p_a_grid + a_batch_offset,
TailNum>(karg.p_a_grid + a_batch_offset + a_n_offset,
karg.p_b_grid + b_batch_offset,
karg.p_c_grid + e_batch_offset,
karg.p_c_grid + e_batch_offset + e_n_offset,
p_shared_0,
p_shared_1,
karg,
......@@ -294,7 +310,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads)
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t Conv_N)
{
const auto in_gemmmraw_gemmkraw_desc =
conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>(a_g_n_c_wis_lengths,
......@@ -306,7 +324,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
Conv_N);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -350,11 +369,13 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
template <typename ELay>
static auto
MakeEGridDescriptor_M_N(const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
const index_t Conv_N)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, Conv_N);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -363,7 +384,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
// desc for problem definition
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}))>;
using EGridDesc_M_N = remove_cvref_t<decltype(MakeEGridDescriptor_M_N<ELayout>({}, {}, 1))>;
#define GridwiseGemmV3TemplateParams \
tensor_layout::gemm::RowMajor, tensor_layout::gemm::ColumnMajor, \
......@@ -396,7 +417,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// desc for blockwise copy
using AGridDesc_AK0_M_AK1 = remove_cvref_t<decltype(MakeAGridDescriptor_AK0_M_AK1<ALayout>(
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
{}, {}, {}, {}, {}, {}, {}, {}, {}, {}, 1))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeBGridDescriptor_BK0_N_BK1<BLayout>({}, {}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
......@@ -429,6 +450,12 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
p_b_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e)},
num_group_{a_g_n_c_wis_lengths[0]},
conv_N_per_block_{
conv_to_gemm_transformer.template GetSplitedNSize<ADataType, EDataType>(
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
a_grid_desc_ak0_m_ak1_{MakeAGridDescriptor_AK0_M_AK1<ALayout>(a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
......@@ -438,13 +465,15 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads)},
input_right_pads,
conv_N_per_block_)},
b_grid_desc_bk0_n_bk1_{
MakeBGridDescriptor_BK0_N_BK1<BLayout>(b_g_k_c_xs_lengths, b_g_k_c_xs_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides)},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, conv_N_per_block_)},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
compute_ptr_offset_of_batch_{},
compute_ptr_offset_of_groups_{},
compute_ptr_offset_of_n_{},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
......@@ -459,15 +488,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
input_left_pads_{input_left_pads},
input_right_pads_{input_right_pads}
{
// A/B/E Batch Stride
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
// p_as and p_bs are pointers
p_a_grid_ = static_cast<const ADataType*>(p_as);
p_b_grid_ = static_cast<const BDataType*>(p_bs);
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
......@@ -488,6 +519,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// tensor descriptors for problem definiton
index_t num_group_;
index_t conv_N_per_block_;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
......@@ -496,7 +528,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// for computing batch offset
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_batch_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_groups_;
ComputePtrOffsetOfStridedBatch<I1, I1, I0> compute_ptr_offset_of_n_;
// element-wise op
AElementwiseOperation a_element_op_;
......@@ -538,11 +571,14 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
const index_t GemmK =
arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) * arg.a_grid_desc_ak0_m_ak1_.GetLength(I2);
const index_t num_workgroups_per_Conv_N =
arg.a_g_n_c_wis_lengths_[I1] / arg.conv_N_per_block_;
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);
gdy *= arg.num_group_;
gdy *= arg.num_group_ * num_workgroups_per_Conv_N;
index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
......@@ -579,7 +615,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
}
else
......@@ -594,7 +631,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.compute_ptr_offset_of_batch_,
arg.compute_ptr_offset_of_groups_,
arg.compute_ptr_offset_of_n_,
arg.num_group_);
}
};
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -161,11 +161,11 @@ __global__ void
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t a_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t b_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane(
const long_index_t e_batch_offset = amd_wave_read_first_lane(
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
......@@ -338,7 +338,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -367,8 +368,8 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -163,7 +163,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads);
input_right_pads,
a_g_n_c_wis_lengths[I1]);
const auto in_gemmm_gemmk_desc =
matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
......@@ -255,8 +256,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides)
{
const auto out_gemmmraw_gemmnraw_desc =
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(e_g_n_k_wos_lengths,
e_g_n_k_wos_strides);
conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>(
e_g_n_k_wos_lengths, e_g_n_k_wos_strides, e_g_n_k_wos_lengths[I1]);
const auto out_gemmm_gemmn_desc =
matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
......@@ -581,7 +582,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
namespace ctc = tensor_layout::convolution;
// check device
if(ck::is_gfx11_supported())
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
{
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{
......
......@@ -68,14 +68,14 @@ template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
enable_if_t<(NumATensor > 1 || NumBTensor > 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(Array<ck::index_t, NumATensor>& BatchStrideAs,
Array<ck::index_t, NumBTensor>& BatchStrideBs,
Array<ck::index_t, NumDTensor>& BatchStrideDs,
index_t BatchStrideE)
ComputePtrOffsetOfStridedBatch(Array<long_index_t, NumATensor>& BatchStrideAs,
Array<long_index_t, NumBTensor>& BatchStrideBs,
Array<long_index_t, NumDTensor>& BatchStrideDs,
long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideAs),
BatchStrideB_(BatchStrideBs),
BatchStrideDs_(BatchStrideDs),
......@@ -87,7 +87,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumATensor> as_offset;
static_for<0, NumATensor, 1>{}(
[&](auto i) { as_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideA_[i]); });
[&](auto i) { as_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideA_[i]; });
return as_offset;
}
......@@ -95,7 +95,7 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumBTensor> bs_offset;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { bs_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideB_[i]); });
[&](auto i) { bs_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideB_[i]; });
return bs_offset;
}
......@@ -103,40 +103,40 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
Array<ck::index_t, NumATensor> BatchStrideA_;
Array<ck::index_t, NumBTensor> BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
Array<long_index_t, NumATensor> BatchStrideA_;
Array<long_index_t, NumBTensor> BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <index_t NumATensor, index_t NumBTensor, index_t NumDTensor>
struct ComputePtrOffsetOfStridedBatch<NumATensor,
NumBTensor,
NumDTensor,
ck::enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
enable_if_t<(NumATensor == 1 && NumBTensor == 1)>>
{
ComputePtrOffsetOfStridedBatch() = default;
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
index_t BatchStrideB,
Array<ck::index_t, NumDTensor> BatchStrideDs,
index_t BatchStrideE)
ComputePtrOffsetOfStridedBatch(long_index_t BatchStrideA,
long_index_t BatchStrideB,
Array<long_index_t, NumDTensor> BatchStrideDs,
long_index_t BatchStrideE)
: BatchStrideA_(BatchStrideA),
BatchStrideB_(BatchStrideB),
BatchStrideDs_(BatchStrideDs),
......@@ -146,38 +146,38 @@ struct ComputePtrOffsetOfStridedBatch<NumATensor,
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideA_);
return static_cast<long_index_t>(g_idx) * BatchStrideA_;
}
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideB_);
return static_cast<long_index_t>(g_idx) * BatchStrideB_;
}
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
{
Array<long_index_t, NumDTensor> ds_offset;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
[&](auto i) { ds_offset(i) = static_cast<long_index_t>(g_idx) * BatchStrideDs_[i]; });
return ds_offset;
}
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
// alias for kernels without multiple D
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
{
return g_idx * static_cast<long_index_t>(BatchStrideE_);
return static_cast<long_index_t>(g_idx) * BatchStrideE_;
}
ck::index_t BatchStrideA_;
ck::index_t BatchStrideB_;
Array<ck::index_t, NumDTensor> BatchStrideDs_;
index_t BatchStrideE_;
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
long_index_t BatchStrideA_;
long_index_t BatchStrideB_;
Array<long_index_t, NumDTensor> BatchStrideDs_;
long_index_t BatchStrideE_;
long_index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
};
template <bool isTuple, typename Tensors>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment