Unverified Commit 29dcb956 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #33 from ROCm/lwpck-1292

Merge from the public repo.
parents 29deceb6 cbcc844e
...@@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle ...@@ -770,8 +770,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
if(ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || if(ck::is_navi3_supported())
ck::get_device_name() == "gfx1102")
{ {
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>)) if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
{ {
......
...@@ -57,7 +57,7 @@ __global__ void ...@@ -57,7 +57,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
......
...@@ -75,7 +75,7 @@ __global__ void ...@@ -75,7 +75,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -61,7 +61,7 @@ __global__ void ...@@ -61,7 +61,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
...@@ -84,7 +84,7 @@ __global__ void ...@@ -84,7 +84,7 @@ __global__ void
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -70,9 +70,8 @@ __global__ void ...@@ -70,9 +70,8 @@ __global__ void
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch,
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx1030__) || defined(__gfx1100__) || \ defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx1101__) || defined(__gfx1102__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
...@@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout, ...@@ -648,11 +647,8 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// TODO: Enable for gfx90a after complier fix if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
if(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx90a" || ck::is_navi2_supported() || ck::is_navi3_supported())
ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx1030" ||
ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx1100" ||
ck::get_device_name() == "gfx1101" || ck::get_device_name() == "gfx1102")
{ {
bool pass = true; bool pass = true;
pass = pass && arg.K_ % K1 == 0; pass = pass && arg.K_ % K1 == 0;
......
...@@ -69,7 +69,7 @@ __global__ void ...@@ -69,7 +69,7 @@ __global__ void
const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
...@@ -60,7 +60,7 @@ __global__ void ...@@ -60,7 +60,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -68,7 +68,7 @@ __global__ void ...@@ -68,7 +68,7 @@ __global__ void
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
...@@ -63,7 +63,7 @@ __global__ void ...@@ -63,7 +63,7 @@ __global__ void
const C0MatrixMask c0_matrix_mask) const C0MatrixMask c0_matrix_mask)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
......
...@@ -53,7 +53,7 @@ __global__ void ...@@ -53,7 +53,7 @@ __global__ void
kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg) kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch); __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType, ...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -55,7 +56,7 @@ __global__ void ...@@ -55,7 +56,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_as_grid,
...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
std::array<index_t, NumATensor> a_mz_stride_; std::array<bool, NumATensor> as_mz_consecutive_;
std::array<index_t, NumATensor> a_kz_stride_; std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<index_t, NumBTensor> b_nz_stride_; std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<index_t, NumBTensor> b_kz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// check vector load/store // check vector load/store
{ {
bool all_valid = true; bool valid_as_access = true;
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
{ const bool valid_a_access_dim_m =
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
ABlockTransferSrcScalarPerVector == const bool valid_a_access_dim_k =
0)) ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
{ const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
all_valid = false; if(!(valid_a_vector_size && valid_a_access_dim))
}
}
else
{ {
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % valid_as_access = false;
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_as_access)
{
return false;
}
// vector memory access of B: could be on N or BK1 dimension bool valid_bs_access = true;
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1) const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{ {
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % valid_bs_access = false;
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_bs_access)
{
return false;
}
// check vector load of Ds bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
all_valid = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{ {
all_valid = false; return false;
} }
if(!all_valid) const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -53,7 +54,7 @@ __global__ void ...@@ -53,7 +54,7 @@ __global__ void
const Block2ETileMap block_2_etile_map) const Block2ETileMap block_2_etile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return generate_tuple([&](auto i) { return vec[i]; }, num); return generate_tuple([&](auto i) { return vec[i]; }, num);
}; };
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ... // dimension Ids for M0, M1, ...
...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{}; typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ... // lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks = const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
void Print() const void Print() const
...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
index_t a_mz_stride_; bool a_mz_consecutive_;
index_t a_kz_stride_; bool a_kz_consecutive_;
index_t b_nz_stride_; bool b_nz_consecutive_;
index_t b_kz_stride_; bool b_kz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
index_t e_mz_stride_; bool e_nz_consecutive_;
index_t e_nz_stride_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!"); "wrong!");
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
if(!(arg.a_mz_stride_ == 1 && return false;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of B: could be on N or BK1 dimension const bool valid_b_vector_size =
if constexpr(BBlockTransferSrcVectorDim == 1) arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
{ const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
if(!(arg.b_nz_stride_ == 1 && const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
{ if(!(valid_b_vector_size && valid_b_access_dim))
return false;
}
}
else
{ {
if(!(arg.b_kz_stride_ == 1 && return false;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of Ds: always on NPerBlock dimension bool valid_ds_access = true;
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_d_access = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
if(valid_d_access == false)
{ {
return false; return false;
} }
// vector memory access of E: always on NPerBlock dimension const bool valid_e_vector_size =
if(!(arg.e_nz_stride_ == 1 && arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % // Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock == const bool valid_e_access_dim = arg.e_nz_consecutive_;
0)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in "
<< "device_contraction_utils.hpp"
<< ":" << __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -56,7 +56,7 @@ __global__ void ...@@ -56,7 +56,7 @@ __global__ void
const Block2CTileMap block_2_ctile_map) const Block2CTileMap block_2_ctile_map)
{ {
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) defined(__gfx94__))
const index_t num_blocks_per_batch = const index_t num_blocks_per_batch =
__builtin_amdgcn_readfirstlane(get_grid_size() / num_batches); __builtin_amdgcn_readfirstlane(get_grid_size() / num_batches);
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
......
...@@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl ...@@ -1393,9 +1393,8 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Dl
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
// check device // check device
if(!(ck::get_device_name() == "gfx906" || ck::get_device_name() == "gfx1030" || if(!(ck::get_device_name() == "gfx906" || ck::is_navi2_supported() ||
ck::get_device_name() == "gfx1100" || ck::get_device_name() == "gfx1101" || ck::is_navi3_supported()))
ck::get_device_name() == "gfx1102"))
{ {
return false; return false;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp" #include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr) if(pArg == nullptr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment