Commit 96b0f78c authored by wangshaojie6's avatar wangshaojie6
Browse files

clang-format

parent 97dcc7b2
......@@ -370,8 +370,9 @@ int main(int argc, char* argv[])
ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{};
......
......@@ -755,13 +755,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
// decoder lower triangular mask
const auto thread_cluster_idx =
threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
const auto thread_cluster_idx = threadid_to_m_n_thread_cluster_adaptor.CalculateBottomIndex(
make_multi_index(get_thread_local_1d_id()));
const auto thread_m_cluster_id = thread_cluster_idx[I0];
const auto thread_n_cluster_id = thread_cluster_idx[I1];
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
const index_t MPerRepeat = MPerBlock / MXdlPerWave;
const index_t NPerRepeat = NPerBlock / NXdlPerWave;
const index_t mstart = m_block_data_idx_on_grid + thread_m_cluster_id;
// gemm1 K loop
index_t gemm1_k_block_outer_index = 0;
......@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
if constexpr(MaskOutUpperTriangle)
{
auto gemm0_n_block_idx = __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) && ((m_block_data_idx_on_grid + MPerBlock - 1) < (gemm0_n_block_idx + NPerBlock - 1)))
auto gemm0_n_block_idx =
__builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
if((m_block_data_idx_on_grid < gemm0_n_block_idx) &&
((m_block_data_idx_on_grid + MPerBlock - 1) <
(gemm0_n_block_idx + NPerBlock - 1)))
{
continue;
}
......@@ -807,40 +810,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
else
{
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) {
const index_t m_global = mstart + m0_i * MPerRepeat;
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
static_for<0, n0, 1>{}([&](auto n0_i) {
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup = nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if (n_global > m_global)
{
acc_thread_buf(acc_offset) = -ck::NumericLimits<float>::Infinity();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset), acc_thread_buf[acc_offset]);
#else
const index_t nstart = gemm1_k_block_outer_index * NPerBlock;
static_for<0, m0, 1>{}([&](auto m0_i) {
const index_t m_global = mstart + m0_i * MPerRepeat;
const index_t acc_idx_m0 = m0_i * n0 * n2 * n4;
static_for<0, n0, 1>{}([&](auto n0_i) {
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
const index_t nstartxdl = nstart + n0_i * NPerRepeat;
const index_t acc_idx_n0 = acc_idx_m0 + n0_i * n2 * n4;
static_for<0, n2, 1>{}([&](auto n2_i) {
const index_t nstartgroup =
nstartxdl + thread_n_cluster_id * n4 + n2_i * AccN3 * n4;
const index_t acc_idx_n2 = acc_idx_n0 + n2_i * n4;
static_for<0, n4, 1>{}([&](auto n4_i) {
const index_t n_global = nstartgroup + n4_i;
const auto acc_offset = Number<acc_idx_n2 + n4_i>{};
if(n_global > m_global)
{
acc_thread_buf(acc_offset) =
-ck::NumericLimits<float>::Infinity();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op(acc_thread_buf(acc_offset),
acc_thread_buf[acc_offset]);
#else
ElementOpPredicatedResetNaNToMinusInf<PadN>{}.Run(
acc_thread_buf(acc_offset), acc_element_op, acc_thread_buf[acc_offset]);
#endif
}
acc_thread_buf(acc_offset),
acc_element_op,
acc_thread_buf[acc_offset]);
#endif
}
});
});
});
});
});
}
block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
......
......@@ -25,18 +25,18 @@ using CPermuteNumDims_G_M_O =
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
Col,
Row,
CPermuteNumDims_G_M_O,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>>& instances);
Col,
Row,
CPermuteNumDims_G_M_O,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>>& instances);
template <typename ALayout,
typename B0Layout,
......@@ -48,32 +48,32 @@ template <typename ALayout,
typename CDataType>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>
B0Layout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>
{
using DeviceOp = DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>;
B0Layout,
B1Layout,
CPermuteNumDims_G_M_Gemm1N,
ADataType,
B0DataType,
B1DataType,
CDataType,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>;
static auto GetInstances()
{
......@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v<B1DataType, half_t> && is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<B0Layout, Col> &&
is_same_v<B1Layout, Row> && is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
is_same_v<B1Layout, Row> &&
is_same_v<CPermuteNumDims_G_M_Gemm1N, CPermuteNumDims_G_M_O>)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
op_ptrs);
......
......@@ -27,7 +27,7 @@ using CPermuteNumDims_G_M_O =
S<2, 1, 1>; // "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using Scale = ck::tensor_operation::element_wise::Scale;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto GemmPadded =
......@@ -61,18 +61,18 @@ using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f1
void add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<Row,
Col,
Row,
CPermuteNumDims_G_M_O,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>>& instances)
Col,
Row,
CPermuteNumDims_G_M_O,
F16,
F16,
F16,
F16,
PassThrough,
PassThrough,
Scale,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
......
......@@ -31,22 +31,22 @@ template <typename ADataType,
typename B1Layout,
typename CPermuteNumDims_G_M_O>
bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int G0,
int G1,
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
float alpha = 1.f)
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int O,
int G0,
int G1,
int StrideA = -1,
int StrideB0 = -1,
int StrideB1 = -1,
int BatchStrideA = -1,
int BatchStrideB0 = -1,
int BatchStrideB1 = -1,
float alpha = 1.f)
{
......@@ -196,19 +196,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto b1_element_op = B1ElementOp{};
auto c_element_op = CElementOp{};
using DeviceOp = tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp>;
using DeviceOp =
tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute<ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O,
ADataType,
B0DataType,
B1DataType,
CDataType,
AElementOp,
B0ElementOp,
Acc0ElementOp,
B1ElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = tensor_operation::device::instance::DeviceOperationInstanceFactory<
......@@ -226,8 +227,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
ref_gemm0_invoker.Run(ref_gemm0_argument);
// mask out upper triangle
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if (idx[1] < idx[2]) self(idx) = -ck::NumericLimits<float>::Infinity();
acc0_g_m_n.ForEach([&](auto& self, auto idx) {
if(idx[1] < idx[2])
self(idx) = -ck::NumericLimits<float>::Infinity();
});
auto ref_softmax = ReferenceSoftmaxInstance{};
......@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{
c_gs_ms_os_device_buf.FromDevice(c_gs_ms_os_device_result.mData.data());
pass = pass &
ck::utils::check_err(c_gs_ms_os_device_result.mData, c_gs_ms_os_host_result.mData);
pass = pass & ck::utils::check_err(c_gs_ms_os_device_result.mData,
c_gs_ms_os_host_result.mData);
if(do_log)
{
......@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType<float>(
std::cout << "c_gs_ms_os_host_result : ", c_gs_ms_os_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_gs_ms_os_device_result : ", c_gs_ms_os_device_result.mData, ",")
LogRangeAsType<float>(std::cout << "c_gs_ms_os_device_result : ",
c_gs_ms_os_device_result.mData,
",")
<< std::endl;
}
}
......
......@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template <typename Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16 : public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
class TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
: public TestBatchedGemmMaskingScaleSoftmaxGemmPermute<Tuple>
{
};
......@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{49, 49, 64, 64, 4, 6},
{64, 49, 64, 64, 4, 6},
{1020, 1020, 64, 128, 4, 6},
{576, 576, 64, 64, 4,6},
{576, 576, 64, 64, 4, 6},
};
this->bench_ = true;
this->Run();
......
......@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void RunSingle(int M, int N, int K, int O, int G0, int G1)
{
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O>(
verify_, 1, false, bench_, M, N, K, O, G0, G1);
bool pass = ck::profiler::profile_batched_gemm_masking_scale_softmax_gemm_permute_impl<
ADataType,
B0DataType,
B1DataType,
CDataType,
ALayout,
B0Layout,
B1Layout,
CPermuteNumDims_G_M_O>(verify_, 1, false, bench_, M, N, K, O, G0, G1);
EXPECT_TRUE(pass);
}
......@@ -59,12 +59,12 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
for(auto lengths : this->lengths_)
{
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int G0 = lengths[4];
int G1 = lengths[5];
int M = lengths[0];
int N = lengths[1];
int K = lengths[2];
int O = lengths[3];
int G0 = lengths[4];
int G1 = lengths[5];
this->RunSingle(M, N, K, O, G0, G1);
}
......@@ -75,7 +75,7 @@ template <GemmSpecialization GemmSpec>
struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Scale = ck::tensor_operation::element_wise::Scale;
using Scale = ck::tensor_operation::element_wise::Scale;
using ALayout = Row;
using B0Layout = Col;
......@@ -174,8 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
K,
O,
0, // BatchCount
{0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides
{0, 0, M, O}, // gs ms ns lengths
{0, O, 0, 1}, // gs ms ns strides
0, // StrideA
0, // StrideB0
0, // StrideB1
......@@ -184,7 +184,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0, // BatchStrideB1
PassThrough{}, // a_element_op
PassThrough{}, // b0_element_op
Scale{1.f}, // acc0_element_op
Scale{1.f}, // acc0_element_op
PassThrough{}, // b1_element_op
PassThrough{}); // c_element_op
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment