Commit 96b0f78c authored by wangshaojie6's avatar wangshaojie6
Browse files

clang-format

parent 97dcc7b2
...@@ -370,8 +370,9 @@ int main(int argc, char* argv[]) ...@@ -370,8 +370,9 @@ int main(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity(); if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
......
...@@ -755,13 +755,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -755,13 +755,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest(); running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask // decoder lower triangular mask
const auto thread_cluster_idx = const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id())); make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0]; const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1]; const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave; const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave; const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id; const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
// gemm1 K loop // gemm1 K loop
index_t gemm1_k_block_outer_index = 0; index_t gemm1_k_block_outer_index = 0;
...@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{ {
if constexpr(MaskOutUpperTriangle) if constexpr(MaskOutUpperTriangle)
{ {
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock); auto gemm0_n_block_idx =
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1))) __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) &&
((m_block_data_idx_on_grid + MPerBlock - 1) <
(gemm0_n_block_idx + NPerBlock - 1)))
{ {
continue; continue;
} }
...@@ -807,40 +810,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -807,40 +810,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
} }
else else
{ {
const index_t nstart = gemm1_k_block_outer_index * NPerBlock; const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) { static_for<0, m0, 1>{}([&](auto m0_i) {
const index_t m_global = mstart + m0_i * MPerRepeat; const index_t m_global = mstart + m0_i * MPerRepeat;
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4; const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
static_for<0, n0, 1>{}([&](auto n0_i) { static_for<0, n0, 1>{}([&](auto n0_i) {
// constexpr auto nrepeat_i = n0_i * NPerRepeat; // constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i; // const index_t nstartxdl = nstart + nrepeat_i;
const index_t nstartxdl = nstart + n0_i * NPerRepeat; const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4; const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) { static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4; const index_t nstartgroup =
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4; nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
static_for<0, n4, 1>{}([&](auto n4_i) { const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
const index_t n_global = nstartgroup + n4_i; static_for<0, n4, 1>{}([&](auto n4_i) {
const auto acc_offset = Number<acc_idx_n2 + n4_i>{}; const index_t n_global = nstartgroup + n4_i;
if (n_global > m_global) const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
{ if(n_global > m_global)
acc_thread_buf(acc_offset) = -ck::NumericLimits<float>::Infinity(); {
} acc_thread_buf(acc_offset) =
else -ck::NumericLimits<float>::Infinity();
{ }
// Acc0 elementwise Op else
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER {
acc_element_op(acc_thread_buf(acc_offset), acc_thread_buf[acc_offset]); // Acc0 elementwise Op
#else #if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
#else
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run( ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(acc_offset), acc_element_op, acc_thread_buf[acc_offset]); acc_thread_buf(acc_offset),
#endif acc_element_op,
} acc_thread_buf[acc_offset]);
#endif
}
});
}); });
}); });
}); });
});
} }
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
...@@ -25,18 +25,18 @@ using CPermuteNumDims_G_M_O = ...@@ -25,18 +25,18 @@ using CPermuteNumDims_G_M_O =
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
Col, Col,
Row, Row,
CPermuteNumDims_G_M_O, CPermuteNumDims_G_M_O,
F16, F16,
F16, F16,
F16, F16,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>>>& instances); PassThrough>>>& instances);
template <typename ALayout, template <typename ALayout,
typename B0Layout, typename B0Layout,
...@@ -48,32 +48,32 @@ template <typename ALayout, ...@@ -48,32 +48,32 @@ template <typename ALayout,
typename CDataType> typename CDataType>
struct DeviceOperationInstanceFactory< struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_Gemm1N,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>> PassThrough>>
{ {
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout, using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout, B0Layout,
B1Layout, B1Layout,
CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_Gemm1N,
ADataType, ADataType,
B0DataType, B0DataType,
B1DataType, B1DataType,
CDataType, CDataType,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>; PassThrough>;
static auto GetInstances() static auto GetInstances()
{ {
...@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory< ...@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>) is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{ {
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> && if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>) is_same_v<B1Layout, Row> &&
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
{ {
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs); op_ptrs);
......
...@@ -27,7 +27,7 @@ using CPermuteNumDims_G_M_O = ...@@ -27,7 +27,7 @@ using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded = static constexpr auto GemmPadded =
...@@ -61,18 +61,18 @@ using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f1 ...@@ -61,18 +61,18 @@ using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f1
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
Col, Col,
Row, Row,
CPermuteNumDims_G_M_O, CPermuteNumDims_G_M_O,
F16, F16,
F16, F16,
F16, F16,
F16, F16,
PassThrough, PassThrough,
PassThrough, PassThrough,
Scale, Scale,
PassThrough, PassThrough,
PassThrough>>>& instances) PassThrough>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -31,22 +31,22 @@ template <typename ADataType, ...@@ -31,22 +31,22 @@ template <typename ADataType,
typename B1Layout, typename B1Layout,
typename CPermuteNumDims_G_M_O> typename CPermuteNumDims_G_M_O>
bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification, bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification,
int init_method, int init_method,
bool do_log, bool do_log,
bool time_kernel, bool time_kernel,
int M, int M,
int N, int N,
int K, int K,
int O, int O,
int G0, int G0,
int G1, int G1,
int StrideA = -1, int StrideA = -1,
int StrideB0 = -1, int StrideB0 = -1,
int StrideB1 = -1, int StrideB1 = -1,
int BatchStrideA = -1, int BatchStrideA = -1,
int BatchStrideB0 = -1, int BatchStrideB0 = -1,
int BatchStrideB1 = -1, int BatchStrideB1 = -1,
float alpha = 1.f) float alpha = 1.f)
{ {
...@@ -196,19 +196,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -196,19 +196,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto b1_element_op = B1ElementOp{}; auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{}; auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout, using DeviceOp =
B0Layout, tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B1Layout, B0Layout,
CPermuteNumDims_G_M_O, B1Layout,
ADataType, CPermuteNumDims_G_M_O,
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
AElementOp, CDataType,
B0ElementOp, AElementOp,
Acc0ElementOp, B0ElementOp,
B1ElementOp, Acc0ElementOp,
CElementOp>; B1ElementOp,
CElementOp>;
// get device op instances // get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory< const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
...@@ -226,8 +227,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -226,8 +227,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
ref_gemm0_invoker.Run(ref_gemm0_argument); ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle // mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) { acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity(); if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
}); });
auto ref_softmax = ReferenceSoftmaxInstance{}; auto ref_softmax = ReferenceSoftmaxInstance{};
...@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{ {
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data()); c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass & pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData); c_gs_ms_os_host_result.mData);
if(do_log) if(do_log)
{ {
...@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi ...@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType<float>( LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",") std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl; << std::endl;
LogRangeAsType<float>( LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result.mData, ",") c_gs_ms_os_device_result.mData,
",")
<< std::endl; << std::endl;
} }
} }
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp" #include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template <typename Tuple> template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple> class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{ {
}; };
...@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest) ...@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{49, 49, 64, 64, 4, 6}, {49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6}, {64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6}, {1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4,6}, {576, 576, 64, 64, 4, 6},
}; };
this->bench_ = true; this->bench_ = true;
this->Run(); this->Run();
......
...@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test ...@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void RunSingle(int M, int N, int K, int O, int G0, int G1) void RunSingle(int M, int N, int K, int O, int G0, int G1)
{ {
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<ADataType, bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
B0DataType, ADataType,
B1DataType, B0DataType,
CDataType, B1DataType,
ALayout, CDataType,
B0Layout, ALayout,
B1Layout, B0Layout,
CPermuteNumDims_G_M_O>( B1Layout,
verify_, 1, false, bench_, M, N, K, O, G0, G1); CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass); EXPECT_TRUE(pass);
} }
...@@ -59,12 +59,12 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test ...@@ -59,12 +59,12 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{ {
for(auto lengths : this->lengths_) for(auto lengths : this->lengths_)
{ {
int M = lengths[0]; int M = lengths[0];
int N = lengths[1]; int N = lengths[1];
int K = lengths[2]; int K = lengths[2];
int O = lengths[3]; int O = lengths[3];
int G0 = lengths[4]; int G0 = lengths[4];
int G1 = lengths[5]; int G1 = lengths[5];
this->RunSingle(M, N, K, O, G0, G1); this->RunSingle(M, N, K, O, G0, G1);
} }
...@@ -75,7 +75,7 @@ template <GemmSpecialization GemmSpec> ...@@ -75,7 +75,7 @@ template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{ {
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale; using Scale = ck::tensor_operation::element_wise::Scale;
using ALayout = Row; using ALayout = Row;
using B0Layout = Col; using B0Layout = Col;
...@@ -174,8 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 ...@@ -174,8 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
K, K,
O, O,
0, // BatchCount 0, // BatchCount
{0, 0, M, O}, // gs ms ns lengths {0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides {0, O, 0, 1}, // gs ms ns strides
0, // StrideA 0, // StrideA
0, // StrideB0 0, // StrideB0
0, // StrideB1 0, // StrideB1
...@@ -184,7 +184,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128 ...@@ -184,7 +184,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0, // BatchStrideB1 0, // BatchStrideB1
PassThrough{}, // a_element_op PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op PassThrough{}, // b0_element_op
Scale{1.f}, // acc0_element_op Scale{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op PassThrough{}); // c_element_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment