Commit c70aacd3 authored by Jing Zhang's avatar Jing Zhang
Browse files

format

parent f9b8a5d0
...@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -272,7 +272,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel) if(config.time_kernel)
{ {
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -168,14 +168,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
{ if(arg_.KBatch > 1)
if(arg_.KBatch > 1) hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
hipGetErrorString( 0,
hipMemsetAsync(arg_.p_c_grid, arg_.M * arg_.N * sizeof(CDataType),
0, stream_config.stream_id_));
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
}
}; };
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>( ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
...@@ -189,13 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -189,13 +186,11 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else else
{ {
{ if(arg.KBatch > 1)
if(arg.KBatch > 1) hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, 0,
0, arg.M * arg.N * sizeof(CDataType),
arg.M * arg.N * sizeof(CDataType), stream_config.stream_id_));
stream_config.stream_id_));
}
ave_time = launch_and_time_kernel( ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
...@@ -213,14 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -213,14 +208,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, true,
true, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>;
minimum_occupancy>; Run(kernel);
Run(kernel);
}
} }
else else
{ {
...@@ -237,117 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -237,117 +230,113 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::One>; TailNumber::Two>;
Run(kernel); Run(kernel);
} }
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == }
TailNumber::Full)
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3< const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::AtomicAdd, InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy, minimum_occupancy,
TailNumber::Full>; TailNumber::Three>;
Run(kernel); Run(kernel);
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Two) TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Four) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Four>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Five) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Five>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Six) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Six>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
} }
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == const auto kernel = kernel_gemm_xdl_cshuffle_v3<
TailNumber::Seven) GridwiseGemm,
{ true,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< InMemoryDataOperationEnum::AtomicAdd,
GridwiseGemm, minimum_occupancy,
true, TailNumber::Seven>;
InMemoryDataOperationEnum::AtomicAdd, Run(kernel);
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
} }
} }
} }
...@@ -469,27 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -469,27 +458,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
{ GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -520,27 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -520,27 +507,25 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Odd>;
minimum_occupancy, Run(kernel);
TailNumber::Odd>; }
Run(kernel); else
} {
else const auto kernel =
{ kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
const auto kernel = kernel_gemm_xdl_cshuffle_v3< true,
GridwiseGemm, InMemoryDataOperationEnum::AtomicAdd,
true, minimum_occupancy,
InMemoryDataOperationEnum::AtomicAdd, TailNumber::Even>;
minimum_occupancy, Run(kernel);
TailNumber::Even>;
Run(kernel);
}
} }
} }
else else
...@@ -576,14 +561,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -576,14 +561,12 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
{ const auto kernel =
const auto kernel = kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, false,
false, InMemoryDataOperationEnum::AtomicAdd,
InMemoryDataOperationEnum::AtomicAdd, minimum_occupancy>;
minimum_occupancy>; Run(kernel);
Run(kernel);
}
} }
else else
{ {
......
...@@ -29,7 +29,7 @@ template <typename GridwiseGemm, ...@@ -29,7 +29,7 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Full>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
...@@ -57,7 +57,7 @@ template <typename GridwiseGemm, ...@@ -57,7 +57,7 @@ template <typename GridwiseGemm,
TailNumber TailNum = TailNumber::Full> TailNumber TailNum = TailNumber::Full>
__global__ void __global__ void
#if CK_USE_LAUNCH_BOUNDS #if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif #endif
// __attribute__((amdgpu_waves_per_eu(1, 1))) // __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
...@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -485,20 +485,11 @@ struct GridwiseGemm_xdl_cshuffle_v3
__host__ void Print() const __host__ void Print() const
{ {
std::cout << "problem {" std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
<< "M:" << M << ", " << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
<< "N:" << N << ", " << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
<< "K:" << K << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
<< "SA:" << StrideA << ", " << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl; << "NBlock: " << NBlock << "}" << std::endl;
} }
......
...@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -571,7 +571,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert(N % 2 == 0, ""); static_assert(N % 2 == 0, "");
vector_type<half_t, N> tmp{src_thread_data}; vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i, tmp.template AsType<half2_t>()[i]); __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i,
tmp.template AsType<half2_t>()[i]);
}); });
} }
else if constexpr(is_same<T, bhalf_t>::value) else if constexpr(is_same<T, bhalf_t>::value)
...@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -579,7 +580,8 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
static_assert(N % 2 == 0, ""); static_assert(N % 2 == 0, "");
vector_type<bhalf_t, N> tmp{src_thread_data}; vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i, tmp.template AsType<bhalf2_t>()[i]); __builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i,
tmp.template AsType<bhalf2_t>()[i]);
}); });
} }
} }
...@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr ...@@ -939,9 +941,10 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
{ {
ignore = dst_wave_buffer_resource; ignore = dst_wave_buffer_resource;
ignore = dst_thread_addr_offset; ignore = dst_thread_addr_offset;
//amd_buffer_atomic_add_impl<scalar_t, vector_size>( // amd_buffer_atomic_add_impl<scalar_t, vector_size>(
//src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); // src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
amd_global_atomic_add_impl<scalar_t, vector_size>(src_thread_data, p_dst_wave + dst_thread_element_offset); amd_global_atomic_add_impl<scalar_t, vector_size>(src_thread_data,
p_dst_wave + dst_thread_element_offset);
} }
#endif #endif
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment