Commit 057140b1 authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 134fc2e7 12a8883c
...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType, ...@@ -345,7 +345,9 @@ struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -500,22 +501,29 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access // for sanity check of vector memory access
for(index_t i = 0; i < NumATensor; ++i) for(index_t i = 0; i < NumATensor; ++i)
{ {
a_mz_stride_[i] = a_ms_ks_strides[i][NumDimM - 1]; as_mz_consecutive_[i] = a_ms_ks_strides[i][NumDimM - 1] == 1;
a_kz_stride_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1]; as_kz_consecutive_[i] = a_ms_ks_strides[i][NumDimM + NumDimK - 1] == 1;
as_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths[i], a_ms_ks_strides[i]);
} }
for(index_t i = 0; i < NumBTensor; ++i) for(index_t i = 0; i < NumBTensor; ++i)
{ {
b_nz_stride_[i] = b_ns_ks_strides[i][NumDimN - 1]; bs_nz_consecutive_[i] = b_ns_ks_strides[i][NumDimN - 1] == 1;
b_kz_stride_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1]; bs_kz_consecutive_[i] = b_ns_ks_strides[i][NumDimN + NumDimK - 1] == 1;
bs_max_read_elems_[i] =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths[i], b_ns_ks_strides[i]);
} }
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = d_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(d_ms_ns_lengths[i], d_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_stride[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_stride[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ = CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_length, e_ms_ns_stride);
} }
// pointers // pointers
...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -545,16 +553,19 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
std::array<index_t, NumATensor> a_mz_stride_; std::array<bool, NumATensor> as_mz_consecutive_;
std::array<index_t, NumATensor> a_kz_stride_; std::array<bool, NumATensor> as_kz_consecutive_;
std::array<bool, NumBTensor> bs_nz_consecutive_;
std::array<index_t, NumBTensor> b_nz_stride_; std::array<bool, NumBTensor> bs_kz_consecutive_;
std::array<index_t, NumBTensor> b_kz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
bool e_nz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_nz_stride_; std::array<index_t, NumATensor> as_max_read_elems_;
std::array<index_t, NumBTensor> bs_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle ...@@ -643,73 +654,65 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// check vector load/store // check vector load/store
{ {
bool all_valid = true; bool valid_as_access = true;
static_for<0, NumATensor, 1>{}([&](auto i) { static_for<0, NumATensor, 1>{}([&](auto i) {
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.as_max_read_elems_[i] % ABlockTransferSrcScalarPerVector == 0;
{ const bool valid_a_access_dim_m =
if(!(arg.a_mz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I1) % ABlockTransferSrcVectorDim == 1 && arg.as_mz_consecutive_[i];
ABlockTransferSrcScalarPerVector == const bool valid_a_access_dim_k =
0)) ABlockTransferSrcVectorDim == 2 && arg.as_kz_consecutive_[i];
{ const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
all_valid = false; if(!(valid_a_vector_size && valid_a_access_dim))
}
}
else
{ {
if(!(arg.a_kz_stride_[i] == 1 && arg.as_grid_desc_ak0_m_ak1_[i].GetLength(I2) % valid_as_access = false;
ABlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_as_access)
{
return false;
}
// vector memory access of B: could be on N or BK1 dimension bool valid_bs_access = true;
static_for<0, NumBTensor, 1>{}([&](auto i) { static_for<0, NumBTensor, 1>{}([&](auto i) {
if constexpr(BBlockTransferSrcVectorDim == 1) const bool valid_b_vector_size =
arg.bs_max_read_elems_[i] % BBlockTransferSrcScalarPerVector == 0;
const bool valid_b_access_dim_n =
BBlockTransferSrcVectorDim == 1 && arg.bs_nz_consecutive_[i];
const bool valid_b_access_dim_k =
BBlockTransferSrcVectorDim == 2 && arg.bs_kz_consecutive_[i];
const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
if(!(valid_b_vector_size && valid_b_access_dim))
{ {
if(!(arg.b_nz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I1) % valid_bs_access = false;
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
}
else
{
if(!(arg.b_kz_stride_[i] == 1 && arg.bs_grid_desc_bk0_n_bk1_[i].GetLength(I2) %
BBlockTransferSrcScalarPerVector ==
0))
{
all_valid = false;
}
} }
}); });
if(!valid_bs_access)
{
return false;
}
// check vector load of Ds bool valid_ds_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
all_valid = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{ {
all_valid = false; return false;
} }
if(!all_valid) const bool valid_e_vector_size =
arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
// Vector write of E is always on N dimension.
const bool valid_e_access_dim = arg.e_nz_consecutive_;
if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/device_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp" #include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -183,7 +184,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return generate_tuple([&](auto i) { return vec[i]; }, num); return generate_tuple([&](auto i) { return vec[i]; }, num);
}; };
const auto a_ms_ns_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_lengths = to_tuple(a_ms_ks_lengths_vec, Number<NumDimM + NumDimK>{});
const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{}); const auto a_ms_ks_strides = to_tuple(a_ms_ks_strides_vec, Number<NumDimM + NumDimK>{});
// dimension Ids for M0, M1, ... // dimension Ids for M0, M1, ...
...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -194,14 +195,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{}; typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
// lengths for M0, M1, ... // lengths for M0, M1, ...
const auto mLengths = get_container_subset(a_ms_ns_lengths, mDimIds); const auto mLengths = get_container_subset(a_ms_ks_lengths, mDimIds);
// lengths for K0, K1, ... // lengths for K0, K1, ...
const auto kLengths = get_container_subset(a_ms_ns_lengths, kDimIds); const auto kLengths = get_container_subset(a_ms_ks_lengths, kDimIds);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] // naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const auto a_grid_desc_ms_ks = const auto a_grid_desc_ms_ks =
make_naive_tensor_descriptor(a_ms_ns_lengths, a_ms_ks_strides); make_naive_tensor_descriptor(a_ms_ks_lengths, a_ms_ks_strides);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...] // transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor( const auto a_grid_desc_mraw_kraw = transform_tensor_descriptor(
...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -383,7 +384,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b_grid, const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid, std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid, void* p_e_grid,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -398,7 +399,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b_grid_{static_cast<const BDataType*>(p_b_grid)}, p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{}, p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)}, p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ns_lengths, a_ms_ks_strides)}, a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(a_ms_ks_lengths, a_ms_ks_strides)},
b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)}, b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(b_ns_ks_lengths, b_ns_ks_strides)},
ds_grid_desc_m_n_{}, ds_grid_desc_m_n_{},
e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)}, e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N(e_ms_ns_lengths, e_ms_ns_strides)},
...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -411,13 +412,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)}, block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op}
a_mz_stride_{},
a_kz_stride_{},
b_nz_stride_{},
b_kz_stride_{},
ds_nz_stride_{},
e_nz_stride_{}
{ {
// populate pointer, batch stride, desc for Ds // populate pointer, batch stride, desc for Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -448,18 +443,26 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_consecutive_ = a_ms_ks_strides[NumDimM - 1] == 1;
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_consecutive_ = a_ms_ks_strides[NumDimM + NumDimK - 1] == 1;
a_max_read_elems_ =
CalculateMaxRead<NumDimM, NumDimK>(a_ms_ks_lengths, a_ms_ks_strides);
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_nz_consecutive_ = b_ns_ks_strides[NumDimN - 1] == 1;
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; b_kz_consecutive_ = b_ns_ks_strides[NumDimN + NumDimK - 1] == 1;
b_max_read_elems_ =
CalculateMaxRead<NumDimN, NumDimK>(b_ns_ks_lengths, b_ns_ks_strides);
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_consecutive_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1] == 1;
ds_max_read_elems_[i] =
CalculateMaxRead<NumDimM, NumDimN>(ds_ms_ns_lengths[i], ds_ms_ns_strides[i]);
} }
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_consecutive_ = e_ms_ns_strides[NumDimM + NumDimN - 1] == 1;
e_max_write_elems_ =
CalculateMaxRead<NumDimM, NumDimN>(e_ms_ns_lengths, e_ms_ns_strides);
} }
void Print() const void Print() const
...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -499,15 +502,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// Strides for the last M/N/K dimensions of A/B/Ds/E // Describe whether the last part of a given dimension of A/B/D/E is consecutive
// for sanity check of vector load/store // in the memory or not.
index_t a_mz_stride_; bool a_mz_consecutive_;
index_t a_kz_stride_; bool a_kz_consecutive_;
index_t b_nz_stride_; bool b_nz_consecutive_;
index_t b_kz_stride_; bool b_kz_consecutive_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<bool, NumDTensor> ds_nz_consecutive_;
index_t e_mz_stride_; bool e_nz_consecutive_;
index_t e_nz_stride_;
index_t a_max_read_elems_;
index_t b_max_read_elems_;
std::array<index_t, NumDTensor> ds_max_read_elems_;
index_t e_max_write_elems_;
}; };
// Invoker // Invoker
...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -616,65 +623,47 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2), (BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!"); "wrong!");
// vector memory access of A: could be on M or AK1 dimension const bool valid_a_vector_size =
if constexpr(ABlockTransferSrcVectorDim == 1) arg.a_max_read_elems_ % ABlockTransferSrcScalarPerVector == 0;
const bool valid_a_access_dim_m = ABlockTransferSrcVectorDim == 1 && arg.a_mz_consecutive_;
const bool valid_a_access_dim_k = ABlockTransferSrcVectorDim == 2 && arg.a_kz_consecutive_;
const bool valid_a_access_dim = valid_a_access_dim_m || valid_a_access_dim_k;
if(!(valid_a_vector_size && valid_a_access_dim))
{ {
if(!(arg.a_mz_stride_ == 1 && return false;
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of B: could be on N or BK1 dimension const bool valid_b_vector_size =
if constexpr(BBlockTransferSrcVectorDim == 1) arg.b_max_read_elems_ % BBlockTransferSrcScalarPerVector == 0;
{ const bool valid_b_access_dim_n = BBlockTransferSrcVectorDim == 1 && arg.b_nz_consecutive_;
if(!(arg.b_nz_stride_ == 1 && const bool valid_b_access_dim_k = BBlockTransferSrcVectorDim == 2 && arg.b_kz_consecutive_;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0)) const bool valid_b_access_dim = valid_b_access_dim_n || valid_b_access_dim_k;
{ if(!(valid_b_vector_size && valid_b_access_dim))
return false;
}
}
else
{ {
if(!(arg.b_kz_stride_ == 1 && return false;
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
} }
// vector memory access of Ds: always on NPerBlock dimension bool valid_ds_access = true;
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 && const bool valid_d_vector_size =
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) % arg.ds_max_read_elems_[i] % CDEBlockTransferScalarPerVector_NPerBlock == 0;
CDEBlockTransferScalarPerVector_NPerBlock == // Vector read of Ds is always on N dimension.
0)) const bool valid_d_access_dim = arg.ds_nz_consecutive_[i];
if(!(valid_d_vector_size && valid_d_access_dim))
{ {
valid_d_access = false; valid_ds_access = false;
} }
}); });
if(!valid_ds_access)
if(valid_d_access == false)
{ {
return false; return false;
} }
// vector memory access of E: always on NPerBlock dimension const bool valid_e_vector_size =
if(!(arg.e_nz_stride_ == 1 && arg.e_max_write_elems_ % CDEBlockTransferScalarPerVector_NPerBlock == 0;
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) % // Vector write of E is always on N dimension.
CDEBlockTransferScalarPerVector_NPerBlock == const bool valid_e_access_dim = arg.e_nz_consecutive_;
0)) if(!(valid_e_vector_size && valid_e_access_dim))
{ {
return false; return false;
} }
...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -692,7 +681,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -708,7 +697,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -729,7 +718,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const void* p_b, const void* p_b,
std::array<const void*, NumDTensor> p_ds, std::array<const void*, NumDTensor> p_ds,
void* p_e, void* p_e,
const std::vector<index_t>& a_ms_ns_lengths, const std::vector<index_t>& a_ms_ks_lengths,
const std::vector<index_t>& a_ms_ks_strides, const std::vector<index_t>& a_ms_ks_strides,
const std::vector<index_t>& b_ns_ks_lengths, const std::vector<index_t>& b_ns_ks_lengths,
const std::vector<index_t>& b_ns_ks_strides, const std::vector<index_t>& b_ns_ks_strides,
...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -745,7 +734,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
p_b, p_b,
p_ds, p_ds,
p_e, p_e,
a_ms_ns_lengths, a_ms_ks_lengths,
a_ms_ks_strides, a_ms_ks_strides,
b_ns_ks_lengths, b_ns_ks_lengths,
b_ns_ks_strides, b_ns_ks_strides,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cassert>
#include <sstream>
#include <vector>
#include "ck/ck.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* Calculates the maximum number of subsequent elements of the fast changing dimension
* that are consecutive in memory.
*
* Example:
* NumDimM = 2, NumDimK = 3
* A shape = [ 2, 3, 4, 5, 6]
* A strides = [360, 120, 30, 6, 1]
* | M | | K |
* It follows from strides that K is FCD and all the subsequent elements of K are consecutive
* in memory.
* But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be
* consecutive in memory.
*
* Assumes that the dimensions are split into two groups of `NumDim1` and `NumDim2` dimensions.
*/
template <index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<index_t>& strides)
{
if(lengths.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of lengths in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
if(strides.size() != NumDim1 + NumDim2)
{
std::ostringstream err;
err << "Incorrect number of strides in " << __FILE__ << ":" << __LINE__
<< ", in function: " << __func__;
throw std::runtime_error(err.str());
}
// Determine the beginning and end idx of the group representing the FCD.
index_t begin_idx, end_idx;
if(strides[NumDim1 - 1] == 1)
{
begin_idx = 0;
end_idx = NumDim1 - 1;
}
else if(strides[NumDim1 + NumDim2 - 1] == 1)
{
begin_idx = NumDim1;
end_idx = NumDim1 + NumDim2 - 1;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return 1;
}
index_t consecutive_stride = 1;
for(index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
{
if(strides[dim_idx] == consecutive_stride)
{
consecutive_stride *= lengths[dim_idx];
}
else
{
break;
}
}
const index_t max_subsequent_elems = consecutive_stride;
return max_subsequent_elems;
}
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/stream_utility.hpp" #include "ck/host_utility/stream_utility.hpp"
namespace ck { namespace ck {
...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple, ...@@ -292,6 +293,12 @@ struct DeviceElementwise3dImpl : public DeviceElementwise<InDataTypeTuple,
bool IsSupportedArgument(const BaseArgument* p_arg) override bool IsSupportedArgument(const BaseArgument* p_arg) override
{ {
if((ck::get_device_name() == "gfx940" || ck::get_device_name() == "gfx941" ||
ck::get_device_name() == "gfx942"))
{
return false;
}
const Argument* pArg = dynamic_cast<const Argument*>(p_arg); const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
if(pArg == nullptr) if(pArg == nullptr)
......
...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle ...@@ -821,7 +821,9 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v4,
typename ComputeDataType = EDataType>
struct DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
: public DeviceGemmMultipleD<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr auto I1 = Number<1>{};
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad<
ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
DsDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load<
GridwiseGemm,
ADataType,
BDataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_ds_grid_,
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!ck::is_lds_direct_load_supported())
{
return false;
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of Ds.
// For now, only the RowMajor layout is supported.
bool all_valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
if constexpr(!is_same_v<DLayout, Row>)
{
all_valid = false;
}
});
if(!all_valid)
{
return false;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
return Argument{p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return std::make_unique<Argument>(p_a,
p_b,
p_ds,
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
StrideDs,
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}};
// clang-format off
str << "DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferScalarPerVector << ", "
<< BBlockTransferScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer];
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename EDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched = make_default_loop_scheduler(),
PipelineVersion PipelineVer = PipelineVersion::v4,
typename ComputeDataType = EDataType>
struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm<ALayout,
BLayout,
ELayout,
ADataType,
BDataType,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
static constexpr auto I1 = Number<1>{};
using GridwiseGemm = GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad<
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
CShuffleDataType,
ck::Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
InMemoryDataOperationEnum::Set,
GemmSpec,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEBlockTransferScalarPerVector_NPerBlock,
LoopSched,
PipelineVer>;
using Argument = typename GridwiseGemm::Argument;
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
const index_t grid_size =
arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
auto launch_kernel = [&](auto has_main_k_block_loop) {
constexpr bool has_main_loop = has_main_k_block_loop.value;
const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load<
GridwiseGemm,
ADataType,
BDataType,
typename GridwiseGemm::DsGridPointer,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
typename GridwiseGemm::AGridDesc_AK0_M_AK1,
typename GridwiseGemm::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::Block2ETileMap,
has_main_loop>;
return launch_and_time_kernel(stream_config,
kernel,
dim3(grid_size),
dim3(BlockSize),
0,
arg.p_a_grid_,
arg.p_b_grid_,
// arg.p_ds_grid_,
ck::Tuple<>{},
arg.p_e_grid_,
arg.a_element_op_,
arg.b_element_op_,
arg.cde_element_op_,
arg.a_grid_desc_ak0_m_ak1_,
arg.b_grid_desc_bk0_n_bk1_,
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_,
arg.block_2_etile_map_);
};
const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
return launch_kernel(integral_constant<bool, true>{});
}
else
{
return launch_kernel(integral_constant<bool, false>{});
}
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!ck::is_lds_direct_load_supported())
{
return false;
}
// Check vector load/store.
{
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
// Check vector load of A.
if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
{
if(arg.MRaw_ % ABlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of B.
if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
{
if(arg.KRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
{
if(arg.NRaw_ % BBlockTransferScalarPerVector != 0)
{
return false;
}
}
else
{
return false;
}
// Check vector load of E.
// For now, only the RowMajor layout is supported.
if constexpr(is_same_v<ELayout, Row>)
{
if(arg.NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
{
return false;
}
}
else
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_,
arg.b_grid_desc_n_k_,
arg.ds_grid_desc_m_n_,
arg.e_grid_desc_m_n_,
arg.block_2_etile_map_);
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
{
using EmptyDsPointers = std::array<const void*, 0>;
using EmptyDsStrides = std::array<ck::index_t, 0>;
return Argument{p_a,
p_b,
EmptyDsPointers{},
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
EmptyDsStrides{},
StrideE,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_e,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
using EmptyDsPointers = std::array<const void*, 0>;
using EmptyDsStrides = std::array<ck::index_t, 0>;
return std::make_unique<Argument>(p_a,
p_b,
EmptyDsPointers{},
p_e,
MRaw,
NRaw,
KRaw,
StrideA,
StrideB,
EmptyDsStrides{},
StrideE,
a_element_op,
b_element_op,
cde_element_op);
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<LoopScheduler, std::string> LoopSchedToString{
{LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
std::map<PipelineVersion, std::string> PipelineVersionToString{
{PipelineVersion::v1, "v1"}, {PipelineVersion::v2, "v2"}, {PipelineVersion::v4, "v4"}};
// clang-format off
str << "DeviceGemm_Xdl_CShuffle_LdsDirectLoad"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "
<< NPerBlock << ", "
<< KPerBlock << ", "
<< AK1 << ", "
<< BK1 << ", "
<< MPerXDL << ", "
<< NPerXDL << ", "
<< MXdlPerWave << ", "
<< NXdlPerWave << ", "
<< ABlockTransferScalarPerVector << ", "
<< BBlockTransferScalarPerVector << ", "
<< CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << ", "
<< getGemmSpecializationString(GemmSpec)
<< ">"
<< " LoopScheduler: "
<< LoopSchedToString[LoopSched] << ", "
<< "PipelineVersion: "
<< PipelineVersionToString[PipelineVer] << ", "
<< "Prefetch: "
<< NumGemmKPrefetchStage;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout, ...@@ -226,7 +226,9 @@ struct DeviceGemmXdlStreamK : public DeviceGemmStreamK<ALayout,
} }
} }
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -357,15 +357,17 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
return out_gemmm_gemmn_desc; return out_gemmm_gemmn_desc;
} }
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
static auto MakeDsGridDescriptor_M_N( static auto MakeDsGridDescriptor_M_N(
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths, const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides) const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>; using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(ds_g_n_k_wos_lengths[i], return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(e_g_n_k_wos_lengths,
ds_g_n_k_wos_strides[i]); ds_g_n_k_wos_strides[i]);
}, },
Number<NumDTensor>{}); Number<NumDTensor>{});
...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -569,7 +571,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// D desc // D desc
ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>( ds_grid_desc_m_n_(i) = DeviceOp::MakeEGridDescriptor_M_N<DLayout>(
ds_g_n_k_wos_lengths[i], ds_g_n_k_wos_strides[i]); e_g_n_k_wos_lengths, ds_g_n_k_wos_strides[i]);
}); });
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0]; compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
...@@ -836,7 +838,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -836,7 +838,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -853,7 +855,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -853,7 +855,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// check if it's 1x1 conv // check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -916,8 +918,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> || is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -925,6 +926,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{ {
valid = false; valid = false;
} }
if constexpr(is_same_v<DLayout, ctc::G_K>)
{
// G and K must be the same
if(arg.ds_g_n_k_wos_lengths_[i][0] != arg.e_g_n_k_wos_lengths_[0] ||
arg.ds_g_n_k_wos_lengths_[i][2] != arg.e_g_n_k_wos_lengths_[2])
{
valid = false;
}
}
else
{
// E and D must have the same shape
for(index_t d = 0; d < NDimSpatial + 3; d++)
{
if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
{
valid = false;
}
}
}
} }
else else
{ {
...@@ -1094,6 +1116,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle ...@@ -1094,6 +1116,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<< NXdlPerWave << ", " << NXdlPerWave << ", "
<< ABlockTransferSrcScalarPerVector << ", " << ABlockTransferSrcScalarPerVector << ", "
<< BBlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", "
<< CDEBlockTransferScalarPerVector_NPerBlock << ", "
<< CShuffleMXdlPerWavePerShuffle << ", " << CShuffleMXdlPerWavePerShuffle << ", "
<< CShuffleNXdlPerWavePerShuffle << CShuffleNXdlPerWavePerShuffle
<< ">"; << ">";
......
...@@ -834,7 +834,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -834,7 +834,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// check if it's 1x1, stride=1 conv // check if it's 1x1, stride=1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t ConvStride = arg.conv_filter_strides_[i]; const index_t ConvStride = arg.conv_filter_strides_[i];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
...@@ -851,7 +851,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle ...@@ -851,7 +851,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
// check if it's 1x1 conv // check if it's 1x1 conv
for(index_t i = 0; i < NDimSpatial; ++i) for(index_t i = 0; i < NDimSpatial; ++i)
{ {
const index_t X = arg.b_g_k_c_xs_lengths_[i + 2]; const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
const index_t LeftPad = arg.input_left_pads_[i]; const index_t LeftPad = arg.input_left_pads_[i];
const index_t RightPad = arg.input_right_pads_[i]; const index_t RightPad = arg.input_right_pads_[i];
......
...@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle ...@@ -631,8 +631,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> || is_same_v<DLayout, ctc::G_NDHW_K> || is_same_v<DLayout, ctc::GNWK> ||
is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> || is_same_v<DLayout, ctc::GNHWK> || is_same_v<DLayout, ctc::GNDHWK> ||
is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> || is_same_v<DLayout, ctc::NWGK> || is_same_v<DLayout, ctc::NHWGK> ||
is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::GK> || is_same_v<DLayout, ctc::NDHWGK> || is_same_v<DLayout, ctc::G_K>)
is_same_v<DLayout, ctc::G_K>)
{ {
const index_t K = arg.ds_g_n_k_wos_lengths_[i][2]; const index_t K = arg.ds_g_n_k_wos_lengths_[i][2];
......
...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout, ...@@ -817,12 +817,15 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>); return arg.group_count_ * sizeof(GroupedGemmKernelArgument<NumDTensor>);
} }
void SetWorkSpacePointer(BaseArgument* p_arg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* p_arg,
void* p_workspace,
const StreamConfig& stream_config = StreamConfig{}) const override
{ {
auto p_arg_ = dynamic_cast<Argument*>(p_arg); auto p_arg_ = dynamic_cast<Argument*>(p_arg);
p_arg_->p_workspace_ = p_workspace; p_arg_->p_workspace_ = p_workspace;
hip_check_error(hipMemset(p_workspace, 0, GetWorkSpaceSize(p_arg))); hip_check_error(
hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(p_arg), stream_config.stream_id_));
} }
static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_data.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_bwd_data.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is Invariant dimension, K is reduced dimension
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseNormalizationBwd,
typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename DXDataType,
typename GridDesc_M_K>
__global__ void
kernel_normalization_bwd_data(const GridDesc_M_K dy_grid_desc_m_k,
const GridDesc_M_K x_grid_desc_m_k,
const GridDesc_M_K gamma_grid_desc_m_k,
const GridDesc_M_K mean_grid_desc_m_k,
const GridDesc_M_K inv_std_grid_desc_m_k,
const GridDesc_M_K dx_grid_desc_m_k,
index_t num_k_block_tile_iteration,
const DYDataType* const __restrict__ p_dy_global,
const XDataType* const __restrict__ p_x_global,
const GammaDataType* const __restrict__ p_gamma_global,
const MeanInvStdDataType* const __restrict__ p_mean_global,
const MeanInvStdDataType* const __restrict__ p_inv_std_global,
DXDataType* const __restrict__ p_dx_global)
{
GridwiseNormalizationBwd::Run(dy_grid_desc_m_k,
x_grid_desc_m_k,
gamma_grid_desc_m_k,
mean_grid_desc_m_k,
inv_std_grid_desc_m_k,
dx_grid_desc_m_k,
num_k_block_tile_iteration,
p_dy_global,
p_x_global,
p_gamma_global,
p_mean_global,
p_inv_std_global,
p_dx_global);
};
template <typename DYDataType,
typename XDataType,
typename GammaDataType,
typename MeanInvStdDataType,
typename ComputeDataType,
typename DXDataType,
index_t Rank,
index_t NumReduceDim,
index_t BlockSize,
index_t MThreadClusterSize,
index_t KThreadClusterSize,
index_t MThreadSliceSize,
index_t KThreadSliceSize,
bool IsDYFastestDimReduced,
index_t DYSrcVectorSize,
bool IsXFastestDimReduced,
index_t XSrcVectorSize,
bool IsGammaFastestDimReduced,
index_t GammaSrcVectorSize,
bool IsMeanInvStdFastestDimReduced,
index_t MeanInvStdSrcVectorSize,
bool IsDxFastestDimReduced,
index_t DXDstVectorSize>
struct DeviceNormalizationBwdDataImpl : public DeviceNormalizationBwdData<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
Rank,
NumReduceDim>
{
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t GammaSrcVectorDim = IsGammaFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
static constexpr index_t DXDstVectorDim = IsDxFastestDimReduced ? 1 : 0;
static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize % DYSrcVectorSize == 0) ||
(DYSrcVectorDim == 1 && KThreadSliceSize % DYSrcVectorSize == 0)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
(GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!");
static_assert(((DXDstVectorDim == 0 && MThreadSliceSize % DXDstVectorSize == 0) ||
(DXDstVectorDim == 1 && KThreadSliceSize % DXDstVectorSize == 0)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
static constexpr bool reduceAllDim = (NumInvariantDim == 0);
static_assert(!reduceAllDim);
static auto Make2dDescriptor(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides,
int numBlockTileIteration)
{
const auto tupleLengths = make_tuple_from_array(lengths, Number<Rank>{});
const auto tupleStrides = make_tuple_from_array(strides, Number<Rank>{});
const auto desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
const auto grid_desc_m_k = [&]() {
using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
using ReduceDims = typename arithmetic_sequence_gen<NumInvariantDim, Rank, 1>::type;
const auto reduceDimLengths =
make_tuple_from_array_and_index_seq(lengths, ReduceDims{});
const auto invariantDimLengths =
make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
return transform_tensor_descriptor(desc,
make_tuple(make_merge_transform(invariantDimLengths),
make_merge_transform(reduceDimLengths)),
make_tuple(InvariantDims{}, ReduceDims{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}();
const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
const auto pad_M =
math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
const auto pad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
auto grid_desc_m_k_padded =
transform_tensor_descriptor(grid_desc_m_k,
make_tuple(make_right_pad_transform(invariantLength, pad_M),
make_right_pad_transform(reduceLength, pad_K)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return grid_desc_m_k_padded;
}
using GridDesc_M_K = decltype(Make2dDescriptor({1}, {1}, 1));
using GridwiseNormalizationBwdDataGeneric =
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorDim,
DYSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
DXDstVectorDim,
DXDstVectorSize,
false>;
using GridwiseNormalizationBwdDataSweepOnce =
GridwiseNormalizationBwdData_mk_to_mk<DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
ComputeDataType,
DXDataType,
GridDesc_M_K,
BlockSize,
MThreadClusterSize,
KThreadClusterSize,
MThreadSliceSize,
KThreadSliceSize,
DYSrcVectorDim,
DYSrcVectorSize,
XSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
MeanInvStdSrcVectorDim,
MeanInvStdSrcVectorSize,
DXDstVectorDim,
DXDstVectorSize,
true>;
struct Argument : public BaseArgument
{
Argument(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const DYDataType* p_dy,
const XDataType* p_x,
const GammaDataType* p_gamma,
const MeanInvStdDataType* p_mean,
const MeanInvStdDataType* p_invStd,
DXDataType* p_dx)
: p_dy_(p_dy),
p_x_(p_x),
p_gamma_(p_gamma),
p_mean_(p_mean),
p_invStd_(p_invStd),
p_dx_(p_dx)
{
lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
dyStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dyStrides, reduceDims);
xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
meanStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(meanStrides, reduceDims);
invStdStrides_ =
shuffle_tensor_dimensions<Rank, NumReduceDim>(invStdStrides, reduceDims);
dxStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(dxStrides, reduceDims);
std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(lengths_);
numBlockTileIteration_ = math::integer_divide_ceil(KRaw_, K_BlockTileSize);
gridSize_ = math::integer_divide_ceil(MRaw_, M_BlockTileSize);
dy_grid_desc_m_k_ = Make2dDescriptor(lengths_, dyStrides_, numBlockTileIteration_);
x_grid_desc_m_k_ = Make2dDescriptor(lengths_, xStrides_, numBlockTileIteration_);
gamma_grid_desc_m_k_ =
Make2dDescriptor(lengths_, gammaStrides_, numBlockTileIteration_);
mean_grid_desc_m_k_ = Make2dDescriptor(lengths_, meanStrides_, numBlockTileIteration_);
inv_std_grid_desc_m_k_ =
Make2dDescriptor(lengths_, invStdStrides_, numBlockTileIteration_);
dx_grid_desc_m_k_ = Make2dDescriptor(lengths_, dxStrides_, numBlockTileIteration_);
isSweeponce_ = dy_grid_desc_m_k_.GetLength(Number<1>{}) <= K_BlockTileSize;
}
const DYDataType* p_dy_;
const XDataType* p_x_;
const GammaDataType* p_gamma_;
const MeanInvStdDataType* p_mean_;
const MeanInvStdDataType* p_invStd_;
DXDataType* p_dx_;
std::vector<index_t> lengths_;
std::vector<index_t> dyStrides_;
std::vector<index_t> xStrides_;
std::vector<index_t> gammaStrides_;
std::vector<index_t> meanStrides_;
std::vector<index_t> invStdStrides_;
std::vector<index_t> dxStrides_;
int numBlockTileIteration_;
size_t gridSize_;
// tensor descriptor
GridDesc_M_K dy_grid_desc_m_k_;
GridDesc_M_K x_grid_desc_m_k_;
GridDesc_M_K gamma_grid_desc_m_k_;
GridDesc_M_K mean_grid_desc_m_k_;
GridDesc_M_K inv_std_grid_desc_m_k_;
GridDesc_M_K dx_grid_desc_m_k_;
bool isSweeponce_;
index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length
};
struct Invoker : public BaseInvoker
{
auto KernelSelector(bool isSweepOnce)
{
return isSweepOnce
? kernel_normalization_bwd_data<GridwiseNormalizationBwdDataSweepOnce,
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
GridDesc_M_K>
: kernel_normalization_bwd_data<GridwiseNormalizationBwdDataGeneric,
DYDataType,
XDataType,
GammaDataType,
MeanInvStdDataType,
DXDataType,
GridDesc_M_K>;
}
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
const auto kernel_main = KernelSelector(arg.isSweeponce_);
return launch_and_time_kernel(stream_config,
kernel_main,
dim3(arg.gridSize_),
dim3(BlockSize),
0,
arg.dy_grid_desc_m_k_,
arg.x_grid_desc_m_k_,
arg.gamma_grid_desc_m_k_,
arg.mean_grid_desc_m_k_,
arg.inv_std_grid_desc_m_k_,
arg.dx_grid_desc_m_k_,
arg.numBlockTileIteration_,
arg.p_dy_,
arg.p_x_,
arg.p_gamma_,
arg.p_mean_,
arg.p_invStd_,
arg.p_dx_);
}
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
template <index_t SrcVectorDim, index_t SrcVectorSize>
bool IsVectorDimSizeValid(const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
{
if constexpr(SrcVectorSize == 1)
return true;
// Fastest dimension is not reduced
if constexpr(SrcVectorDim == 0)
{
if constexpr(NumInvariantDim == 0)
return false;
if(strides[NumInvariantDim - 1] != 1)
return false;
if(lengths[NumInvariantDim - 1] % SrcVectorSize != 0)
return false;
}
else // Fastest dimension is reduced
{
if(strides[Rank - 1] != 1)
return false;
if(lengths[Rank - 1] % SrcVectorSize != 0)
return false;
};
return true;
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
bool pass = true;
pass &= IsVectorDimSizeValid<DYSrcVectorDim, DYSrcVectorSize>(p_arg_->lengths_,
p_arg_->dyStrides_);
pass &= IsVectorDimSizeValid<XSrcVectorDim, XSrcVectorSize>(p_arg_->lengths_,
p_arg_->xStrides_);
pass &= IsVectorDimSizeValid<GammaSrcVectorDim, GammaSrcVectorSize>(p_arg_->lengths_,
p_arg_->gammaStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->meanStrides_);
pass &= IsVectorDimSizeValid<MeanInvStdSrcVectorDim, MeanInvStdSrcVectorSize>(
p_arg_->lengths_, p_arg_->invStdStrides_);
pass &= IsVectorDimSizeValid<DXDstVectorDim, DXDstVectorSize>(p_arg_->lengths_,
p_arg_->dxStrides_);
return pass;
}
std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> lengths,
const std::vector<index_t> dyStrides,
const std::vector<index_t> xStrides,
const std::vector<index_t> gammaStrides,
const std::vector<index_t> meanStrides,
const std::vector<index_t> invStdStrides,
const std::vector<index_t> dxStrides,
const std::vector<index_t> reduceDims,
const void* p_dy,
const void* p_x,
const void* p_gamma,
const void* p_mean,
const void* p_invStd,
void* p_dx) override
{
if(lengths.size() != Rank || dyStrides.size() != Rank || xStrides.size() != Rank ||
gammaStrides.size() != Rank || meanStrides.size() != Rank ||
invStdStrides.size() != Rank || dxStrides.size() != Rank)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(lengths,
dyStrides,
xStrides,
gammaStrides,
meanStrides,
invStdStrides,
dxStrides,
reduceDims,
static_cast<const DYDataType*>(p_dy),
static_cast<const XDataType*>(p_x),
static_cast<const GammaDataType*>(p_gamma),
static_cast<const MeanInvStdDataType*>(p_mean),
static_cast<const MeanInvStdDataType*>(p_invStd),
static_cast<DXDataType*>(p_dx));
}
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>();
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdDataImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "DYSrcVectorSize" << DYSrcVectorSize << "_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_MeanRstd" << MeanInvStdSrcVectorSize << "_Dx" << DXDstVectorSize;
str << ">";
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension // M is Invariant dimension, K is reduced dimension
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -87,7 +87,6 @@ struct DeviceNormalizationBwdGammaBetaImpl
Rank, Rank,
NumReduceDim> NumReduceDim>
{ {
static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0; static constexpr index_t DYSrcVectorDim = IsDYFastestDimReduced ? 1 : 0;
static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0; static constexpr index_t XSrcVectorDim = IsXFastestDimReduced ? 1 : 0;
static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0; static constexpr index_t MeanInvStdSrcVectorDim = IsMeanInvStdFastestDimReduced ? 1 : 0;
...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -102,18 +101,18 @@ struct DeviceNormalizationBwdGammaBetaImpl
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)), (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"); "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static_assert( static_assert(
(MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) || (MeanInvStdSrcVectorDim == 0 && MThreadSliceSize % MeanInvStdSrcVectorSize == 0) ||
(MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0), (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize % MeanInvStdSrcVectorSize == 0),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please " "Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"); "check!");
static_assert(
((MThreadSliceSize % DGammaDstVectorSize == 0) ||
(MThreadSliceSize % DBetaDstVectorSize == 0)),
"Invalid thread slice sizes and/or Gamma and beta vector sizes configuration, please "
"check!");
static constexpr index_t NumInvariantDim = Rank - NumReduceDim; static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize; static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize; static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -298,7 +297,7 @@ struct DeviceNormalizationBwdGammaBetaImpl
GridDesc_M dgamma_grid_desc_m_; GridDesc_M dgamma_grid_desc_m_;
GridDesc_M dbeta_grid_desc_m_; GridDesc_M dbeta_grid_desc_m_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
}; };
...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl ...@@ -457,6 +456,21 @@ struct DeviceNormalizationBwdGammaBetaImpl
{ {
return std::make_unique<Invoker>(); return std::make_unique<Invoker>();
} }
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "DeviceNormalizationBwdGammaBetaImpl<" << BlockSize << ",";
str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
str << "VectorSize_DY" << DYSrcVectorSize << "_X" << XSrcVectorSize ;
str << "_DGamma" << DGammaDstVectorSize << "_DBeta" << DBetaDstVectorSize << ">";
// clang-format on
return str.str();
}
}; };
} // namespace device } // namespace device
......
...@@ -19,7 +19,7 @@ namespace tensor_operation { ...@@ -19,7 +19,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -263,7 +263,7 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
GridDesc_M save_inv_std_grid_desc_m_; GridDesc_M save_inv_std_grid_desc_m_;
bool isSweeponce_; bool isSweeponce_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType, ...@@ -342,8 +342,6 @@ struct DeviceNormalizationFwdImpl : public DeviceNormalizationFwd<XDataType,
} }
else else
{ {
printf("!!!! %d\n", p_arg_->invariant_lowest_length_);
if(p_arg_->xStrides_[NumInvariantDim - 1] != 1) if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
return false; return false;
......
...@@ -108,7 +108,7 @@ namespace tensor_operation { ...@@ -108,7 +108,7 @@ namespace tensor_operation {
namespace device { namespace device {
// Y = Normalization(X, Beta, Gamma) // Y = Normalization(X, Beta, Gamma)
// M: Invarient length // M: Invariant length
// K: Reduce length (Calculate mean and variance along K dimension) // K: Reduce length (Calculate mean and variance along K dimension)
// eg. Length = [N, C, H, W], reduce dim = [C, H, W] // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
// Then, M = N, K = C * H * W // Then, M = N, K = C * H * W
...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -468,7 +468,7 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_; Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_;
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_; Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_;
index_t MRaw_; // invarient length index_t MRaw_; // Invariant length
index_t KRaw_; // reduce length index_t KRaw_; // reduce length
index_t invariant_lowest_length_; index_t invariant_lowest_length_;
...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp ...@@ -577,7 +577,9 @@ struct DeviceNormalizationFwdSplitKImpl : public DeviceNormalizationFwd<XDataTyp
return (workspace_size); return (workspace_size);
}; };
void SetWorkSpacePointer(BaseArgument* pArg, void* p_workspace) const override void SetWorkSpacePointer(BaseArgument* pArg,
void* p_workspace,
const StreamConfig& = StreamConfig{}) const override
{ {
Argument* pArg_ = dynamic_cast<Argument*>(pArg); Argument* pArg_ = dynamic_cast<Argument*>(pArg);
......
...@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout ...@@ -308,12 +308,6 @@ struct GNDHWK : public BaseTensorLayout
static constexpr const char* name = "GNDHWK"; static constexpr const char* name = "GNDHWK";
}; };
// for output bias
struct GK : public BaseTensorLayout
{
static constexpr const char* name = "GK";
};
// output tensor // output tensor
// packed NWGK/NHWGK/NDHWGK // packed NWGK/NHWGK/NDHWGK
struct NWGK : public BaseTensorLayout struct NWGK : public BaseTensorLayout
......
...@@ -281,6 +281,24 @@ struct ConvertF8SR ...@@ -281,6 +281,24 @@ struct ConvertF8SR
} }
}; };
struct ConvertF8RNE
{
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
// check Y datatype
static_assert(is_same<Y, f8_t>::value || is_same<Y, bf8_t>::value,
"Data type is not supported by this operation!");
// check X datatype
static_assert(is_same<X, float>::value || is_same<X, half_t>::value,
"Data type is not supported by this operation!");
y = f8_convert_rne<Y>(x);
}
};
struct Scale struct Scale
{ {
__host__ __device__ Scale(float scale) : scale_(scale) {} __host__ __device__ Scale(float scale) : scale_(scale) {}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
namespace ck {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename DsPointer,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2ETileMap,
bool HasMainKBlockLoop>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
#endif
kernel_gemm_multiple_d_xdl_cshuffle_lds_direct_load(
const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
const AElementwiseOperation a_element_op,
const BElementwiseOperation b_element_op,
const CDEElementwiseOperation cde_element_op,
const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap block_2_etile_map)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
cde_element_op,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
#else
ignore = p_a_grid;
ignore = p_b_grid;
ignore = p_ds_grid;
ignore = p_e_grid;
ignore = a_element_op;
ignore = b_element_op;
ignore = cde_element_op;
ignore = a_grid_desc_ak0_m_ak1;
ignore = b_grid_desc_bk0_n_bk1;
ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
ignore = block_2_etile_map;
#endif
}
// GEMM:
// input : A[M, K]
// input : B[N, K]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename BDataType,
typename AComputeDataType_,
typename AccDataType,
typename CShuffleDataType,
typename DsDataType,
typename EDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerXdl,
index_t NPerXdl,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferScalarPerVector,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferScalarPerVector,
index_t BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
LoopScheduler LoopSched,
PipelineVersion PipelineVer = PipelineVersion::v4,
typename BComputeDataType = AComputeDataType_>
struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
{
static constexpr index_t NumDTensor = DsDataType::Size();
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
static constexpr auto I7 = Number<7>{};
static constexpr auto AK1 = Number<AK1Value>{};
static constexpr auto BK1 = Number<BK1Value>{};
static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
#if CK_WORKAROUND_DENORM_FIX
using AComputeDataType =
conditional_t<is_same_v<AComputeDataType_, ck::half_t>, ck::bhalf_t, AComputeDataType_>;
#else
using AComputeDataType = AComputeDataType_;
#endif
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
}
__host__ __device__ static constexpr auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
{
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
make_naive_tensor_descriptor_packed(
make_tuple(I1,
Number<CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl>{},
I1,
Number<CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>{}));
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
}
// ck::Tuple<const D0DataType*, const D1DataType*, ...>
static constexpr auto MakeDsGridPointer()
{
return generate_tuple(
[&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
return static_cast<const DDataType*>(nullptr);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
constexpr auto max_lds_align = math::lcm(AK1, BK1);
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// LDS allocation for C shuffle.
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max(
NumGemmKPrefetchStage * a_block_space_size_aligned * sizeof(AComputeDataType) +
NumGemmKPrefetchStage * b_block_space_size_aligned * sizeof(BComputeDataType),
c_block_size * sizeof(CShuffleDataType));
}
__host__ __device__ static auto
MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(StrideA, I1));
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
make_tuple(I1, StrideA));
}
}();
return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
}
__host__ __device__ static auto
MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto b_grid_desc_nraw_kraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(I1, StrideB));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
make_tuple(StrideB, I1));
}
}();
return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
}
__host__ __device__ static auto
MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
{
constexpr auto matrix_padder =
ck::tensor_operation::device::MatrixPadder<GemmSpec, index_t, index_t, index_t>{
MPerBlock, NPerBlock, KPerBlock};
const auto e_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(StrideE, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ELayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
make_tuple(I1, StrideE));
}
}();
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
}
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride)
{
return generate_tuple(
[&](auto i) { return MakeEGridDescriptor_M_N(MRaws[i], NRaws[i], DsStride[i]); },
Number<NumDTensor>{});
}
using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
using DsGridDesc_M_N = remove_cvref_t<decltype(MakeDsGridDescriptor_M_N({}, {}, {}))>;
using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N(1, 1, 1));
// A desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
{
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto K = a_grid_desc_m_k.GetLength(I1);
const auto AK0 = K / AK1;
return transform_tensor_descriptor(a_grid_desc_m_k,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// B desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
{
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto K = b_grid_desc_n_k.GetLength(I1);
const auto BK0 = K / BK1;
return transform_tensor_descriptor(b_grid_desc_n_k,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(N)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
}
// E desc for destination in blockwise copy.
__host__ __device__ static constexpr auto
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
{
const auto M = e_grid_desc_m_n.GetLength(I0);
const auto N = e_grid_desc_m_n.GetLength(I1);
const auto MBlock = M / MPerBlock;
const auto NBlock = N / NPerBlock;
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
e_grid_desc_m_n,
make_tuple(make_unmerge_transform(make_tuple(MBlock, Number<MPerBlock>{})),
make_unmerge_transform(make_tuple(NBlock, Number<NPerBlock>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
return e_grid_desc_mblock_mperblock_nblock_nperblock;
}
// Ds desc for source in blockwise copy.
__host__ __device__ static constexpr auto
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
{
return generate_tuple(
[&](auto i) {
return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
},
Number<NumDTensor>{});
}
__host__ __device__ static constexpr auto
MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
{
return BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, EGridDesc_M_N>(
e_grid_desc_m_n);
}
using AGridDesc_AK0_M_AK1 =
remove_cvref_t<decltype(MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))>;
using BGridDesc_BK0_N_BK1 =
remove_cvref_t<decltype(MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))>;
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
DsGridDesc_M_N{}))>;
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
remove_cvref_t<decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
EGridDesc_M_N{}))>;
using Block2ETileMap = remove_cvref_t<decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))>;
__host__ __device__ static constexpr bool CheckValidity(const AGridDesc_M_K& a_grid_desc_m_k,
const BGridDesc_N_K& b_grid_desc_n_k,
const DsGridDesc_M_N& ds_grid_desc_m_n,
const EGridDesc_M_N& e_grid_desc_m_n,
const Block2ETileMap& block_2_etile_map)
{
static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
"KPerBlock must be divisible by AK1Value and BK1Value!");
static_assert(
std::is_same_v<AElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough> &&
std::is_same_v<BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough>,
"Direct load transfers do not support elementwise operations other than passthrough.");
const auto M = a_grid_desc_m_k.GetLength(I0);
const auto N = b_grid_desc_n_k.GetLength(I0);
const auto AK = a_grid_desc_m_k.GetLength(I1);
const auto BK = b_grid_desc_n_k.GetLength(I1);
// Check the consistency of descriptors.
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
{
return false;
}
bool valid = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
N == ds_grid_desc_m_n[i].GetLength(I1));
});
if(!valid)
{
return false;
}
// Check the tile size.
if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
{
return false;
}
// Check gridwise gemm pipeline.
const auto num_k_loop = AK / KPerBlock;
if(!GridwiseGemmPipe::IsSupported(num_k_loop))
{
return false;
}
// Check block-to-E-tile.
if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
{
return false;
}
// Check tensor size: cannot exceed 2GB.
constexpr long_index_t TwoGB = (long_index_t{1} << 31);
if(!(a_grid_desc_m_k.GetElementSpaceSize() * sizeof(ADataType) <= TwoGB &&
b_grid_desc_n_k.GetElementSpaceSize() * sizeof(BDataType) <= TwoGB &&
e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
{
return false;
}
return true;
}
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
{
const index_t num_loop = K / KPerBlock;
return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
}
using DsGridPointer = decltype(MakeDsGridPointer());
__device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
template <typename DataType>
__device__ static auto AllocateBlockBuffers(void* p_shared,
int32_t num_elems,
int32_t offset_elems,
int32_t max_lds_align)
{
const int32_t single_buffer_offset = math::integer_least_multiple(num_elems, max_lds_align);
return generate_tuple(
[&](auto i) {
const int32_t local_offset = i * single_buffer_offset;
return make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<DataType*>(p_shared) + local_offset + offset_elems, num_elems);
},
Number<NumGemmKPrefetchStage>{});
}
template <bool HasMainKBlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
__device__ static void Run(const ADataType* __restrict__ p_a_grid,
const BDataType* __restrict__ p_b_grid,
DsGridPointer p_ds_grid,
EDataType* __restrict__ p_e_grid,
void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op,
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
e_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2ETileMap& block_2_etile_map)
{
// Elementwise operations are not supported for A and B, arguments left only for the API
// consistency.
(void)a_element_op;
(void)b_element_op;
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
const auto ds_grid_buf = generate_tuple(
[&](auto i) {
return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
},
Number<NumDTensor>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
// Divide block work by [M, N].
const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_etile_map.ValidCTileIndex(
block_work_idx,
make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
// This forces m/n_block_data_idx_on_grid into SGPR.
const index_t m_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
constexpr auto max_lds_align = math::lcm(AK1, BK1);
// A matrix in LDS memory, destination of blockwise copy.
constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
// B matrix in LDS memory, destination of blockwise copy.
constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
auto a_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
make_multi_index(0, 0, 0));
auto b_blockwise_copy =
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
make_multi_index(0, 0, 0));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr index_t KPack = math::max(
math::lcm(AK1, BK1),
MfmaSelector<AComputeDataType, MPerXdl, NPerXdl, BComputeDataType>::selected_mfma
.k_per_blk);
auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
BlockSize,
AComputeDataType,
BComputeDataType,
AccDataType,
decltype(a_block_desc_ak0_m_ak1),
decltype(b_block_desc_bk0_n_bk1),
MPerXdl,
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack,
LoopSched>();
auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
// LDS allocation for A and B: be careful of alignment.
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
auto a_block_buffers = AllocateBlockBuffers<AComputeDataType>(
p_shared, a_block_desc_ak0_m_ak1.GetElementSpaceSize(), 0, max_lds_align);
const auto b_buffers_offset = a_block_space_size_aligned * NumGemmKPrefetchStage;
auto b_block_buffers =
AllocateBlockBuffers<BComputeDataType>(p_shared,
b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
b_buffers_offset,
max_lds_align);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
const auto gridwise_gemm_pipeline =
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buffers,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buffers,
b_block_slice_copy_step,
blockwise_gemm,
c_thread_buf,
num_k_block_main_loop);
// Shuffle C and write out.
{
static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
"wrong!");
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<CShuffleDataType*>(p_shared),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
make_tuple(
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
M1, // M1 = MWave
M2, // M2 * M3 * M4 = MPerXdl
M3,
M4)),
make_freeze_transform(I0),
make_unmerge_transform(make_tuple(
Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
N1, // N1 = NWave
N2))), // N2 = NPerXdl
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(
Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
// Calculate the origin of thread output tensor on global memory.
const auto c_thread_mtx_on_block =
blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
make_tuple(Sequence<0, 1, 2, 3, 4>{}),
make_tuple(Sequence<0>{}));
const auto m_thread_data_on_block_idx =
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
make_multi_index(m_thread_data_on_block));
const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(make_tuple(N0, N1, N2))),
make_tuple(Sequence<0, 1, 2>{}),
make_tuple(Sequence<0>{}));
const auto n_thread_data_on_block_idx =
n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
make_multi_index(n_thread_data_on_block));
// Shuffle: threadwise copy C from VGPR to LDS.
auto c_thread_copy_vgpr_to_lds =
ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
CShuffleDataType,
decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
ck::tensor_operation::element_wise::PassThrough,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
I1,
I1,
M2,
I1,
M4,
I1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
7,
1,
InMemoryDataOperationEnum::Set,
1,
true>{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
make_multi_index(0,
0,
m_thread_data_on_block_idx[I1],
n_thread_data_on_block_idx[I1],
m_thread_data_on_block_idx[I2],
m_thread_data_on_block_idx[I3],
m_thread_data_on_block_idx[I4],
n_thread_data_on_block_idx[I2]),
ck::tensor_operation::element_wise::PassThrough{}};
// A tuple of reference to C/Ds tensor descriptors.
const auto c_ds_desc_refs = concat_tuple_of_reference(
tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{}));
// A tuple of reference to C/Ds grid buffers.
const auto c_ds_buf_refs = concat_tuple_of_reference(
tie(c_shuffle_block_buf),
generate_tie(
[&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; },
Number<NumDTensor>{}));
// A tuple of starting index of C/Ds blockwise copy.
const auto idx_c_ds_block_begin = container_concat(
make_tuple(make_multi_index(0, 0, 0, 0)),
generate_tuple(
[&](auto) {
return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
},
Number<NumDTensor>{}));
// Blockwise copy C/D/E between LDS and global.
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
Tuple<EDataType>,
decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
3, // index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t<
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs,
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
cde_element_op};
// Space filling curve for threadwise C in VGPR before shuffle.
constexpr auto sfc_c_vgpr =
SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
Sequence<CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
1,
1,
M2,
1,
M4,
1>>{};
// Space filling curve for shuffled blockwise C/D/E.
constexpr auto sfc_cde_block =
SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
Sequence<0, 2, 1, 3>,
Sequence<1,
CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1,
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
static_for<0, num_access, 1>{}([&](auto access_id) {
// Make sure it's safe to write to LDS.
block_sync_lds();
// Each thread write its data from VGPR to LDS.
c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
c_thread_buf,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
c_shuffle_block_buf);
// Make sure it's safe to read from LDS.
block_sync_lds();
// Each block copy its data from LDS to global.
cde_block_copy_lds_and_global.Run(
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(e_grid_buf));
if constexpr(access_id < num_access - 1)
{
constexpr auto cde_lds_and_global_step =
sfc_cde_block.GetForwardStep(access_id);
// Move on Ds.
static_for<0, NumDTensor, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step);
});
// Move on E.
cde_block_copy_lds_and_global.MoveDstSliceWindow(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
I0,
cde_lds_and_global_step);
}
});
}
}
struct Argument : public tensor_operation::device::BaseArgument
{
Argument(const void* p_a_grid,
const void* p_b_grid,
std::array<const void*, NumDTensor> p_ds_grid,
void* p_e_grid,
index_t MRaw,
index_t NRaw,
index_t KRaw,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideE,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op)
: p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
p_ds_grid_{},
p_e_grid_{static_cast<EDataType*>(p_e_grid)},
a_grid_desc_m_k_{MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
b_grid_desc_n_k_{MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
ds_grid_desc_m_n_{},
e_grid_desc_m_n_{MakeEGridDescriptor_M_N(MRaw, NRaw, StrideE)},
a_grid_desc_ak0_m_ak1_{MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
b_grid_desc_bk0_n_bk1_{MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_{},
e_grid_desc_mblock_mperblock_nblock_nperblock_{},
block_2_etile_map_{MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
cde_element_op_{cde_element_op},
MRaw_{MRaw},
NRaw_{NRaw},
KRaw_{KRaw}
{
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
ds_grid_desc_m_n_(i) = MakeEGridDescriptor_M_N(MRaw, NRaw, StrideDs[i]);
});
if(CheckValidity(a_grid_desc_m_k_,
b_grid_desc_n_k_,
ds_grid_desc_m_n_,
e_grid_desc_m_n_,
block_2_etile_map_))
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n_);
e_grid_desc_mblock_mperblock_nblock_nperblock_ =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n_);
}
}
void Print() const
{
std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
static_for<0, NumDTensor, 1>{}(
[&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
}
// Pointers
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
DsGridPointer p_ds_grid_;
EDataType* p_e_grid_;
// Tensor descriptors for problem definiton
AGridDesc_M_K a_grid_desc_m_k_;
BGridDesc_N_K b_grid_desc_n_k_;
DsGridDesc_M_N ds_grid_desc_m_n_;
EGridDesc_M_N e_grid_desc_m_n_;
// Tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_;
// block-to-e-tile map
Block2ETileMap block_2_etile_map_;
// element-wise ops
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_;
// For checking vector load/store
index_t MRaw_;
index_t NRaw_;
index_t KRaw_;
};
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment