Commit c70aacd3 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent f9b8a5d0
...@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel) if(config.time_kernel)
{ {
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
{
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString( hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
hipMemsetAsync(arg_.p_c_grid,
0, 0,
arg_.M * arg_.N * sizeof(CDataType), arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -188,14 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -188,14 +185,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
arg_); arg_);
} }
else else
{
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0, 0,
arg.M * arg.N * sizeof(CDataType), arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_)); stream_config.stream_id_));
}
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
...@@ -212,7 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -212,7 +207,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -221,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -221,7 +215,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
}
else else
{ {
const auto kernel = const auto kernel =
...@@ -236,12 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -236,12 +229,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
...@@ -251,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -251,8 +243,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full) TailNumber::Full)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
...@@ -262,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -262,8 +254,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
TailNumber::Two)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
...@@ -322,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -322,8 +313,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
TailNumber::Six)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
...@@ -350,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -350,7 +340,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
} }
} }
}
else else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
...@@ -468,7 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -468,7 +457,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -491,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -491,7 +479,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel); Run(kernel);
} }
} }
}
else else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -519,12 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -519,12 +506,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
else else
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
...@@ -533,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -533,8 +519,8 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
...@@ -542,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -542,7 +528,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
Run(kernel); Run(kernel);
} }
} }
}
else else
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
...@@ -575,7 +560,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -575,7 +560,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -584,7 +568,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -584,7 +568,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
}
else else
{ {
const auto kernel = const auto kernel =
......
...@@ -29,7 +29,7 @@ template <typename GridwiseGemm, ...@@ -29,7 +29,7 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Full>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
...@@ -57,7 +57,7 @@ template <typename GridwiseGemm, ...@@ -57,7 +57,7 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Full>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
...@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
__host__ void Print() const __host__ void Print() const
{ {
std::cout << "problem {" std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "M:" << M << ", " << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< "N:" << N << ", " << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "K:" << K << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< "SA:" << StrideA << ", " << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl; << "NBlock: " << NBlock << "}" << std::endl;
} }
......
...@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert(N % 2 == 0, ""); static_assert(N % 2 == 0, "");
vector_type<half_t, N> tmp{src_thread_data}; vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i, tmp.template AsType<half2_t>()[i]); __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
}); });
} }
else if constexpr(is_same<T, bhalf_t>::value) else if constexpr(is_same<T, bhalf_t>::value)
...@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert(N % 2 == 0, ""); static_assert(N % 2 == 0, "");
vector_type<bhalf_t, N> tmp{src_thread_data}; vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i, tmp.template AsType<bhalf2_t>()[i]); __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
}); });
} }
} }
...@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
{ {
ignore = dst_wave_buffer_resource; ignore = dst_wave_buffer_resource;
ignore = dst_thread_addr_offset; ignore = dst_thread_addr_offset;
//amd_buffer_atomic_add_impl<scalar_t, vector_size>( // amd_buffer_atomic_add_impl<scalar_t, vector_size>(
//src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); // src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
amd_global_atomic_add_impl<scalar_t, vector_size>(src_thread_data, p_dst_wave + dst_thread_element_offset); amd_global_atomic_add_impl<scalar_t, vector_size>(src_thread_data,
p_dst_wave + dst_thread_element_offset);
} }
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment