Commit 96b0f78c authored by wangshaojie6's avatar wangshaojie6
Browse files

clang-format

parent 97dcc7b2
......@@ -371,7 +371,8 @@ int main(int argc, char* argv[])
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{};
......
......@@ -755,8 +755,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask
const auto thread_cluster_idx =
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
......@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
if constexpr(MaskOutUpperTriangle)
{
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) &&
((m_block_data_idx_on_grid + MPerBlock - 1) <
(gemm0_n_block_idx + NPerBlock - 1)))
{
continue;
}
......@@ -818,24 +821,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t nstartgroup =
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if (n_global > m_global)
if(n_global > m_global)
{
acc_thread_buf(acc_offset) = -ck::NumericLimits<float>::Infinity();
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset), acc_thread_buf[acc_offset]);
#else
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
#else
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(acc_offset), acc_element_op, acc_thread_buf[acc_offset]);
#endif
acc_thread_buf(acc_offset),
acc_element_op,
acc_thread_buf[acc_offset]);
#endif
}
});
});
......
......@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
is_same_v<B1Layout, Row> &&
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
......
......@@ -196,7 +196,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
......@@ -227,7 +228,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{};
......@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass &
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData);
if(do_log)
{
......@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result.mData, ",")
LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
c_gs_ms_os_device_result.mData,
",")
<< std::endl;
}
}
......
......@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{
};
......@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4,6},
{576, 576, 64, 64, 4, 6},
};
this->bench_ = true;
this->Run();
......
......@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void RunSingle(int M, int N, int K, int O, int G0, int G1)
{
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<ADataType,
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O>(
verify_, 1, false, bench_, M, N, K, O, G0, G1);
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass);
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment