Unverified Commit c0c52268 authored by Dan Yao's avatar Dan Yao Committed by GitHub
Browse files

Merge pull request #905 from ROCmSoftwarePlatform/mha-train-develop-grad-bias

flash attention output   bias grad
parents f04ec574 c88d1173
......@@ -282,6 +282,29 @@ struct TransformBatchedContractionContractionToBatchedGemmGemm
return matrix_padder.PadCDescriptor_M_N(
MakeCGridDescriptorPair(c_gs_ms_os_lengths_vec, c_gs_ms_os_strides_vec).second);
}
//
// C0
//
static auto MakeC0GridDescriptorPair(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeGridDescriptorPair<NumDimG, NumDimM, NumDimN, CSpec>(c_gs_ms_ns_lengths_vec,
c_gs_ms_ns_strides_vec);
}
// TODO: rename to G_MRaw_NRaw
static auto MakeC0GridDescriptor_G_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).first;
}
static auto MakeC0GridDescriptor_M_N(const std::vector<index_t>& c_gs_ms_ns_lengths_vec,
const std::vector<index_t>& c_gs_ms_ns_strides_vec)
{
return matrix_padder.PadC0Descriptor_M_N(
MakeC0GridDescriptorPair(c_gs_ms_ns_lengths_vec, c_gs_ms_ns_strides_vec).second);
}
};
} // namespace tensor_operation
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment