Commit 96b0f78c authored by wangshaojie6's avatar wangshaojie6
Browse files

clang-format

parent 97dcc7b2
...@@ -371,7 +371,8 @@ int main(int argc, char* argv[]) ...@@ -371,7 +371,8 @@ int main(int argc, char* argv[])
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity(); if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
......
...@@ -755,8 +755,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -755,8 +755,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest(); running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask // decoder lower triangular mask
const auto thread_cluster_idx = const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1]; const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave; const index_t MPerRepeat = MPerBlock / MXdlPerWave;
...@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
if constexpr(MaskOutUpperTriangle) if constexpr(MaskOutUpperTriangle)
{ {
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); auto gemm0_n_block_idx =
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1))) __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) &&
((m_block_data_idx_on_grid + MPerBlock - 1) <
(gemm0_n_block_idx + NPerBlock - 1)))
{ {
continue; continue;
} }
...@@ -818,24 +821,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -818,24 +821,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const index_t nstartxdl = nstart + n0_i * NPerRepeat; const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4; const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) { static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4; const index_t nstartgroup =
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4; const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) { static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i; const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{}; const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if (n_global > m_global) if(n_global > m_global)
{ {
acc_thread_buf(acc_offset) = -ck::NumericLimits<float>::Infinity(); acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
} }
else else
{ {
// Acc0 elementwise Op // Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER #if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset), acc_thread_buf[acc_offset]); acc_element_op(acc_thread_buf(acc_offset),
#else acc_thread_buf[acc_offset]);
#else
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run( ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(acc_offset), acc_element_op, acc_thread_buf[acc_offset]); acc_thread_buf(acc_offset),
#endif acc_element_op,
acc_thread_buf[acc_offset]);
#endif
} }
}); });
}); });
......
...@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory< ...@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>) is_same_v<B1Layout, Row> &&
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
{ {
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
......
...@@ -196,7 +196,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -196,7 +196,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, using DeviceOp =
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CPermuteNumDims_G_M_O, CPermuteNumDims_G_M_O,
...@@ -227,7 +228,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -227,7 +228,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity(); if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{ {
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass & pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); c_gs_ms_os_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result.mData, ",") c_gs_ms_os_device_result.mData,
",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp" #include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template <typename Tuple> template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple> class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{ {
}; };
...@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest) ...@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{49, 49, 64, 64, 4, 6}, {49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6}, {64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6}, {1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4,6}, {576, 576, 64, 64, 4, 6},
}; };
this->bench_ = true; this->bench_ = true;
this->Run(); this->Run();
......
...@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test ...@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void RunSingle(int M, int N, int K, int O, int G0, int G1) void RunSingle(int M, int N, int K, int O, int G0, int G1)
{ {
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<ADataType, bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
ALayout, ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CPermuteNumDims_G_M_O>( CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment