Commit b15eecba authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'mha-train-develop' of...

Merge branch 'mha-train-develop' of https://github.com/ROCmSoftwarePlatform/composable_kernel into mha-train-develop
parents 87d1e073 04c206da
...@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm ...@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
return matrix_padder.PadCDescriptor_M_N( return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second); MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
} }
//
// C0
//
static auto MakeC0GridDescriptorPair(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(c_gs_ms_ns_lengths_vec,
c_gs_ms_ns_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeC0GridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).first;
}
static auto MakeC0GridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return matrix_padder.PadC0Descriptor_M_N(
MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).second);
}
}; };
} // namespace tensor_operation } // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment