Commit 5de8ecfe authored by Astha Rai's avatar Astha Rai
Browse files

added header guards for gridwise gemm files:...

added header guards for gridwise gemm files: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp and gridwise_gemm_multiple_d_xdl_cshuffle.hpp
parent 02153e24
...@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -423,10 +423,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename AsLayout, GemmSpecialization GemmSpec> template <typename AsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeAsGridDescriptor_M_K(
MakeAsGridDescriptor_M_K(const std::array<index_t, NumATensor>& MRaws, #ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumATensor>& MRaws,
const ck::Array<index_t, NumATensor>& KRaws,
const ck::Array<index_t, NumATensor>& AsStride
#else
const std::array<index_t, NumATensor>& MRaws,
const std::array<index_t, NumATensor>& KRaws, const std::array<index_t, NumATensor>& KRaws,
const std::array<index_t, NumATensor>& AsStride) const std::array<index_t, NumATensor>& AsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -462,10 +469,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename BsLayout, GemmSpecialization GemmSpec> template <typename BsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeBsGridDescriptor_N_K(
MakeBsGridDescriptor_N_K(const std::array<index_t, NumBTensor>& NRaws, #ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumBTensor>& NRaws,
const ck::Array<index_t, NumBTensor>& KRaws,
const ck::Array<index_t, NumBTensor>& BsStride
#else
const std::array<index_t, NumBTensor>& NRaws,
const std::array<index_t, NumBTensor>& KRaws, const std::array<index_t, NumBTensor>& KRaws,
const std::array<index_t, NumBTensor>& BsStride) const std::array<index_t, NumBTensor>& BsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -500,10 +514,17 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
} }
template <typename DsLayout, GemmSpecialization GemmSpec> template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto MakeDsGridDescriptor_M_N(
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, #ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumDTensor>& MRaws,
const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride
#else
const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const std::array<index_t, NumDTensor>& DsStride
#endif
)
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle ...@@ -969,9 +990,15 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle
const index_t M, const index_t M,
const index_t N, const index_t N,
const index_t K, const index_t K,
#ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumATensor> StrideAs,
const ck::Array<index_t, NumBTensor> StrideBs,
const ck::Array<index_t, NumDTensor> StrideDs,
#else
const std::array<index_t, NumATensor> StrideAs, const std::array<index_t, NumATensor> StrideAs,
const std::array<index_t, NumBTensor> StrideBs, const std::array<index_t, NumBTensor> StrideBs,
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
#endif
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
...@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -473,11 +473,19 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
} }
#ifdef CK_CODE_GEN_RTC
template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto
MakeDsGridDescriptor_M_N(const ck::Array<index_t, NumDTensor>& MRaws,
const ck::Array<index_t, NumDTensor>& NRaws,
const ck::Array<index_t, NumDTensor>& DsStride)
#else
template <typename DsLayout, GemmSpecialization GemmSpec> template <typename DsLayout, GemmSpecialization GemmSpec>
__host__ __device__ static auto __host__ __device__ static auto
MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws, MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
const std::array<index_t, NumDTensor>& NRaws, const std::array<index_t, NumDTensor>& NRaws,
const std::array<index_t, NumDTensor>& DsStride) const std::array<index_t, NumDTensor>& DsStride)
#endif
{ {
return generate_tuple( return generate_tuple(
[&](auto i) { [&](auto i) {
...@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle ...@@ -941,7 +949,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const index_t K, const index_t K,
const index_t StrideA, const index_t StrideA,
const index_t StrideB, const index_t StrideB,
#ifdef CK_CODE_GEN_RTC
const ck::Array<index_t, NumDTensor> StrideDs,
#else
const std::array<index_t, NumDTensor> StrideDs, const std::array<index_t, NumDTensor> StrideDs,
#endif
const index_t StrideE, const index_t StrideE,
const Block2ETileMap& block_2_etile_map) const Block2ETileMap& block_2_etile_map)
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment