Commit 6b9a4bd5 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'amd-develop-staging-0423' into amd-master

parents 56de337f c5f1cdf7
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,10 +22,12 @@ namespace device { ...@@ -22,10 +22,12 @@ namespace device {
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t NumDim, index_t NumDim, // The max dim of input tensors
index_t MPerThread, // the tensors descs have to be aligned, such that
typename InScalarPerVectorSeq, // the innermost dim is the contiguous one.
typename OutScalarPerVectorSeq> index_t MPerThread, // How many elements per thread to read
typename InScalarPerVectorSeq, // Scalar per vec for each Input
typename OutScalarPerVectorSeq> // Scalar per vec for each Output
struct DeviceElementwiseImpl struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim> : public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{ {
...@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl ...@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = valid && false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = valid && false;
}); });
return valid; return valid;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -254,12 +254,13 @@ template <index_t NDimSpatial, ...@@ -254,12 +254,13 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType = typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>, Number<0>,
ADataType>()), // ComputeType is InputType by default (first ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial, : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
...@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
ComputeDataType> AComputeDataType,
BComputeDataType>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
...@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>; using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmTemplateParameters \ #define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
...@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
// Use appropriate gridwise gemm // Use appropriate gridwise gemm
using GridwiseGemm = using GridwiseGemm =
std::conditional_t<isMultiA || isMultiB, std::conditional_t<isMultiA || isMultiB,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -75,12 +75,13 @@ template <index_t NDimSpatial, ...@@ -75,12 +75,13 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType = typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>, Number<0>,
ADataType>()), // ComputeType is InputType by default (first ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
...@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl ...@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
ComputeDataType, AComputeDataType,
BComputeDataType,
LoopSched>; LoopSched>;
} // namespace device } // namespace device
......
...@@ -23,6 +23,7 @@ namespace device { ...@@ -23,6 +23,7 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
GemmSpecialization GemmSpec, GemmSpecialization GemmSpec,
bool Zeroing,
typename ALayout, typename ALayout,
typename BLayout, typename BLayout,
typename DsLayout, typename DsLayout,
...@@ -106,8 +107,37 @@ __global__ void ...@@ -106,8 +107,37 @@ __global__ void
const auto block_2_etile_map = const auto block_2_etile_map =
GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off);
if constexpr(Zeroing)
{
auto barrier_count_finished = auto barrier_count_finished =
barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; barrier_count + group_id * barrier_size_grp + id_local % mn_blocks;
GridwiseGemm::template RunWithZeroing<HasMainKBlockLoop,
EGlobalMemoryDataOperation,
GemmSpec,
ALayout,
BLayout,
DsLayout,
ELayout>(gemm_desc_ptr[group_id].p_a_grid,
gemm_desc_ptr[group_id].p_b_grid,
p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid,
p_shared,
barrier_count_finished,
a_element_op,
b_element_op,
c_element_op,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
KBatch,
block_2_etile_map);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, GridwiseGemm::template Run<HasMainKBlockLoop,
EGlobalMemoryDataOperation, EGlobalMemoryDataOperation,
...@@ -120,7 +150,7 @@ __global__ void ...@@ -120,7 +150,7 @@ __global__ void
p_ds_grid_, p_ds_grid_,
gemm_desc_ptr[group_id].p_e_grid, gemm_desc_ptr[group_id].p_e_grid,
p_shared, p_shared,
barrier_count_finished, nullptr,
a_element_op, a_element_op,
b_element_op, b_element_op,
c_element_op, c_element_op,
...@@ -133,6 +163,7 @@ __global__ void ...@@ -133,6 +163,7 @@ __global__ void
StrideE, StrideE,
KBatch, KBatch,
block_2_etile_map); block_2_etile_map);
}
id_off += grid_size_grp; id_off += grid_size_grp;
id_local += grid_size_grp; id_local += grid_size_grp;
...@@ -193,8 +224,11 @@ template <typename ALayout, ...@@ -193,8 +224,11 @@ template <typename ALayout,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
PipelineVersion PipelineVer = PipelineVersion::v1,
LoopScheduler LoopSched = make_default_loop_scheduler(),
typename ComputeType = ADataType, typename ComputeType = ADataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> typename ALDSType = ComputeType,
typename BLDSType = ComputeType>
struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -215,11 +249,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
using AComputeType = ComputeType;
using BComputeType = ComputeType;
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle< using GridwiseGemm = GridwiseGemmMultipleD_xdl_splitk_cshuffle<
ADataType, // TODO: distinguish A/B datatype ADataType, // TODO: distinguish A/B datatype
BDataType, BDataType,
ComputeType, AComputeType,
BComputeType,
AccDataType, AccDataType,
CShuffleDataType, CShuffleDataType,
DsDataType, DsDataType,
...@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -258,7 +296,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched>; LoopSched,
PipelineVer,
ALDSType,
BLDSType>;
template <typename UnderlyingBlockToCTileMap> template <typename UnderlyingBlockToCTileMap>
struct OffsettedBlockToCTileMapMLoops struct OffsettedBlockToCTileMapMLoops
...@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -613,10 +654,49 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
float ave_time = 0; float ave_time = 0;
auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) {
if(arg.k_batch_ == 1)
{
const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>,
GemmSpec,
false,
ALayout,
BLayout,
DsLayout,
ELayout,
DsDataType,
Block2ETileMap,
GroupedGemmBlock2ETileMap,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
e_global_memory_operation_,
has_main_k_block_loop_>;
return launch_and_time_kernel(
stream_config,
kernel,
dim3(arg.grid_size_),
dim3(BlockSize),
0,
cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev),
nullptr,
arg.barrier_size_grp_,
arg.gemm_desc_kernel_arg_.size(),
arg.grid_size_grp_,
arg.k_batch_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
}
else
{
const auto kernel = const auto kernel =
kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm, kernel_grouped_gemm_xdl_fixed_nk<GridwiseGemm,
GroupedGemmKernelArgument<NumDTensor>, GroupedGemmKernelArgument<NumDTensor>,
GemmSpec, GemmSpec,
true,
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
...@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -645,13 +725,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
arg.a_element_op_, arg.a_element_op_,
arg.b_element_op_, arg.b_element_op_,
arg.c_element_op_); arg.c_element_op_);
}
}; };
constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd;
constexpr auto Set = InMemoryDataOperationEnum::Set; constexpr auto Set = InMemoryDataOperationEnum::Set;
// For bf16 datatype only kbatch = 1 scenario is supported. This condition is enforced // For bf16 datatype only kbatch = 1 scenario is supported. This condition is
// in IsSupportedArgument function // enforced in IsSupportedArgument function
if constexpr(std::is_same<ADataType, ck::bhalf_t>::value) if constexpr(std::is_same<ADataType, ck::bhalf_t>::value)
{ {
if(has_main_k_block_loop) if(has_main_k_block_loop)
...@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -719,12 +800,12 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
bool supported = true; bool supported = true;
// If we use padding we do not support vector loads for dimensions not divisible by vector // If we use padding we do not support vector loads for dimensions not divisible by
// load size. // vector load size.
if constexpr(GemmSpec != GemmSpecialization::Default) if constexpr(GemmSpec != GemmSpecialization::Default)
{ {
// [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} layout, // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1}
// thus we have to adapt it to the {M,K} or {N,K} layout. // layout, thus we have to adapt it to the {M,K} or {N,K} layout.
const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0;
const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -26,13 +26,19 @@ namespace device { ...@@ -26,13 +26,19 @@ namespace device {
template <typename GridwiseGemm, template <typename GridwiseGemm,
typename GemmDesc, typename GemmDesc,
bool HasMainKBlockLoop, bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation> InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif #endif
kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
const index_t group_count) const index_t group_count,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CElementwiseOperation c_element_op)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__)) defined(__gfx94__))
...@@ -64,10 +70,16 @@ __global__ void ...@@ -64,10 +70,16 @@ __global__ void
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation>(
gemm_desc_ptr[group_id].karg_, gemm_desc_ptr[group_id].karg_,
static_cast<void*>(p_shared), static_cast<void*>(p_shared),
gemm_desc_ptr[group_id].block_2_ctile_map_); gemm_desc_ptr[group_id].block_2_ctile_map_,
a_element_op,
b_element_op,
c_element_op);
#else #else
ignore = gemm_descs_const; ignore = gemm_descs_const;
ignore = group_count; ignore = group_count;
ignore = a_element_op;
ignore = b_element_op;
ignore = c_element_op;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
} }
...@@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -193,7 +205,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
static constexpr index_t B2E_M01 = 8; static constexpr index_t B2E_M01 = 8;
using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap<Block2ETileMapKSplit>;
using KernelArgument = typename GridwiseGemm::Argument; using KernelArgument = typename GridwiseGemm::Argument;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
struct GemmTransKernelArg struct GemmTransKernelArg
{ {
KernelArgument karg_; KernelArgument karg_;
...@@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo ...@@ -437,7 +449,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3(BlockSize), dim3(BlockSize),
0, 0,
cast_pointer_to_constant_address_space(arg.p_workspace_), cast_pointer_to_constant_address_space(arg.p_workspace_),
arg.gemm_kernel_args_.size()); arg.gemm_kernel_args_.size(),
PassThrough{},
PassThrough{},
PassThrough{});
}; };
if(all_have_main_k0_block_loop) if(all_have_main_k0_block_loop)
......
...@@ -92,6 +92,110 @@ struct Add ...@@ -92,6 +92,110 @@ struct Add
}; };
}; };
struct Max
{
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
const Y x0_converted = type_convert<Y>(x0);
const Y x1_converted = type_convert<Y>(x1);
y = ck::math::max(x0_converted, x1_converted);
}
};
struct Min
{
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
const Y x0_converted = type_convert<Y>(x0);
const Y x1_converted = type_convert<Y>(x1);
y = ck::math::min(x0_converted, x1_converted);
}
};
struct Multiply
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 * type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 * x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) * x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 * x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x1);
y = x0 * x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
{
const float x1_tmp = ck::type_convert<float>(x0);
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x1_tmp * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
{
const float x2_tmp = ck::type_convert<float>(x1);
const float y_tmp = x0 * x2_tmp;
y = ck::type_convert<bhalf_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 * x1;
};
};
struct ScaleAdd struct ScaleAdd
{ {
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {} __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
namespace tensor_operation {
namespace element_wise {
// y = UnaryOp0(UnaryOp1(...(x)))
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// Execute first unary op to copy data to y
unary_ops_.At(Number<0>{})(y, x);
static_for<1, Tuple<UnaryOpsSet...>::Size(), 1>{}([&](auto i) { unary_ops_.At(i)(y, y); });
};
Tuple<UnaryOpsSet...> unary_ops_;
};
// y = BinaryOp(UnaryOp0(x0), UnaryOp1(x1))
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1)
: binary_op_(binary_op), unary_op0_(unary_op0), unary_op1_(unary_op1)
{
}
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
Y unary_x0_tmp_result;
Y unary_x1_tmp_result;
unary_op0_(unary_x0_tmp_result, x0);
unary_op1_(unary_x1_tmp_result, x1);
binary_op_(y, unary_x0_tmp_result, unary_x1_tmp_result);
};
private:
BinaryOp binary_op_;
UnaryOp0 unary_op0_;
UnaryOp1 unary_op1_;
};
// y = BinaryOp0(BinaryOp1(UnaryOp0(x0), UnaryOp1(x1)), UnaryOp2(x2))
template <typename BinaryOp0,
typename BinaryOp1,
typename UnaryOp0,
typename UnaryOp1,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
__host__ __device__ TrinaryWithUnaryCombinedOp(BinaryOp0 binary_op0,
BinaryOp0 binary_op1,
UnaryOp0 unary_op0,
UnaryOp1 unary_op1,
UnaryOp2 unary_op2)
: binary_op0_(binary_op0),
binary_op1_(binary_op1),
unary_op0_(unary_op0),
unary_op1_(unary_op1),
unary_op2_(unary_op2)
{
}
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1, const X2& x2) const
{
Y unary_x0_tmp_result;
Y unary_x1_tmp_result;
Y unary_x2_tmp_result;
unary_op0_(unary_x0_tmp_result, x0);
unary_op1_(unary_x1_tmp_result, x1);
unary_op2_(unary_x2_tmp_result, x2);
binary_op0_(unary_x0_tmp_result, unary_x0_tmp_result, unary_x1_tmp_result);
binary_op1_(y, unary_x0_tmp_result, unary_x2_tmp_result);
};
private:
BinaryOp0 binary_op0_{};
BinaryOp1 binary_op1_{};
UnaryOp0 unary_op0_{};
UnaryOp1 unary_op1_{};
UnaryOp2 unary_op2_{};
};
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp"
#include "ck/utility/common_header.hpp"
namespace ck {
template <typename GridwiseElementwiseFunctor,
typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
const OutGridDescTuple out_grid_desc_tuple,
const InDataTypePointerTuple p_in_global_tuple,
const OutDataTypePointerTuple p_out_global_tuple,
const Block2TileMap block_2_tile_map,
const ElementwiseOperation elementwise_op)
{
GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
out_grid_desc_tuple,
p_in_global_tuple,
p_out_global_tuple,
block_2_tile_map,
elementwise_op);
}
template <typename InGridDescTuple,
typename OutGridDescTuple,
typename InDataTypePointerTuple,
typename OutDataTypePointerTuple,
typename Block2TileMap,
typename ElementwiseOperation,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq,
bool InOutSameVectorDim>
struct GridwiseElementwise
{
static constexpr index_t NumInput = InDataTypePointerTuple::Size();
static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size() &&
NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
"Tuple size is inconsistent with the number of in/out!");
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
using PassThroughOp = tensor_operation::element_wise::PassThrough;
__device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
const OutGridDescTuple& out_grid_desc_tuple,
const InDataTypePointerTuple& p_in_global_tuple,
const OutDataTypePointerTuple& p_out_global_tuple,
const Block2TileMap& block_2_tile_map,
const ElementwiseOperation& elementwise_op)
{
constexpr auto src_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
using DataType = remove_cv_t<remove_pointer_t<DataTypePointer>>;
return DataType{};
},
Number<NumInput>{});
constexpr auto dst_datas = generate_tuple(
[&](auto I) {
using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
using DataType = remove_pointer_t<DataTypePointer>;
return DataType{};
},
Number<NumOutput>{});
const auto in_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumInput>{});
auto out_global_buf_tuple = generate_tuple(
[&](auto I) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
},
Number<NumOutput>{});
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t m0_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
const index_t m1_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
const auto input_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumInput>{});
const auto output_thread_grid_offset = generate_tuple(
[&](auto) {
return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
},
Number<NumOutput>{});
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
// If src and dst have same vector dim, then:
// M0 dim - for src and dst vector load/store
// else:
// M0 dim - for dst vector load
// M1 dim - for src vector store
using SrcDimAccessOrder = Sequence<0, 1>;
using DstDimAccessOrder =
std::conditional_t<InOutSameVectorDim, Sequence<0, 1>, Sequence<1, 0>>;
using SrcVectorDim = Number<1>;
using DstVectorDim = std::conditional_t<InOutSameVectorDim, Number<1>, Number<0>>;
using ThreadClusterLengths =
Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
ThisThreadBlock,
ElementwiseOperation,
uniform_sequence_gen_t<NumOutput, static_cast<index_t>(InMemoryDataOperationEnum::Set)>,
Sequence<M0PerBlock, M1PerBlock>,
ThreadClusterLengths,
ThreadClusterArrangeOrder,
decltype(src_datas),
decltype(dst_datas),
InGridDescTuple,
OutGridDescTuple,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim{},
DstVectorDim{},
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
uniform_sequence_gen_t<NumInput, 1>,
uniform_sequence_gen_t<NumOutput, 1>,
uniform_sequence_gen_t<NumInput, false>,
uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
input_thread_grid_offset,
out_grid_desc_tuple,
output_thread_grid_offset,
elementwise_op};
global_to_global_transfer.Run(
in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
}
};
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -73,7 +73,7 @@ template <typename ADataType, ...@@ -73,7 +73,7 @@ template <typename ADataType,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock, index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched, LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v1, PipelineVersion PipelineVer = PipelineVersion::v1,
typename BComputeDataType = AComputeDataType_> typename BComputeDataType_ = AComputeDataType_>
struct GridwiseGemmMultipleD_xdl_cshuffle struct GridwiseGemmMultipleD_xdl_cshuffle
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -103,8 +103,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#if CK_WORKAROUND_DENORM_FIX #if CK_WORKAROUND_DENORM_FIX
using AComputeDataType = using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>; conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
using BComputeDataType =
conditional_t<is_same_v<BComputeDataType_, ck::half_t>, ck::bhalf_t, BComputeDataType_>;
#else #else
using AComputeDataType = AComputeDataType_; using AComputeDataType = AComputeDataType_;
using BComputeDataType = BComputeDataType_;
#endif #endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment