Commit 7f65ac05 authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 687d2b7e 7e5c81fe
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute_xdl.cpp)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_bias_softmax_gemm_permute gemm_bias_softmax_gemm_permute.cpp)
set(target 1)
endif()
endforeach()
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) add_custom_target(example_im2col_col2im)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_custom_target(example_im2col_col2im)
add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp) add_example_executable(example_image_to_column_f32 image_to_column_f32.cpp)
add_example_dependencies(example_im2col_col2im example_image_to_column_f32) add_example_dependencies(example_im2col_col2im example_image_to_column_f32)
add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp) add_example_executable(example_column_to_image_f32 column_to_image_f32.cpp)
add_example_dependencies(example_im2col_col2im example_column_to_image_f32) add_example_dependencies(example_im2col_col2im example_column_to_image_f32)
set(target 1)
endif()
endforeach()
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_gemm_multi_ABD_xdl_fp16 gemm_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
list(APPEND gpu_list2 gfx908 gfx90a gfx940 gfx941 gfx942) add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list2 AND target EQUAL 0)
add_example_executable(example_contraction_multi_ABD_xdl_fp16 contraction_multi_ABD_xdl_fp16.cpp)
set(target 1)
endif()
endforeach()
...@@ -2,16 +2,9 @@ add_subdirectory(binary) ...@@ -2,16 +2,9 @@ add_subdirectory(binary)
add_subdirectory(multi_AB) add_subdirectory(multi_AB)
add_subdirectory(unary) add_subdirectory(unary)
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942) add_custom_target(example_convnd_activ_xdl)
set(target 0) # ScaleAdd ScaleAdd Relu
foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
if(gpu IN_LIST gpu_list AND target EQUAL 0) add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_custom_target(example_convnd_activ_xdl) add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
# ScaleAdd ScaleAdd Relu add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_fp16)
add_example_executable(example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16 convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16.cpp)
add_example_dependencies(example_convnd_activ_xdl example_convnd_fwd_xdl_scaleadd_scaleadd_relu_bcasted_bias_fp16)
set(target 1)
endif()
endforeach()
if(GPU_TARGETS MATCHES "gfx11") add_custom_target(example_fpAintB_gemm_wmma)
add_custom_target(example_fpAintB_gemm_wmma) add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp)
add_example_executable(example_fp16int8_gemm_wmma fp16int8_gemm_wmma.cpp) add_example_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
add_dependencies(example_fpAintB_gemm_wmma example_fp16int8_gemm_wmma)
endif()
...@@ -5,6 +5,12 @@ include_directories(BEFORE ...@@ -5,6 +5,12 @@ include_directories(BEFORE
add_custom_target(examples) add_custom_target(examples)
function(add_example_dependencies EXAMPLE_NAME FILE_NAME)
if(FILE_NAME)
add_dependencies(EXAMPLE_NAME FILE_NAME)
endif()
endfunction(add_example_dependencies EXAMPLE_NAME)
function(add_example_executable EXAMPLE_NAME FILE_NAME) function(add_example_executable EXAMPLE_NAME FILE_NAME)
message("adding example ${EXAMPLE_NAME}") message("adding example ${EXAMPLE_NAME}")
set(result 1) set(result 1)
...@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) ...@@ -38,12 +44,27 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif() endif()
endforeach() endforeach()
endif() endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ") message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
...@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) ...@@ -97,12 +118,27 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif() endif()
endforeach() endforeach()
endif() endif()
#Do not build any DL examples if DL_KERNELS not set
foreach(source IN LISTS FILE_NAME) foreach(source IN LISTS FILE_NAME)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl example ${source} ") message("removing dl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}") list(REMOVE_ITEM FILE_NAME "${source}")
endif() endif()
endforeach() endforeach()
#Do not build any XDL examples if gfx9 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "_xdl")
message("removing xdl example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach(source IN LISTS FILE_NAME)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "_wmma")
message("removing wmma example ${source} ")
list(REMOVE_ITEM FILE_NAME "${source}")
endif()
endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(FILE_NAME) if(FILE_NAME)
add_executable(${EXAMPLE_NAME} ${FILE_NAME}) add_executable(${EXAMPLE_NAME} ${FILE_NAME})
......
...@@ -45,6 +45,10 @@ ...@@ -45,6 +45,10 @@
#endif #endif
// define general macros for various architectures // define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
...@@ -62,8 +66,7 @@ ...@@ -62,8 +66,7 @@
// buffer resource // buffer resource
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1 #define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \ #elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx9__)
defined(__gfx90a__) || defined(__gfx94__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx103__) #elif defined(__gfx103__)
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000 #define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
...@@ -75,8 +78,7 @@ ...@@ -75,8 +78,7 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing #ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code #elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32 #define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx103__) || \ #elif defined(__gfx906__) || defined(__gfx9__) || defined(__gfx103__) // for GPU code
defined(__gfx94__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32 #define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16 #define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8 #define CK_USE_AMD_V_DOT4_I32_I8
...@@ -89,7 +91,7 @@ ...@@ -89,7 +91,7 @@
// MFMA instruction // MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA #define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_MFMA #define CK_USE_AMD_MFMA
#endif #endif
...@@ -120,7 +122,7 @@ ...@@ -120,7 +122,7 @@
// buffer atomic add: floating point // buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code #ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) // for GPU code #elif defined(__gfx9__) // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code #else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0 #define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r2.hpp"
namespace ck {
/**
* @brief Blockwise data transfer
*
* This version does following things to avoid scratch memory issue
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
*
*/
template <typename ThreadGroup,
typename ElementwiseOperation,
typename DstInMemOps, // Sequence
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcDatas,
typename DstDatas,
typename SrcDescs,
typename DstDescs,
typename SrcDimAccessOrder,
typename DstDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
typename SrcsScalarPerVector, // Sequence
typename DstsScalarPerVector, // Sequence
typename SrcsScalarStrideInVector, // Sequence
typename DstsScalarStrideInVector, // Sequence
typename ThreadTransferSrcsResetCoordinateAfterRun, // Sequence
typename ThreadTransferDstsResetCoordinateAfterRun, // Sequence
index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r2
{
static constexpr index_t nDim =
remove_reference_t<tuple_element_t<0, SrcDescs>>::GetNumOfDimension();
static constexpr index_t nSrc = SrcDescs::Size();
static constexpr index_t nDst = DstDescs::Size();
static constexpr auto thread_slice_lengths = BlockSliceLengths{} / ThreadClusterLengths{};
using Index = MultiIndex<nDim>;
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r2(
const SrcDescs& src_descs,
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op)
{
static_assert(nDim == ThreadClusterLengths::Size() &&
nDim == ThreadClusterArrangeOrder::Size() &&
nDim == SrcDimAccessOrder::Size() && nDim == SrcDimAccessOrder::Size(),
"wrong! nDim not consistent");
static_for<0, nSrc, 1>{}([&](auto src_i) {
static_assert(nDim ==
remove_cvref_t<tuple_element_t<src_i, SrcDescs>>::GetNumOfDimension(),
"wrong! nDim not consistent");
});
static_for<0, nDst, 1>{}([&](auto dst_i) {
static_assert(nDim ==
remove_cvref_t<tuple_element_t<dst_i, DstDescs>>::GetNumOfDimension(),
"wrong! nDim not consistent");
});
static_assert(
is_same<BlockSliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
"wrong! threads should be mapped to cover entire slicing window");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"wrong! ThreadGroup::GetNumOfThread() too small");
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId()));
const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
const auto src_thread_slice_origins = generate_tuple(
[&](auto i) { return src_block_slice_origins[i] + thread_data_idx_begin; },
Number<nSrc>{});
const auto dst_thread_slice_origins = generate_tuple(
[&](auto i) { return dst_block_slice_origins[i] + thread_data_idx_begin; },
Number<nDst>{});
threadwise_transfer_.SetSrcSliceOrigins(src_descs, src_thread_slice_origins);
threadwise_transfer_.SetDstSliceOrigins(dst_descs, dst_thread_slice_origins);
}
}
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
}
}
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers& dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
}
}
template <typename SrcBuffer, typename DstBuffer, index_t ThreadScratchId>
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffer& src_bufs,
const DstDescs& dst_descs,
DstBuffer& dst_bufs,
Number<ThreadScratchId> thread_scratch_id)
{
RunRead(src_descs, src_bufs, thread_scratch_id);
RunWrite(dst_descs, dst_bufs, thread_scratch_id);
}
__device__ void MoveSrcSliceWindow(const SrcDescs& src_descs, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveSrcSliceWindow(src_descs, step);
}
}
__device__ void MoveDstSliceWindow(const DstDescs& dst_descs, const Index& step)
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.MoveDstSliceWindow(dst_descs, step);
}
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
using ThreadwiseTransfer =
ThreadwiseTensorSliceTransfer_v3r2<decltype(thread_slice_lengths),
ElementwiseOperation,
DstInMemOps,
SrcDatas,
DstDatas,
SrcDescs,
DstDescs,
SrcDimAccessOrder,
DstDimAccessOrder,
SrcVectorDim,
DstVectorDim,
SrcsScalarPerVector,
DstsScalarPerVector,
SrcsScalarStrideInVector,
DstsScalarStrideInVector,
ThreadTransferSrcsResetCoordinateAfterRun,
ThreadTransferDstsResetCoordinateAfterRun,
NumThreadScratch>;
ThreadwiseTransfer threadwise_transfer_;
};
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple()); ...@@ -40,7 +40,8 @@ using is_tuple = decltype(std::declval<T&>().IsTuple());
* \tparam AElementwiseOperation A elementwise operation. * \tparam AElementwiseOperation A elementwise operation.
* \tparam BElementwiseOperation B elementwise operation. * \tparam BElementwiseOperation B elementwise operation.
* \tparam CDEElementwiseOperation CDE elementwise operation. * \tparam CDEElementwiseOperation CDE elementwise operation.
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed). * \tparam AComputeType Compute data type for A tensor (default: ADataType, first if tuple passed).
* \tparam BComputeType Compute data type for B tensor (default: AComputeType).
*/ */
template <index_t NDimSpatial, template <index_t NDimSpatial,
typename ALayout, typename ALayout,
...@@ -54,12 +55,13 @@ template <index_t NDimSpatial, ...@@ -54,12 +55,13 @@ template <index_t NDimSpatial,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CDEElementwiseOperation, typename CDEElementwiseOperation,
typename ComputeType = typename AComputeType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>, Number<0>,
ADataType>())> // ComputeType is InputType by default (first ADataType>()), // AComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeType = AComputeType>
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
{ {
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value; static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template <index_t NumDTensor = 0>
struct GroupedGemmMultipleDKernelArguments
{
__host__ __device__
GroupedGemmMultipleDKernelArguments(const void* p_a_grid_,
const void* p_b_grid_,
std::array<const void*, NumDTensor> p_ds_grid_,
void* p_e_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_)
: p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{p_ds_grid_},
p_e_grid{p_e_grid_},
M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideE{StrideE_}
{
}
const void* p_a_grid;
const void* p_b_grid;
std::array<const void*, NumDTensor> p_ds_grid;
void* p_e_grid;
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
void Print() const
{
std::stringstream str;
for(auto sd : StrideDs)
str << sd << ",";
std::cout << "arg {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SE:" << StrideE << ", "
<< "SDs: {" << str.str() << "}"
<< "}" << std::endl;
}
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGroupedGemmMultipleDSplitK : public DeviceGroupedGemm<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const = 0;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_dynamic_vector_dims.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename InDataTypeTuple,
typename OutDataTypeTuple,
typename ElementwiseOperation,
index_t NumDim,
index_t BlockSize,
index_t M0PerBlock,
index_t M1PerBlock,
index_t M0PerThread,
index_t M1PerThread,
typename ThreadClusterArrangeOrder,
typename InScalarPerVectorSeq,
typename OutScalarPerVectorSeq>
struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{
static constexpr int NumInput = InDataTypeTuple::Size();
static constexpr int NumOutput = OutDataTypeTuple::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static_assert(NumInput == InScalarPerVectorSeq::Size() &&
NumOutput == OutScalarPerVectorSeq::Size(),
"Tuple size is inconsistent with the number of in/out!");
static auto GenerateInDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(nullptr);
},
Number<NumInput>{});
};
static auto GenerateOutDataTypePointerTuple()
{
return generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(nullptr);
},
Number<NumOutput>{});
};
using InDataTypePointerTuple = decltype(GenerateInDataTypePointerTuple());
using OutDataTypePointerTuple = decltype(GenerateOutDataTypePointerTuple());
static index_t GetLowestStrideDim(const std::array<index_t, NumDim>& strides)
{
index_t most_continous_dim = NumDim - 1;
index_t most_continous_dim_stride = strides[most_continous_dim];
for(index_t dim = 0; dim < NumDim; dim++)
{
if(strides[dim] < most_continous_dim_stride)
{
most_continous_dim_stride = strides[dim];
most_continous_dim = dim;
}
}
return most_continous_dim;
}
template <typename InOutDescriptor>
static auto PadInputOutputDescriptor(const InOutDescriptor& desc)
{
const auto M0 = desc.GetLength(I0);
const auto M1 = desc.GetLength(I1);
const auto pad_M0 = math::integer_divide_ceil(M0, M0PerThread) * M0PerThread - M0;
const auto pad_M1 = math::integer_divide_ceil(M1, M1PerThread) * M1PerThread - M1;
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_right_pad_transform(M0, pad_M0), make_right_pad_transform(M1, pad_M1)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return padded_desc;
}
static auto GenerateBatchDimsLenghtsTuple(const std::array<index_t, NumDim>& lengths,
const index_t M0_dim,
const index_t M1_dim)
{
// Generate batch dims, they will be merged to M0
// Add one more dim than needed in case that M0 is equal to M1
// If M0 is equal to M1, then will be one more batch dim
std::array<index_t, NumDim - 1> batch_dims;
index_t batch_dim = 0;
for(index_t i = 0; i < NumDim; i++)
{
if(i != M0_dim && i != M1_dim)
{
batch_dims[batch_dim] = lengths[i];
batch_dim++;
}
}
// Add dummy dim if M0_dim is not equal to M1_dim
if(M0_dim != M1_dim && NumDim >= 2)
batch_dims[NumDim - 2] = 1;
return generate_tuple([&](auto I) { return batch_dims[I]; }, Number<NumDim - 1>{});
}
static auto MakeDescriptor(const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& in_strides,
const std::array<index_t, NumDim>& out_strides,
const std::array<index_t, NumDim>& desc_strides)
{
const auto M0_dim = GetLowestStrideDim(out_strides);
const auto M1_dim = GetLowestStrideDim(in_strides);
// If M0_dim is equal to M1_dim, then make M0_dim dummy
const auto M0 = M0_dim == M1_dim ? I1 : lengths[M0_dim];
const auto M1 = lengths[M1_dim];
const auto M0_stride = M0_dim == M1_dim ? I1 : desc_strides[M0_dim];
const auto M1_stride = desc_strides[M1_dim];
const auto batch_dims_lenghts = GenerateBatchDimsLenghtsTuple(lengths, M0_dim, M1_dim);
const auto batch_dims_strides = GenerateBatchDimsLenghtsTuple(desc_strides, M0_dim, M1_dim);
const auto desc = make_naive_tensor_descriptor(
concat_tuple(batch_dims_lenghts, make_tuple(M0), make_tuple(M1)),
concat_tuple(batch_dims_strides, make_tuple(M0_stride), make_tuple(M1_stride)));
// Merged batch dims with M0
const auto transforms =
make_tuple(make_merge_transform(concat_tuple(batch_dims_lenghts, make_tuple(M0))),
make_pass_through_transform(M1));
using BatchElemsSequence =
typename arithmetic_sequence_gen<0, decltype(batch_dims_lenghts)::Size() + 1, 1>::type;
const auto lower_dims = make_tuple(BatchElemsSequence{}, Sequence<NumDim>{});
const auto upper_dims = make_tuple(Sequence<0>{}, Sequence<1>{});
// desc: (merged_dims + M0, M1)
auto merged_desc = transform_tensor_descriptor(desc, transforms, lower_dims, upper_dims);
return PadInputOutputDescriptor(merged_desc);
}
template <index_t NumTensors>
static auto GenerateInOutGridDescTuple()
{
std::array<index_t, NumDim> ones;
for(index_t d = 0; d < NumDim; d++)
{
ones[d] = 1;
}
return generate_tuple([&](auto) { return MakeDescriptor(ones, ones, ones, ones); },
Number<NumTensors>{});
};
using InGridDescTuple = decltype(GenerateInOutGridDescTuple<NumInput>());
using OutGridDescTuple = decltype(GenerateInOutGridDescTuple<NumOutput>());
using Block2TileMap = BlockToCTileMap_M00_N0_M01Adapt<M0PerBlock, M1PerBlock>;
using GridwiseElementwiseOp = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
false>;
using GridwiseElementwiseOpSameInOutVectorDim = GridwiseElementwise<InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation,
BlockSize,
M0PerBlock,
M1PerBlock,
M0PerThread,
M1PerThread,
ThreadClusterArrangeOrder,
InScalarPerVectorSeq,
OutScalarPerVectorSeq,
true>;
struct Argument : public BaseArgument
{
Argument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
: lengths_(lengths),
inStridesArray_(inStridesArray),
outStridesArray_(outStridesArray),
elementwise_op_(elementwise_op)
{
in_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
return static_cast<const DataType*>(in_dev_buffers[I.value]);
},
Number<NumInput>{});
out_dev_buffers_ = generate_tuple(
[&](auto I) {
using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
return static_cast<DataType*>(out_dev_buffers[I.value]);
},
Number<NumOutput>{});
}
InDataTypePointerTuple in_dev_buffers_;
OutDataTypePointerTuple out_dev_buffers_;
std::array<index_t, NumDim> lengths_;
std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
ElementwiseOperation elementwise_op_;
};
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
auto in_grid_desc_tuple = generate_tuple(
[&](auto src_i) {
// Use Strides from first tensor to assert that M0 dim and
// M1 dim are the same for each tensor.
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.inStridesArray_[src_i]);
},
Number<NumInput>{});
auto out_grid_desc_tuple = generate_tuple(
[&](auto dst_i) {
return MakeDescriptor(arg.lengths_,
arg.inStridesArray_[I0],
arg.outStridesArray_[I0],
arg.outStridesArray_[dst_i]);
},
Number<NumOutput>{});
const index_t M0 = in_grid_desc_tuple.At(I0).GetLength(Number<I0>{});
const index_t M1 = in_grid_desc_tuple.At(I0).GetLength(Number<I1>{});
const auto block_2_tile_map = Block2TileMap(M0, M1);
const index_t grid_size = block_2_tile_map.CalculateGridSize(M0, M1);
const bool in_out_same_vector_dim = GetLowestStrideDim(arg.inStridesArray_[I0]) ==
GetLowestStrideDim(arg.outStridesArray_[I0]);
const auto kernel = in_out_same_vector_dim
? kernel_elementwise<GridwiseElementwiseOpSameInOutVectorDim,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>
: kernel_elementwise<GridwiseElementwiseOp,
InGridDescTuple,
OutGridDescTuple,
InDataTypePointerTuple,
OutDataTypePointerTuple,
Block2TileMap,
ElementwiseOperation>;
float elapsed_time = launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
in_grid_desc_tuple,
out_grid_desc_tuple,
arg.in_dev_buffers_,
arg.out_dev_buffers_,
block_2_tile_map,
arg.elementwise_op_);
return elapsed_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
const index_t M0_dim = GetLowestStrideDim(arg.inStridesArray_[I0]);
const index_t M1_dim = GetLowestStrideDim(arg.outStridesArray_[I0]);
auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
const std::array<index_t, NumDim>& strides,
index_t scalarPerVector,
index_t M_dim) {
if(scalarPerVector == 1)
{
return true;
}
if(strides[M_dim] == 1 && lengths[M_dim] % scalarPerVector == 0)
{
return true;
}
return false;
};
bool is_valid = true;
static_for<0, NumInput, 1>{}([&](auto I) {
static_assert(M0PerThread % InScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % InScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I), M0_dim);
});
static_for<0, NumOutput, 1>{}([&](auto I) {
static_assert(M0PerThread % OutScalarPerVectorSeq::At(I) == 0 &&
M1PerThread % OutScalarPerVectorSeq::At(I) == 0);
is_valid &= IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I), M1_dim);
});
return is_valid;
};
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto
MakeArgument(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op)
{
return Argument{lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op};
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
const std::array<const void*, NumInput> in_dev_buffers,
const std::array<void*, NumOutput> out_dev_buffers,
ElementwiseOperation elementwise_op) override
{
return std::make_unique<Argument>(lengths,
inStridesArray,
outStridesArray,
in_dev_buffers,
out_dev_buffers,
elementwise_op);
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
};
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceElementwiseImpl<";
str << NumDim << ", ";
str << BlockSize << ", ";
str << M0PerBlock << ", ";
str << M1PerBlock << ", ";
str << M0PerThread << ", ";
str << M1PerThread << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,10 +22,12 @@ namespace device { ...@@ -22,10 +22,12 @@ namespace device {
template <typename InDataTypeTuple, template <typename InDataTypeTuple,
typename OutDataTypeTuple, typename OutDataTypeTuple,
typename ElementwiseOperation, typename ElementwiseOperation,
index_t NumDim, index_t NumDim, // The max dim of input tensors
index_t MPerThread, // the tensors descs have to be aligned, such that
typename InScalarPerVectorSeq, // the innermost dim is the contiguous one.
typename OutScalarPerVectorSeq> index_t MPerThread, // How many elements per thread to read
typename InScalarPerVectorSeq, // Scalar per vec for each Input
typename OutScalarPerVectorSeq> // Scalar per vec for each Output
struct DeviceElementwiseImpl struct DeviceElementwiseImpl
: public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim> : public DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim>
{ {
...@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl ...@@ -242,13 +244,13 @@ struct DeviceElementwiseImpl
static_for<0, NumInput, 1>{}([&](auto I) { static_for<0, NumInput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I))) arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
valid = false; valid = valid && false;
}); });
static_for<0, NumOutput, 1>{}([&](auto I) { static_for<0, NumOutput, 1>{}([&](auto I) {
if(!IsScalarPerVectorValid( if(!IsScalarPerVectorValid(
arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I))) arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
valid = false; valid = valid && false;
}); });
return valid; return valid;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -254,12 +254,13 @@ template <index_t NDimSpatial, ...@@ -254,12 +254,13 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType = typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>, Number<0>,
ADataType>()), // ComputeType is InputType by default (first ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial, : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
...@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -274,7 +275,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CDEElementwiseOperation, CDEElementwiseOperation,
ComputeDataType> AComputeDataType,
BComputeDataType>
{ {
using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle; using DeviceOp = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle;
...@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -386,7 +388,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>; using GemmBDataType = std::conditional_t<!isMultiB && isMultiA, Tuple<BDataType>, BDataType>;
#define GridwiseGemmTemplateParameters \ #define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
...@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -399,7 +401,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
BComputeDataType
// Use appropriate gridwise gemm // Use appropriate gridwise gemm
using GridwiseGemm = using GridwiseGemm =
std::conditional_t<isMultiA || isMultiB, std::conditional_t<isMultiA || isMultiB,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -75,12 +75,13 @@ template <index_t NDimSpatial, ...@@ -75,12 +75,13 @@ template <index_t NDimSpatial,
index_t CShuffleNXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock, index_t CDEBlockTransferScalarPerVector_NPerBlock,
typename ComputeDataType = typename AComputeDataType =
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
Number<0>, Number<0>,
ADataType>()), // ComputeType is InputType by default (first ADataType>()), // ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was // in tuple for MultiAB), unpack if tuple was
// passed // passed
typename BComputeDataType = AComputeDataType,
LoopScheduler LoopSched = make_default_loop_scheduler()> LoopScheduler LoopSched = make_default_loop_scheduler()>
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
NDimSpatial, NDimSpatial,
...@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl ...@@ -128,7 +129,8 @@ using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipl
CShuffleNXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock,
ComputeDataType, AComputeDataType,
BComputeDataType,
LoopSched>; LoopSched>;
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment