Commit c54b7bc9 authored by Chao Liu's avatar Chao Liu
Browse files

gMerge remote-tracking branch 'origin/develop' into group_norm

parents 9a8967a4 f584ab0c
...@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle ...@@ -881,9 +881,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
FloatGemmAcc c_new = FloatGemmAcc c_new =
(running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c + (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
math::exp(max[iM] - running_max_new[iM]) * acc1) / math::exp(max[iM] - running_max_new[iM]) * acc1) /
running_sum_new[iM]; // O_new running_sum_new[iM]; // Formula by Dao et al.,
// https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
c_thread_buf(I) = c_new; c_thread_buf(I) = c_new; // O_new
}); });
}); });
......
...@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D ...@@ -83,6 +83,8 @@ struct GridwiseElementwise_1D
auto in_global_buf_tuple = generate_tuple( auto in_global_buf_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize()); p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
}, },
...@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D ...@@ -90,6 +92,8 @@ struct GridwiseElementwise_1D
auto out_global_buf_tuple = generate_tuple( auto out_global_buf_tuple = generate_tuple(
[&](auto I) { [&](auto I) {
static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
return make_dynamic_buffer<AddressSpaceEnum::Global>( return make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize()); p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
}, },
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <numeric>
#include <iterator>
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
template <typename GridwisePermute,
typename InGridDesc,
typename OutGridDesc,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
typename Block2TileMap>
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataType* p_in_global,
OutDataType* p_out_global,
const ElementwiseOperation elementwise_op,
const Block2TileMap block_2_tile_map)
{
__shared__ char p_shared[GridwisePermute::GetSharedMemoryNumberOfByte()];
GridwisePermute::Run(in_grid_desc,
out_grid_desc,
p_in_global,
p_out_global,
p_shared,
elementwise_op,
block_2_tile_map);
}
template <typename InGridDesc,
typename OutGridDesc,
typename InDataType,
typename OutDataType,
typename ElementwiseOperation,
index_t BlockSize,
index_t NPerBlock,
index_t HPerBlock,
index_t WPerBlock,
index_t InBlockLdsExtraW,
typename InBlockTransferThreadClusterLengths,
typename InBlockTransferThreadClusterArrangeOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t SrcScalarPerVector,
index_t DstScalarPerVector>
struct GridwisePermute
{
static_assert(InGridDesc::GetNumOfDimension() == OutGridDesc::GetNumOfDimension());
static_assert(3 <= InGridDesc::GetNumOfDimension());
static_assert((InGridDesc::GetNumOfDimension() - 2) <= SrcVectorDim &&
SrcVectorDim < InGridDesc::GetNumOfDimension());
static_assert((OutGridDesc::GetNumOfDimension() - 2) <= DstVectorDim &&
DstVectorDim < OutGridDesc::GetNumOfDimension());
static_assert(SrcVectorDim != DstVectorDim);
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
struct Block2TileMap
{
static constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
static constexpr auto I0 = Number<0>{};
Block2TileMap() = delete;
Block2TileMap(const Block2TileMap&) = default;
Block2TileMap(Block2TileMap&&) = delete;
~Block2TileMap() = default;
Block2TileMap& operator=(const Block2TileMap&) = delete;
Block2TileMap& operator=(Block2TileMap&&) = delete;
explicit Block2TileMap(const InGridDesc& desc) : desc_(desc) {}
__host__ constexpr index_t CalculateGridSize(const InGridDesc& desc) const
{
const auto N0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc.GetLength(Number<NumDim - 1>{}), WPerBlock);
const index_t grid_size = N0 * H0 * W0;
return grid_size;
}
template <typename TopIdx>
__host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const
{
static_assert(TopIdx::Size() == 1);
auto block_1d_id = idx_top[I0];
const auto N0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 3>{}), NPerBlock);
const auto H0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 2>{}), HPerBlock);
const auto W0 =
math::integer_divide_ceil(desc_.GetLength(Number<NumDim - 1>{}), WPerBlock);
block_1d_id = block_1d_id % (N0 * H0 * W0);
index_t idx_N0 = block_1d_id / (H0 * W0);
index_t idx_H0 = (block_1d_id % (H0 * W0)) / W0;
index_t idx_W0 = block_1d_id % W0;
return make_tuple(idx_N0, idx_H0, idx_W0);
}
private:
const InGridDesc desc_;
};
using DefaultBlock2TileMap = Block2TileMap;
// use an [NPerBlock, HPerBlock, WPerBlock] tensor as element-copy relay
__host__ __device__ static constexpr auto GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock()
{
return make_naive_tensor_descriptor(
make_tuple(Number<NPerBlock>{}, Number<HPerBlock>{}, Number<WPerBlock>{}),
make_tuple(Number<HPerBlock*(WPerBlock + InBlockLdsExtraW)>{},
Number<WPerBlock + InBlockLdsExtraW>{},
I1));
}
// for N-dimension descriptor, reserve its last 2 dimensions, then merge its leading dimensions
// into single one. finally, form a 3D descriptor: [d(0), d(1), ..., d(N - 2), d(N - 1)] ->
// [(d(0) x d(1) x ...), d(N - 2), d(N - 1)]
template <typename GridDesc>
__host__ __device__ static constexpr auto GetMergedDesc(const GridDesc& desc)
{
constexpr index_t NumDim = GridDesc::GetNumOfDimension();
static_assert(3 <= NumDim);
const auto merged_desc = transform_tensor_descriptor(
desc,
make_tuple(make_merge_transform(generate_tuple(
[&](auto I) { return desc.GetLength(I); }, Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 2>{})),
make_pass_through_transform(desc.GetLength(Number<NumDim - 1>{}))),
make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
Sequence<NumDim - 2>{},
Sequence<NumDim - 1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
return merged_desc;
}
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
return in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize() *
sizeof(InDataType);
}
__host__ __device__ static constexpr auto MakeDefaultBlock2TileMap(const InGridDesc& desc)
{
return DefaultBlock2TileMap{desc};
}
__host__ __device__ static constexpr bool CheckValidity(const InGridDesc& in_grid_desc,
const OutGridDesc& out_grid_desc)
{
constexpr index_t NumDim = InGridDesc::GetNumOfDimension();
// check if we only swap last 2 dimensions
bool valid = true;
static_for<0, NumDim - 2, 1>{}([&](auto I) {
if(valid && in_grid_desc.GetLength(I) != out_grid_desc.GetLength(I))
{
valid = false;
}
});
return valid &&
(in_grid_desc.GetLength(Number<NumDim - 1>{}) ==
out_grid_desc.GetLength(Number<NumDim - 2>{})) &&
(in_grid_desc.GetLength(Number<NumDim - 2>{}) ==
out_grid_desc.GetLength(Number<NumDim - 1>{}));
}
template <typename Block2TileMap>
__device__ static void Run(const InGridDesc in_grid_desc,
const OutGridDesc out_grid_desc,
const InDataType* p_in_global,
OutDataType* p_out_global,
void* __restrict__ p_shared,
const ElementwiseOperation elementwise_op,
const Block2TileMap& block_2_tile_map)
{
auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_desc.GetElementSpaceSize());
auto out_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_global, out_grid_desc.GetElementSpaceSize());
// each workgroup handles an [NPerBlock, HPerBlock, WPerBLock] slice-transpose problem
const auto block_work_idx =
block_2_tile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I0] * NPerBlock);
const index_t h_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I1] * HPerBlock);
const index_t w_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_work_idx[I2] * WPerBlock);
// create [NPerBlock, HPerBlock, WPerBLock] shaped LDS buffer
constexpr auto in_block_desc_nperblock_hperblock_wperblock =
GetInBlockDesc_NPerBlock_HPerBlock_WPerBlock();
auto in_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<InDataType*>(p_shared),
in_block_desc_nperblock_hperblock_wperblock.GetElementSpaceSize());
using BlockSliceLengths = Sequence<NPerBlock, HPerBlock, WPerBlock>;
using InBlockTransferAccessOrder = Sequence<0, 1, 2>;
constexpr index_t SrcVectorDimAfterMerge =
SrcVectorDim - (InGridDesc::GetNumOfDimension() - 3);
constexpr index_t DstVectorDimAfterMerge = SrcVectorDimAfterMerge;
using ck::tensor_operation::element_wise::PassThrough;
// merge input descriptor into [(in_grid_desc.GetLength(0) x in_grid_desc.GetLength(1) x
// ...), in_grid_desc.GetLength(NumDim - 2), in_grid_desc.GetLength(NumDim - 1)]
const auto in_grid_desc_n_h_w = GetMergedDesc(in_grid_desc);
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from global memory to LDS
auto in_global_load = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
PassThrough,
InMemoryDataOperationEnum::Set,
BlockSliceLengths,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
InDataType,
decltype(in_grid_desc_n_h_w),
decltype(in_block_desc_nperblock_hperblock_wperblock),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
SrcVectorDimAfterMerge,
2,
SrcScalarPerVector,
1,
1,
1,
true,
true>(in_grid_desc_n_h_w,
make_multi_index(
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
PassThrough{},
in_block_desc_nperblock_hperblock_wperblock,
make_multi_index(0, 0, 0),
PassThrough{});
// merge output descriptor into [(out_grid_desc.GetLength(0) x out_grid_desc.GetLength(1) x
// ...), out_grid_desc.GetLength(NumDim - 2), out_grid_desc.GetLength(NumDim - 1)]
const auto out_grid_desc_n_w_h = GetMergedDesc(out_grid_desc);
// create transposed view of output tensor
const auto out_grid_desc_n_h_w = transform_tensor_descriptor(
out_grid_desc_n_w_h,
make_tuple(make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I0)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I1)),
make_pass_through_transform(out_grid_desc_n_w_h.GetLength(I2))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<2>{}, Sequence<1>{}));
// a workgroup copies an [NPerBlock, HPerBlock, WPerBlock] slice from LDS to global memory
auto out_global_store = ThreadGroupTensorSliceTransfer_v4r1<
ThisThreadBlock,
ElementwiseOperation,
PassThrough,
InMemoryDataOperationEnum::Set,
BlockSliceLengths,
InBlockTransferThreadClusterLengths,
InBlockTransferThreadClusterArrangeOrder,
InDataType,
OutDataType,
decltype(in_block_desc_nperblock_hperblock_wperblock),
decltype(out_grid_desc_n_h_w),
InBlockTransferAccessOrder,
InBlockTransferAccessOrder,
2,
DstVectorDimAfterMerge,
1,
DstScalarPerVector,
1,
1,
true,
true>(in_block_desc_nperblock_hperblock_wperblock,
make_multi_index(0, 0, 0),
PassThrough{},
out_grid_desc_n_h_w,
make_multi_index(
n_block_data_idx_on_grid, h_block_data_idx_on_grid, w_block_data_idx_on_grid),
elementwise_op);
in_global_load.Run(in_grid_desc_n_h_w,
in_global_buf,
in_block_desc_nperblock_hperblock_wperblock,
in_block_buf,
I0);
out_global_store.Run(in_block_desc_nperblock_hperblock_wperblock,
in_block_buf,
out_grid_desc_n_h_w,
out_global_buf,
I0);
}
};
} // namespace ck
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include "ck/utility/common_header.hpp" #include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor/static_tensor.hpp" #include "ck/tensor/static_tensor.hpp"
namespace ck { namespace ck {
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstddef>
#include <array>
#include <type_traits>
namespace ck {
template <typename T>
class span
{
public:
using element_type = T;
using value_type = std::remove_cv_t<element_type>;
using size_type = std::size_t;
using difference_type = std::ptrdiff_t;
using pointer = element_type*;
using const_pointer = const element_type*;
using reference = element_type&;
using const_reference = const element_type&;
using iterator = pointer;
using const_iterator = pointer;
constexpr span() : span(nullptr, size_type{0}) {}
constexpr span(pointer first, size_type count) : ptr_(first), size_(count) {}
constexpr span(pointer first, pointer last) : span(first, last - first) {}
template <std::size_t N>
constexpr span(element_type (&arr)[N]) noexcept : span(arr, N)
{
}
template <std::size_t N>
constexpr span(std::array<value_type, N>& arr) noexcept : span(arr.data(), N)
{
}
template <typename Container>
constexpr span(const Container& container) : span(container.data(), container.size())
{
}
constexpr iterator begin() const noexcept { return ptr_; }
constexpr const_iterator cbegin() const noexcept { return begin(); }
constexpr iterator end() const noexcept { return begin() + size(); }
constexpr const_iterator cend() const noexcept { return end(); }
constexpr reference front() const { return *begin(); }
constexpr reference back() const { return *(--end()); }
constexpr reference operator[](size_type idx) const { return *(begin() + idx); }
constexpr pointer data() const noexcept { return ptr_; }
constexpr size_type size() const noexcept { return size_; }
private:
pointer ptr_;
size_type size_;
};
} // namespace ck
...@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t ...@@ -34,17 +34,15 @@ __device__ void transpose_fp16_2x2(const half2_t& x0, const half2_t& x1, half2_t
y0 = vy0.template AsType<half2_t>()[I0]; y0 = vy0.template AsType<half2_t>()[I0];
y1 = vy1.template AsType<half2_t>()[I0]; y1 = vy1.template AsType<half2_t>()[I0];
#else #else
asm volatile("\n \ constexpr int32_t m0 = 0x05040100;
v_pack_b32_f16 %0, %1, %2 \n \ constexpr int32_t m1 = 0x07060302;
"
: "=v"(y0) // ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
: "v"(x0), "v"(x1)); // -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
asm volatile("\n \ // index is reversed because of little endianness (least significant bits first)
v_pack_b32_f16 %0, %1, %2, op_sel:[1, 1] \n \ y0 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0));
" y1 = bit_cast<half2_t>(__builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m1));
: "=v"(y1)
: "v"(x0), "v"(x1));
#endif #endif
} }
...@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0, ...@@ -106,16 +104,14 @@ __device__ void transpose_int8_4x4(const int8x4_t& x0,
// -- -- -- -- -- -- -- -- - - - - // -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88 // index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first) // index is reversed because of little endianness (least significant bits first)
// clang-format off t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m0)); t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m0)); z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z0) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1)); z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z1) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2)); t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t0) : "v"(bit_cast<int32_t>(x1)), "v"(bit_cast<int32_t>(x0)), "s"(m3)); t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(t1) : "v"(bit_cast<int32_t>(x3)), "v"(bit_cast<int32_t>(x2)), "s"(m3)); z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z2) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m1)); z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
asm volatile("v_perm_b32 %0, %1, %2, %3" : "=v"(z3) : "v"(bit_cast<int32_t>(t1)), "v"(bit_cast<int32_t>(t0)), "s"(m2));
// clang-format on
y0 = bit_cast<int8x4_t>(z0); y0 = bit_cast<int8x4_t>(z0);
y1 = bit_cast<int8x4_t>(z1); y1 = bit_cast<int8x4_t>(z1);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "ck/ck.hpp" #include "ck/ck.hpp"
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/host_utility/io.hpp" #include "ck/host_utility/io.hpp"
...@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out, ...@@ -32,7 +33,7 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl; << std::endl;
return false; return false;
} }
...@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out, ...@@ -50,7 +51,7 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl; << "] != ref[" << i << "]: " << out[i] << " != " << ref[i] << std::endl;
} }
res = false; res = false;
...@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out, ...@@ -58,7 +59,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
...@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out, ...@@ -73,7 +74,7 @@ check_err(const std::vector<T>& out,
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl; << std::endl;
return false; return false;
} }
...@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out, ...@@ -94,7 +95,7 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
...@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out, ...@@ -102,22 +103,22 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
template <typename T> template <typename T>
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type typename std::enable_if<std::is_same_v<T, half_t>, bool>::type
check_err(const std::vector<T>& out, check_err(span<const T> out,
const std::vector<T>& ref, span<const T> ref,
const std::string& msg = "Error: Incorrect results!", const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3, double rtol = 1e-3,
double atol = 1e-3) double atol = 1e-3)
{ {
if(out.size() != ref.size()) if(out.size() != ref.size())
{ {
std::cout << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size() std::cerr << msg << " out.size() != ref.size(), :" << out.size() << " != " << ref.size()
<< std::endl; << std::endl;
return false; return false;
} }
...@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out, ...@@ -137,7 +138,7 @@ check_err(const std::vector<T>& out,
err_count++; err_count++;
if(err_count < 5) if(err_count < 5)
{ {
std::cout << msg << std::setw(12) << std::setprecision(7) << " out[" << i std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i
<< "] != ref[" << i << "]: " << o << " != " << r << std::endl; << "] != ref[" << i << "]: " << o << " != " << r << std::endl;
} }
res = false; res = false;
...@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out, ...@@ -145,11 +146,22 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cout << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
template <typename T>
typename std::enable_if<std::is_same<T, half_t>::value, bool>::type
check_err(const std::vector<T>& out,
const std::vector<T>& ref,
const std::string& msg = "Error: Incorrect results!",
double rtol = 1e-3,
double atol = 1e-3)
{
return check_err(span<const T>{out}, span<const T>{ref}, msg, rtol, atol);
}
template <typename T> template <typename T>
std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>) std::enable_if_t<(std::is_integral_v<T> && !std::is_same_v<T, bhalf_t>)
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
...@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out, ...@@ -194,7 +206,7 @@ check_err(const std::vector<T>& out,
} }
if(!res) if(!res)
{ {
std::cout << "max err: " << max_err << std::endl; std::cerr << "max err: " << max_err << std::endl;
} }
return res; return res;
} }
......
...@@ -5,7 +5,10 @@ ...@@ -5,7 +5,10 @@
#include <algorithm> #include <algorithm>
#include <cmath> #include <cmath>
#include <iterator>
#include <random> #include <random>
#include <type_traits>
#include <utility>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
...@@ -25,6 +28,15 @@ struct FillUniformDistribution ...@@ -25,6 +28,15 @@ struct FillUniformDistribution
std::uniform_real_distribution<float> dis(a_, b_); std::uniform_real_distribution<float> dis(a_, b_);
std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); }); std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
} }
template <typename ForwardRange>
auto operator()(ForwardRange&& range) -> std::void_t<decltype(
std::declval<FillUniformDistribution>()(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range))))>
{
(*this)(std::begin(std::forward<ForwardRange>(range)),
std::end(std::forward<ForwardRange>(range)));
}
}; };
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below. // Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
......
...@@ -3,15 +3,16 @@ ...@@ -3,15 +3,16 @@
#pragma once #pragma once
#include <thread>
#include <vector>
#include <numeric>
#include <algorithm> #include <algorithm>
#include <utility>
#include <cassert> #include <cassert>
#include <iostream> #include <iostream>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include "ck/utility/data_type.hpp" #include "ck/utility/data_type.hpp"
#include "ck/utility/span.hpp"
template <typename Range> template <typename Range>
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim) std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
...@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs) ...@@ -235,6 +236,9 @@ auto make_ParallelTensorFunctor(F f, Xs... xs)
template <typename T> template <typename T>
struct Tensor struct Tensor
{ {
using Descriptor = HostTensorDescriptor;
using Data = std::vector<T>;
template <typename X> template <typename X>
Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize()) Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.GetElementSpaceSize())
{ {
...@@ -251,7 +255,7 @@ struct Tensor ...@@ -251,7 +255,7 @@ struct Tensor
{ {
} }
Tensor(const HostTensorDescriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {} Tensor(const Descriptor& desc) : mDesc(desc), mData(mDesc.GetElementSpaceSize()) {}
template <typename OutT> template <typename OutT>
Tensor<OutT> CopyAsType() const Tensor<OutT> CopyAsType() const
...@@ -278,9 +282,9 @@ struct Tensor ...@@ -278,9 +282,9 @@ struct Tensor
{ {
} }
const std::vector<std::size_t>& GetLengths() const { return mDesc.GetLengths(); } decltype(auto) GetLengths() const { return mDesc.GetLengths(); }
const std::vector<std::size_t>& GetStrides() const { return mDesc.GetStrides(); } decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); } std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
...@@ -288,6 +292,8 @@ struct Tensor ...@@ -288,6 +292,8 @@ struct Tensor
std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); } std::size_t GetElementSpaceSize() const { return mDesc.GetElementSpaceSize(); }
std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
void SetZero() void SetZero()
{ {
for(auto& v : mData) for(auto& v : mData)
...@@ -425,14 +431,40 @@ struct Tensor ...@@ -425,14 +431,40 @@ struct Tensor
return mData[mDesc.GetOffsetFromMultiIndex(idx)]; return mData[mDesc.GetOffsetFromMultiIndex(idx)];
} }
typename std::vector<T>::iterator begin() { return mData.begin(); } typename Data::iterator begin() { return mData.begin(); }
typename Data::iterator end() { return mData.end(); }
typename std::vector<T>::iterator end() { return mData.end(); } typename Data::pointer data() { return mData.data(); }
typename std::vector<T>::const_iterator begin() const { return mData.begin(); } typename Data::const_iterator begin() const { return mData.begin(); }
typename std::vector<T>::const_iterator end() const { return mData.end(); } typename Data::const_iterator end() const { return mData.end(); }
typename Data::const_pointer data() const { return mData.data(); }
typename Data::size_type size() const { return mData.size(); }
template <typename U = T>
auto AsSpan() const
{
constexpr std::size_t FromSize = sizeof(T);
constexpr std::size_t ToSize = sizeof(U);
using Element = std::add_const_t<std::remove_reference_t<U>>;
return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
}
template <typename U = T>
auto AsSpan()
{
constexpr std::size_t FromSize = sizeof(T);
constexpr std::size_t ToSize = sizeof(U);
using Element = std::remove_reference_t<U>;
return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
}
HostTensorDescriptor mDesc; Descriptor mDesc;
std::vector<T> mData; Data mData;
}; };
...@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_ ...@@ -55,6 +55,22 @@ using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_
// clang-format on // clang-format on
>; >;
using device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances =
std::tuple<
// clang-format off
//#######################################| ALayout| B0Layout| B1Layout| CLayout| AData| B0Data| B1Data| CData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#######################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 64, 32, 4, 4, 2, 32, 32, 2, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 256, 128, 40, 128, 32, 4, 4, 2, 32, 32, 2, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 64, 32, 4, 4, 2, 32, 32, 1, 8, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 256, 40, 128, 32, 4, 4, 2, 32, 32, 1, 8, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 64, 32, 4, 4, 2, 32, 32, 1, 4, 2, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>,
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, PassThrough, PassThrough, GemmPadded, 1, 256, 128, 128, 40, 128, 32, 4, 4, 2, 32, 32, 1, 4, 4, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S<2,128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, false, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8>
// clang-format on
>;
void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance( void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance(
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row, std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemm<Row,
Col, Col,
...@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g ...@@ -73,6 +89,9 @@ void add_device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_g
add_device_operation_instances( add_device_operation_instances(
instances, instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{}); device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances{});
add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_irregular_k_instances{});
} }
} // namespace instance } // namespace instance
......
...@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16) ...@@ -105,6 +105,19 @@ TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16)
this->Run(); this->Run();
} }
TYPED_TEST(TestBatchedGemmSoftmaxGemmFP16, DISABLED_Bench_FP16_IrregularK)
{
this->lengths_ = std::vector<std::vector<int>>{{256, 256, 160, 160, 16},
{256, 64, 160, 64, 16},
{1024, 1024, 80, 80, 16},
{1024, 64, 80, 64, 16},
{4096, 4096, 40, 40, 16},
{4096, 64, 40, 64, 16}};
this->bench_ = true;
this->verify_ = false;
this->Run();
}
using ck::tensor_operation::device::GemmSpecialization; using ck::tensor_operation::device::GemmSpecialization;
// TODO: enable KPadding tests when it is implemented // TODO: enable KPadding tests when it is implemented
......
...@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test ...@@ -29,14 +29,19 @@ struct TestBatchedGemmSoftmaxGemm : public ::testing::Test
using B1Layout = std::tuple_element_t<6, Tuple>; using B1Layout = std::tuple_element_t<6, Tuple>;
using CLayout = std::tuple_element_t<7, Tuple>; using CLayout = std::tuple_element_t<7, Tuple>;
std::vector<std::vector<int>> lengths_ = { std::vector<std::vector<int>> lengths_ = {{256, 256, 64, 64, 4},
{256, 256, 64, 64, 4}, {256, 256, 128, 128, 4},
{256, 256, 128, 128, 4}, {512, 512, 64, 64, 2},
{512, 512, 64, 64, 2}, {512, 512, 128, 128, 2},
{512, 512, 128, 128, 2}, {1024, 1024, 64, 64, 1},
{1024, 1024, 64, 64, 1}, {1024, 1024, 128, 128, 1},
{1024, 1024, 128, 128, 1}, {256, 256, 160, 160, 4},
}; {256, 64, 160, 64, 4},
{1024, 1024, 80, 80, 2},
{1024, 64, 80, 64, 2},
{4096, 4096, 40, 40, 1},
{4096, 64, 40, 64, 1}};
bool bench_ = false; bool bench_ = false;
bool verify_ = true; bool verify_ = true;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment