Commit f9b8a5d0 authored by Jing Zhang's avatar Jing Zhang
Browse files

added bf16 atomic_add

parent b0f295cb
...@@ -19,7 +19,7 @@ using AElementOp = PassThrough; ...@@ -19,7 +19,7 @@ using AElementOp = PassThrough;
using BElementOp = PassThrough; using BElementOp = PassThrough;
using CElementOp = PassThrough; using CElementOp = PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNPadding;
// clang-format off // clang-format off
using DeviceGemmV2Instance = using DeviceGemmV2Instance =
......
...@@ -272,7 +272,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) ...@@ -272,7 +272,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
if(config.time_kernel) if(config.time_kernel)
{ {
ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50});
std::size_t flop = 2_uz * M * N * K; std::size_t flop = 2_uz * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -168,7 +168,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -168,7 +168,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
// rotating mem // rotating mem
rotating_mem.Next(); rotating_mem.Next();
// clear c mem // clear c mem
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(arg_.KBatch > 1) if(arg_.KBatch > 1)
hipGetErrorString( hipGetErrorString(
...@@ -190,7 +189,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -190,7 +189,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
} }
else else
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid, hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
...@@ -215,7 +213,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -215,7 +213,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
...@@ -240,7 +237,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -240,7 +237,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{ {
...@@ -473,7 +469,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -473,7 +469,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -525,7 +520,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -525,7 +520,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{ {
...@@ -582,7 +576,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -582,7 +576,6 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
if constexpr(!is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{ {
const auto kernel = const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm, kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
......
...@@ -1105,7 +1105,9 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1105,7 +1105,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value || if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
is_same<remove_cvref_t<CDataType>, float>::value)) is_same<remove_cvref_t<CDataType>, float>::value ||
is_same<remove_cvref_t<CDataType>, bhalf_t>::value ||
is_same<remove_cvref_t<CDataType>, int32_t>::value))
{ {
if(!karg.IsReduceAdd()) if(!karg.IsReduceAdd())
{ {
......
...@@ -568,32 +568,19 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ ...@@ -568,32 +568,19 @@ __device__ void amd_global_atomic_add_impl(const typename vector_type<T, N>::typ
{ {
if constexpr(is_same<T, half_t>::value) if constexpr(is_same<T, half_t>::value)
{ {
#if 0
if constexpr(N == 2)
{
__builtin_amdgcn_global_atomic_fadd_v2f16(addr, src_thread_data);
}
else if constexpr(N == 4)
{
vector_type<half_t, 4> tmp{src_thread_data};
static_for<0, 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType<half2_t>()[i]);
});
}
else if constexpr(N == 8)
{
vector_type<half_t, 8> tmp{src_thread_data};
static_for<0, 4, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(addr + i, tmp.AsType<half2_t>()[i]);
});
}
#else
static_assert(N % 2 == 0, ""); static_assert(N % 2 == 0, "");
vector_type<half_t, N> tmp{src_thread_data}; vector_type<half_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) { static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i, tmp.template AsType<half2_t>()[i]); __builtin_amdgcn_global_atomic_fadd_v2f16(bit_cast<half2_t*>(addr) + i, tmp.template AsType<half2_t>()[i]);
}); });
#endif }
else if constexpr(is_same<T, bhalf_t>::value)
{
static_assert(N % 2 == 0, "");
vector_type<bhalf_t, N> tmp{src_thread_data};
static_for<0, N / 2, 1>{}([&](auto i) {
__builtin_amdgcn_global_atomic_fadd_v2bf16(bit_cast<bhalf2_t*>(addr) + i, tmp.template AsType<bhalf2_t>()[i]);
});
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment