Commit 36eb9b69 authored by illsilin's avatar illsilin
Browse files

fix clang format

parent b3bd7f68
......@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
typename ALayout, typename BLayout, typename CLayout>
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
......@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
float ave_time =
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
......@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
<< " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name
<< " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
<< " B Type = " << DataTypeTraits<BDataType>::name
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
......@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
......@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_m_k_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
......@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
a_m_k, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
......@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
(K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_gpu_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
......
......@@ -50,13 +50,14 @@ __device__ int32x4_t make_wave_buffer_resource_with_default_range(T* p_wave)
}
template <typename T>
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave, index_t element_space_size)
__device__ __amdgpu_buffer_rsrc_t make_wave_buffer_resource_new(T* p_wave,
index_t element_space_size)
{
// wavewise base address (64 bit)
auto p = const_cast<remove_cv_t<T>*>(p_wave);
auto p = const_cast<remove_cv_t<T>*>(p_wave);
int32_t stride = 0;
int32_t num = element_space_size * sizeof(T);
auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
int32_t num = element_space_size * sizeof(T);
auto flags = CK_BUFFER_RESOURCE_3RD_DWORD;
return __builtin_amdgcn_make_buffer_rsrc(p, stride, num, flags);
}
......@@ -129,57 +130,57 @@ amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
if constexpr(N == 1)
{
return __builtin_amdgcn_raw_buffer_load_b8(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
int16_t tmp = __builtin_amdgcn_raw_buffer_load_b16(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x2_t>(tmp);
}
else if constexpr(N == 4)
{
int32_t tmp = __builtin_amdgcn_raw_buffer_load_b32(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x4_t>(tmp);
}
else if constexpr(N == 8)
{
int32x2_t tmp = __builtin_amdgcn_raw_buffer_load_b64(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x8_t>(tmp);
}
else if constexpr(N == 16)
{
int32x4_t tmp = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
return bit_cast<int8x16_t>(tmp);
}
else if constexpr(N == 32)
{
int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
vector_type<int32_t, 8> tmp;
tmp.AsType<int32x4_t>()(Number<0>{}) = tmp0;
......@@ -190,24 +191,24 @@ amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
else if constexpr(N == 64)
{
int32x4_t tmp0 = __builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset,
static_cast<index_t>(coherence));
int32x4_t tmp1 =
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset + 4 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp2 =
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int32_t),
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset + 8 * sizeof(int32_t),
static_cast<index_t>(coherence));
int32x4_t tmp3 =
__builtin_amdgcn_raw_buffer_load_b128(src_wave_buffer_resource,
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
src_thread_addr_offset,
src_wave_addr_offset + 12 * sizeof(int32_t),
static_cast<index_t>(coherence));
vector_type<int32_t, 16> tmp;
......@@ -223,9 +224,10 @@ amd_buffer_load_impl_raw(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
template <typename T,
index_t N,
AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
__device__ typename vector_type<T, N>::type amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
__device__ typename vector_type<T, N>::type
amd_buffer_load_impl(__amdgpu_buffer_rsrc_t src_wave_buffer_resource,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset)
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
......@@ -259,87 +261,87 @@ amd_buffer_store_impl_raw(const typename vector_type<int8_t, N>::type src_thread
if constexpr(N == 1)
{
__builtin_amdgcn_raw_buffer_store_b8(src_thread_data,
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 2)
{
__builtin_amdgcn_raw_buffer_store_b16(bit_cast<int16_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 4)
{
__builtin_amdgcn_raw_buffer_store_b32(bit_cast<int32_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
__builtin_amdgcn_raw_buffer_store_b64(bit_cast<int32x2_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 16)
{
__builtin_amdgcn_raw_buffer_store_b128(bit_cast<int32x4_t>(src_thread_data),
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 32)
{
vector_type<int32_t, 8> tmp{bit_cast<int32x8_t>(src_thread_data)};
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
}
else if constexpr(N == 64)
{
vector_type<int32_t, 16> tmp{bit_cast<int32x16_t>(src_thread_data)};
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 4,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<2>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 8,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_buffer_store_b128(tmp.template AsType<int32x4_t>()[Number<3>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + sizeof(int32_t) * 12,
static_cast<index_t>(coherence));
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment