Commit 017fb2eb authored by muozturk's avatar muozturk
Browse files

cmake list

parents 7abb7439 3a3b98ef
...@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example() ...@@ -46,9 +46,9 @@ int run_layernorm4d_fwd_example()
{0, W * C, C, 1}, {0, W * C, C, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()}, std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(), std::vector<ck::index_t>{save_mean.mDesc.GetStrides().begin(),
save_mean.mDesc.GetStrides().end()}, save_mean.mDesc.GetStrides().end()},
{1, 2, 3}, {1, 2, 3},
1e-4, 1e-4,
x_dev.GetDeviceBuffer(), x_dev.GetDeviceBuffer(),
......
...@@ -134,6 +134,12 @@ ...@@ -134,6 +134,12 @@
// inner product using V_DOT with DPP8 modifiers // inner product using V_DOT with DPP8 modifiers
#define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1 #define CK_USE_AMD_V_DOT_DPP8_INLINE_ASM 1
// LDS direct loads using inline assembly
#define CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM 1
// set stochastic rounding as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0) // block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1 #define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
......
...@@ -26,7 +26,7 @@ inline std::string get_device_name() ...@@ -26,7 +26,7 @@ inline std::string get_device_name()
} }
const std::string raw_name(props.gcnArchName); const std::string raw_name(props.gcnArchName);
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40 // https://github.com/ROCm/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
static std::map<std::string, std::string> device_name_map = { static std::map<std::string, std::string> device_name_map = {
{"Ellesmere", "gfx803"}, {"Ellesmere", "gfx803"},
{"Baffin", "gfx803"}, {"Baffin", "gfx803"},
......
...@@ -11,6 +11,6 @@ struct StreamConfig ...@@ -11,6 +11,6 @@ struct StreamConfig
hipStream_t stream_id_ = nullptr; hipStream_t stream_id_ = nullptr;
bool time_kernel_ = false; bool time_kernel_ = false;
int log_level_ = 0; int log_level_ = 0;
int cold_niters_ = 50; int cold_niters_ = 1;
int nrepeat_ = 200; int nrepeat_ = 10;
}; };
...@@ -59,7 +59,9 @@ struct BaseOperator ...@@ -59,7 +59,9 @@ struct BaseOperator
virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; } virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
virtual void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const virtual void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const
{ {
assert(p_arg); assert(p_arg);
p_arg->p_workspace_ = p_workspace; p_arg->p_workspace_ = p_workspace;
......
...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType, ...@@ -376,7 +376,9 @@ struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -354,7 +354,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
std::array<index_t, NumATensor> a_mz_stride_; std::array<bool, NumATensor> as_mz_consecutive_;
std::array<index_t, NumATensor> a_kz_stride_; std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<index_t, NumBTensor> b_nz_stride_; std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<index_t, NumBTensor> b_kz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// check vector load/store // check vector load/store
{ {
bool all_valid = true; bool valid_as_access = true;
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
{ const bool valid_a_access_dim_m =
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
ABlockTransferSrcScalarPerVector == const bool valid_a_access_dim_k =
0)) ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
{ const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
all_valid = false; if(!(valid_a_vector_size && valid_a_access_dim))
}
}
else
{ {
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % valid_as_access = false;
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_as_access)
{
return false;
}
// vector memory access of B: could be on N or BK1 dimension bool valid_bs_access = true;
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1) const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{ {
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % valid_bs_access = false;
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_bs_access)
{
return false;
}
// check vector load of Ds bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
all_valid = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{ {
all_valid = false; return false;
} }
if(!all_valid) const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return generate_tuple([&](auto i) { return vec[i]; }, num); return generate_tuple([&](auto i) { return vec[i]; }, num);
}; };
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ... // dimension Ids for M0, M1, ...
...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{}; typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ... // lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks = const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
void Print() const void Print() const
...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
index_t a_mz_stride_; bool a_mz_consecutive_;
index_t a_kz_stride_; bool a_kz_consecutive_;
index_t b_nz_stride_; bool b_nz_consecutive_;
index_t b_kz_stride_; bool b_kz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
index_t e_mz_stride_; bool e_nz_consecutive_;
index_t e_nz_stride_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!"); "wrong!");
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
if(!(arg.a_mz_stride_ == 1 && return false;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of B: could be on N or BK1 dimension const bool valid_b_vector_size =
if constexpr(BBlockTransferSrcVectorDim == 1) arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
{ const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
if(!(arg.b_nz_stride_ == 1 && const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
{ if(!(valid_b_vector_size && valid_b_access_dim))
return false;
}
}
else
{ {
if(!(arg.b_kz_stride_ == 1 && return false;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of Ds: always on NPerBlock dimension bool valid_ds_access = true;
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_d_access = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
if(valid_d_access == false)
{ {
return false; return false;
} }
// vector memory access of E: always on NPerBlock dimension const bool valid_e_vector_size =
if(!(arg.e_nz_stride_ == 1 && arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % // Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock == const bool valid_e_access_dim = arg.e_nz_consecutive_;
0)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp" #include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr) if(pArg == nullptr)
......
...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout, ...@@ -380,7 +380,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
<< " LoopScheduler: " << " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", " << LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: " << "PipelineVersion: "
<< PipelineVersionToString[PipelineVer]; << PipelineVersionToString[PipelineVer] << ", "
<< "Prefetch: "
<< NumGemmKPrefetchStage;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
} }
} }
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides) const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]); ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> || is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
valid = false; valid = false;
} }
if constexpr(is_same_v<DLayout, ctc::G_K>)
{
// G and K must be the same
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
{
valid = false;
}
}
else
{
// E and D must have the same shape
for(index_t d = 0; d < NDimSpatial + 3; d++)
{
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
{
valid = false;
}
}
}
} }
else else
{ {
......
...@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> || is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
......
...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>); return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; p_arg_->p_workspace_ = p_workspace;
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......
...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout ...@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout
static constexpr const char* name = "GNDHWK"; static constexpr const char* name = "GNDHWK";
}; };
// for output bias
struct GK : public BaseTensorLayout
{
static constexpr const char* name = "GK";
};
// output tensor // output tensor
// packed NWGK/NHWGK/NDHWGK // packed NWGK/NHWGK/NDHWGK
struct NWGK : public BaseTensorLayout struct NWGK : public BaseTensorLayout
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment