Commit 8f9c0243 authored by Alan Turner's avatar Alan Turner
Browse files

Merge branch 'develop' into migx-jit-lib

parents 181ea79a c8a8385f
......@@ -142,8 +142,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -323,13 +323,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
remove_cvref_t<
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
remove_cvref_t<
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
......@@ -654,12 +654,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
FloatC, // typename Src0Data,
FloatC, // typename Src1Data,
FloatC, // typename DstData,
decltype(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
......@@ -151,8 +151,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using GridwiseGemmPipe = remove_cvref_t<decltype(
GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
using GridwiseGemmPipe = remove_cvref_t<
decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
__host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
{
......@@ -331,18 +331,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
c_grid_desc_m_n);
}
using CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
remove_cvref_t<
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
CGridDesc_M_N{}))>;
using C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
remove_cvref_t<
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C0GridDesc_M_N{}))>;
using C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl =
remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
remove_cvref_t<
decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
C1GridDesc_M_N{}))>;
using DefaultBlock2CTileMap =
......@@ -674,14 +674,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
FloatC, // typename Src1Data,
FloatC, // typename Src2Data,
FloatC, // typename DstData,
decltype(
c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
5, // index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace ck {
template <typename GridwisePutElementwise1dFunctor,
typename InGrid1dDesc,
typename InDataType,
typename IndexDataType,
typename OutDataType,
typename ElementwiseOperation>
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
const InDataType* __restrict__ p_in_global,
const IndexDataType* __restrict__ p_indices_global,
OutDataType* __restrict__ p_out_global,
const ElementwiseOperation elementwise_op)
{
GridwisePutElementwise1dFunctor::Run(
in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
}
// output[indices] = input
template <typename InGrid1dDesc,
typename InDataType,
typename IndexDataType,
typename OutDataType,
typename ElementwiseOperation,
InMemoryDataOperationEnum MemOp,
index_t InVectorSize>
struct GridwisePutElement_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_buffer_desc_m =
make_naive_tensor_descriptor_packed(make_tuple(Number<InVectorSize>{}));
__device__ static void Run(const InGrid1dDesc& in_grid_1d_desc,
const InDataType* __restrict__ p_in_global,
const IndexDataType* __restrict__ p_indices_global,
OutDataType* __restrict__ p_out_global,
const ElementwiseOperation& elementwise_op)
{
// Global Memory
const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_global, in_grid_1d_desc.GetElementSpaceSize());
const auto indices_global_buf =
make_dynamic_buffer<AddressSpaceEnum::Global>(p_indices_global,
in_grid_1d_desc.GetElementSpaceSize(),
NumericLimits<IndexDataType>::Lowest());
// VGPR
StaticBuffer<AddressSpaceEnum::Vgpr, InDataType, InVectorSize, true> in_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, IndexDataType, InVectorSize, true> indices_thread_buf;
// Thread id, Block id and index
const index_t thread_global_id = get_thread_global_1d_id();
const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize);
const index_t blockSize = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto M = in_grid_1d_desc.GetLength(I0);
const index_t loop_step = blockPerGrid * blockSize * InVectorSize;
const auto loop_step_index = make_multi_index(loop_step);
auto in_global_load =
ThreadwiseTensorSliceTransfer_v2<InDataType,
InDataType,
decltype(in_grid_1d_desc),
decltype(thread_buffer_desc_m),
Sequence<InVectorSize>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InVectorSize, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc, thread_global_offset};
auto indices_global_load =
ThreadwiseTensorSliceTransfer_v2<IndexDataType,
IndexDataType,
decltype(in_grid_1d_desc),
decltype(thread_buffer_desc_m),
Sequence<InVectorSize>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
InVectorSize, // ScalarPerVector
1, // SrcScalarStrideInVector
false>{in_grid_1d_desc, thread_global_offset};
index_t num_iter = M / loop_step;
do
{
in_global_load.Run(in_grid_1d_desc,
in_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
in_thread_buf);
in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
static_for<0, InVectorSize, 1>{}(
[&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); });
indices_global_load.Run(in_grid_1d_desc,
indices_global_buf,
thread_buffer_desc_m,
make_tuple(I0),
indices_thread_buf);
indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
static_for<0, InVectorSize, 1>{}([&](auto iM) {
if(indices_thread_buf[iM] >= 0)
{
if constexpr(MemOp == InMemoryDataOperationEnum::Set)
{
// User should guarantee each index in p_indices_global is different
*(p_out_global + indices_thread_buf[iM]) =
ck::type_convert<OutDataType>(in_thread_buf[iM]);
}
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd)
{
atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM]));
}
else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
{
atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
ck::type_convert<OutDataType>(in_thread_buf[iM]));
}
else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
{
// User should guarantee each index in p_indices_global is different
*(p_out_global + indices_thread_buf[iM]) +=
ck::type_convert<OutDataType>(in_thread_buf[iM]);
}
else
{
static_assert(MemOp == InMemoryDataOperationEnum::Set ||
MemOp == InMemoryDataOperationEnum::AtomicAdd ||
MemOp == InMemoryDataOperationEnum::AtomicMax ||
MemOp == InMemoryDataOperationEnum::Add);
}
}
});
} while(--num_iter);
}
};
} // namespace ck
......@@ -78,8 +78,8 @@ struct GridwiseSparseEmbeddingsForwardLayernorm
using ThreadwiseWolfordDesc2D = decltype(make_naive_tensor_descriptor_packed(make_tuple(
Number<DimSubBlocks * DimThreadSize>{}, Number<RowSubBlocks * RowVectorSize>{})));
using ThreadwiseWolfordDescReduce = decltype(
make_naive_tensor_descriptor_packed(make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWolfordDescReduce = decltype(make_naive_tensor_descriptor_packed(
make_tuple(Number<DimSubBlocks * DimThreadSize>{})));
using ThreadwiseWelford =
ThreadwiseWelford<AccDataType, ThreadwiseWolfordDesc2D, ThreadwiseWolfordDescReduce>;
......
......@@ -78,17 +78,18 @@ struct GridwiseNormalizationSplitK1st
static constexpr auto ThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
__device__ static int
GetKPerThread(int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
GetKPerThread(int k, int kRaw, int kGridSize, int block_k_cluster_id, int thread_k_cluster_id)
{
bool is_rightmost_block = block_k_cluster_id == kGridSize - 1;
if(is_rightmost_block)
{
int left_kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
int kPerBlock = kRaw % kGridSize == 0 ? left_kPerBlock : kRaw % left_kPerBlock;
int kPerThread =
kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
int left_kPerBlock = math::integer_divide_ceil(k, kGridSize);
int kRightmostBlock = kRaw - left_kPerBlock * (kGridSize - 1);
int kPerThread = kRightmostBlock < K_BlockTileSize
? 0
: KThreadSliceSize * (kRightmostBlock / K_BlockTileSize);
int kPerBlockTail = kRightmostBlock - kPerThread * KThreadClusterSize;
if(kPerBlockTail > 0)
{
......@@ -105,7 +106,7 @@ struct GridwiseNormalizationSplitK1st
}
else
{
int kPerBlock = math::integer_divide_ceil(kRaw, kGridSize);
int kPerBlock = math::integer_divide_ceil(k, kGridSize);
return KThreadSliceSize * (kPerBlock / K_BlockTileSize);
}
}
......@@ -195,8 +196,11 @@ struct GridwiseNormalizationSplitK1st
auto threadwise_welford = ThreadwiseWelford();
int kRaw = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
threadwise_welford.max_count_ =
GetKPerThread(kRaw, k_grid_size, block_k_cluster_id, thread_k_cluster_id);
threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k.GetLength(I1),
kRaw,
k_grid_size,
block_k_cluster_id,
thread_k_cluster_id);
static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
mean_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace ck {
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template <typename FloatA,
typename FloatB,
typename FloatC,
typename AThreadDesc_TK0_TM0_TM1_TK1,
typename BThreadDesc_TK0_TN0_TN1_TK1,
typename CThreadDesc_TM0_TM1_TN0_TN1,
typename TKLengths,
typename TMLengths,
typename TNLengths,
bool ShareA,
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false>
struct ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr index_t TK0 = TKLengths{}[I0];
static constexpr index_t TK1 = TKLengths{}[I1];
static constexpr index_t TM0 = TMLengths{}[I0];
static constexpr index_t TM1 = TMLengths{}[I1];
static constexpr index_t TN0 = TNLengths{}[I0];
static constexpr index_t TN1 = TNLengths{}[I1];
static_assert(TM0 == 1 && TN0 == 1);
static_assert((ShareA && TM1 % dpp8::lane_group_size == 0) ||
(!ShareA && TN1 % dpp8::lane_group_size == 0));
static constexpr index_t shared_elems_per_lane =
ShareA ? TM1 / dpp8::lane_group_size : TN1 / dpp8::lane_group_size;
__device__ constexpr ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
{
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
"wrong!");
}
template <typename ABuffer,
typename AOriginIdx,
typename BBuffer,
typename BOriginIdx,
typename CBuffer,
typename COriginIdx>
__device__ static void Run(const ABuffer& a_buf,
AOriginIdx,
const BBuffer& b_buf,
BOriginIdx,
CBuffer& c_buf,
COriginIdx)
{
static_assert(is_known_at_compile_time<remove_cvref_t<AOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<BOriginIdx>>::value &&
is_known_at_compile_time<remove_cvref_t<COriginIdx>>::value,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
static_assert(
is_same<remove_cvref_t<typename ABuffer::type>, remove_cvref_t<FloatA>>::value &&
is_same<remove_cvref_t<typename BBuffer::type>, remove_cvref_t<FloatB>>::value &&
is_same<remove_cvref_t<typename CBuffer::type>, remove_cvref_t<FloatC>>::value &&
"wrong! inconsistent type");
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
static_for<0, TK0, 1>{}([&](auto tk0) {
static_for<0, TM1, 1>{}([&](auto tm1) {
static_for<0, TN1, 1>{}([&](auto tn1) {
vector_type<FloatA, TK1> a_vec;
vector_type<FloatB, TK1> b_vec;
static_for<0, TK1, 1>{}([&](auto tk1) {
constexpr index_t local_tm1 = ShareA ? tm1 % shared_elems_per_lane : tm1;
constexpr index_t a_offset = AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
a_origin_idx + make_multi_index(tk0, 0, local_tm1, tk1));
constexpr index_t local_tn1 = ShareA ? tn1 : tn1 % shared_elems_per_lane;
constexpr index_t b_offset = BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
b_origin_idx + make_multi_index(tk0, 0, local_tn1, tk1));
a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
});
using a_vector_t = typename vector_type<FloatA, TK1>::type;
using b_vector_t = typename vector_type<FloatB, TK1>::type;
constexpr index_t c_offset = CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
c_origin_idx + make_multi_index(0, tm1, 0, tn1));
constexpr int src_lane =
ShareA ? (tm1 / shared_elems_per_lane) % dpp8::lane_group_size
: (tn1 / shared_elems_per_lane) % dpp8::lane_group_size;
dpp8::inner_product_dpp<a_vector_t, b_vector_t, FloatC, src_lane, ShareA>(
a_vec.template AsType<a_vector_t>()[I0],
b_vec.template AsType<b_vector_t>()[I0],
c_buf(Number<c_offset>{}));
});
});
});
}
};
} // namespace ck
......@@ -129,6 +129,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
static_assert(SliceLengths::At(SrcVectorDim) % SrcScalarPerVector == 0,
"SliceLengths[SrcVectorDim] must be divisible by SrcScalarPerVector");
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_space_filling_curve.hpp"
namespace ck {
// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
// and sometimes useless instructions:
// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
// instead
// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
// tensor coordinate instead
// 3. Don't use a pointer to VGPR buffer, use vector instead
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename ElementwiseOperation,
typename SliceLengths,
typename DimAccessOrder,
index_t VectorDim,
index_t ScalarPerVector,
bool SrcResetCoordinateAfterRun,
bool DstResetCoordinateAfterRun>
struct ThreadwiseTensorSliceTransfer_v6r1r2
{
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r1r2(
const SrcDesc& src_desc,
const Index& src_slice_origin,
const DstDesc& dst_desc,
const Index& dst_slice_origin,
const ElementwiseOperation& element_op)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
"wrong! cannot evenly divide");
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, InMemoryDataOperationEnum DstInMemOp>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
// loop over space-filling curve
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
using src_vector_type = vector_type_maker_t<SrcData, ScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto dst_vector_container = dst_vector_type{};
// apply pointwise operation
static_for<0, ScalarPerVector, 1>{}([&](auto i) {
SrcData v;
// apply element-wise operation
element_op_(v, src_vector_container.template AsType<SrcData>()[i]);
// apply type convert
dst_vector_container.template AsType<DstData>()(i) = type_convert<DstData>(v);
});
const bool is_dst_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
// copy data from dst_vector into dst_buf
dst_buf.template Update<DstInMemOp, dst_vector_t>(
dst_coord_.GetOffset(),
is_dst_valid,
dst_vector_container.template AsType<dst_vector_t>()[I0]);
// move coordinate
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
move_tensor_coordinate(
dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
}
});
// move coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
if constexpr(DstResetCoordinateAfterRun)
{
const auto dst_reset_step =
make_tensor_coordinate_step(dst_desc, GetCoordinateResetStep());
move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
}
}
__device__ static constexpr auto GetCoordinateResetStep()
{
constexpr auto scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<VectorDim, ScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx = SrcResetCoordinateAfterRun
? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
const Index& dst_slice_origin_step_idx)
{
// if dst coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx = DstResetCoordinateAfterRun
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
DstCoord dst_coord_;
const ElementwiseOperation element_op_;
};
} // namespace ck
......@@ -29,7 +29,9 @@ enum struct MfmaInstr
mfma_i32_16x16x16i8,
mfma_i32_32x32x16i8,
mfma_i32_16x16x32i8,
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64,
mfma_f32_32x32x16f8f8,
mfma_f32_16x16x32f8f8
};
template <MfmaInstr instr>
......@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 4;
static constexpr index_t num_regs_per_blk = 16;
static constexpr index_t num_threads_per_blk = 32;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 2;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 32;
static constexpr index_t n_per_blk = 32;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <>
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
{
static constexpr index_t group_size = 4;
static constexpr index_t num_groups_per_blk = 1;
static constexpr index_t num_regs_per_blk = 4;
static constexpr index_t num_threads_per_blk = 16;
static constexpr index_t wave_size = 64;
static constexpr index_t num_input_blks = 4;
static constexpr index_t num_output_blks = 1;
static constexpr index_t m_per_blk = 16;
static constexpr index_t n_per_blk = 16;
static constexpr index_t k_per_blk = 8;
static constexpr bool is_k_reduction = true;
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
}
};
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
struct MfmaSelector
{
......@@ -594,6 +640,18 @@ struct MfmaSelector
}
#endif
template <>
static constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}
template <>
static constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
__host__ __device__ constexpr MfmaSelector()
......@@ -794,7 +852,7 @@ struct XdlopsGemm
{
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
is_same<base_type, int8_t>::value,
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
"base base_type must be double, float, half, bfloat16, and int8_t!");
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
......
......@@ -13,6 +13,150 @@
namespace ck {
namespace tensor_operation {
namespace {
template <
index_t NDimSpatial,
typename ALayout,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization>
constexpr auto make_out_grid_desc(const index_t N,
const index_t Do,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides)
{
const auto KStride = Number<1>{};
if constexpr(is_same_v<ALayout, tensor_layout::convolution::NHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t HiStride = out_g_n_k_wos_strides[3];
const index_t WiStride = out_g_n_k_wos_strides[4];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
make_tuple(WiStride, KStride));
}
else
{
return make_naive_tensor_descriptor(make_tuple(N, Ho, Wo, K),
make_tuple(NStride, HiStride, WiStride, KStride));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
{
const index_t NStride = out_g_n_k_wos_strides[1];
const index_t DoStride = out_g_n_k_wos_strides[3];
const index_t HoStride = out_g_n_k_wos_strides[4];
const index_t WoStride = out_g_n_k_wos_strides[5];
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
make_tuple(WoStride, KStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Do, Ho, Wo, K),
make_tuple(NStride, DoStride, HoStride, WoStride, KStride));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
}
}
else if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNDHWK>)
{
// assume packed
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K));
}
else
{
return make_naive_tensor_descriptor_packed(make_tuple(N, Do, Ho, Wo, K));
}
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + ALayout::name());
}
}
template <typename BLayout>
constexpr auto make_wei_grid_desc(
const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C)
{
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
}
else if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
{
return make_naive_tensor_descriptor_packed(make_tuple(K, Z, Y, X, C));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + BLayout::name());
}
}
template <index_t NDimSpatial, typename CLayout>
constexpr auto make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides)
{
if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>)
{
return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
}
else if constexpr(is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC>)
{
return make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[5],
in_g_n_c_wis_strides[2]));
}
else
{
throw std::runtime_error("wrong! unsupported layout: " + CLayout::name());
}
}
} // namespace
template <
index_t NDimSpatial,
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization ConvBwdDataSpecialization,
......@@ -27,13 +171,30 @@ struct TransformConvBwdDataToGemm_v1
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto NonSpatialDimsNum = Number<3>{};
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
static constexpr auto HIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto WIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
static constexpr auto YIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
static constexpr auto XIdx =
NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{} : Number<NonSpatialDimsNum + 2>{};
template <typename ALayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<ALayout, tensor_layout::convolution::GNHWK>,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>),
bool>::type = false>
static auto MakeADescriptor_AK0_M_AK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
const std::array<index_t, NDimSpatial + 3>& /* out_g_n_k_wos_strides */,
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
const std::array<index_t, NDimSpatial + 3>& /* wei_g_k_c_xs_strides */,
const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
......@@ -44,44 +205,52 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t AK0 = K / AK1;
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed
const auto out_n_ho_wo_k_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(N, Ho, Wo, K));
// n_do_ho_wo_k for 3d or n_ho_wo_k for 2d
const auto out_grid_desc =
make_out_grid_desc<NDimSpatial, ALayout, ConvBwdDataSpecialization>(
N, Do, Ho, Wo, K, out_g_n_k_wos_strides);
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t AK0 = math::integer_divide_ceil(K, AK1);
// A: output tensor
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)),
make_tuple(make_pass_through_transform(N * Ho * Wo),
out_grid_desc,
make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
......@@ -96,41 +265,57 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t AK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, AK1);
if constexpr(NDimSpatial == 2)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_n_ho_wo_k_grid_desc,
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
......@@ -150,7 +335,7 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc =
const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
......@@ -158,7 +343,7 @@ struct TransformConvBwdDataToGemm_v1
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_unmerge_transform(make_tuple(AK0, AK1))),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
......@@ -170,29 +355,135 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5, 6>{}));
Sequence<5>{}));
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_ak0_ak1_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, AK0)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
make_pass_through_transform(AK1)),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}, Sequence<6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmak0_gemmmraw_gemmak1_grid_desc,
make_tuple(AK0, GemmMPerBlock, AK1),
Sequence<false, DoPadGemmM, false>{});
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
// A: output tensor
const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Do, I0, I0),
make_pad_transform(Ho, I0, I0),
make_pad_transform(Wo, I0, I0),
make_pass_through_transform(K)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
transform_tensor_descriptor(
out_n_hop_wop_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(
make_tuple(ZDot, DTilde),
make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
make_embed_transform(
make_tuple(YDot, HTilde),
make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
make_embed_transform(
make_tuple(XDot, WTilde),
make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc =
transform_tensor_descriptor(
out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
make_tuple(make_pass_through_transform(N),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(K)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}));
const auto out_gemmk_gemmmraw_grid_desc = transform_tensor_descriptor(
out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc,
make_tuple(
make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice))),
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmk_gemmm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
out_gemmk_gemmmraw_grid_desc,
make_tuple(AK1, GemmMPerBlock),
Sequence<true, DoPadGemmM>{});
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
out_gemmk_gemmm_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
make_pass_through_transform(
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return out_gemmak0_gemmm_gemmak1_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
}
}
template <typename BLayout,
typename std::enable_if<NDimSpatial == 2 &&
is_same_v<BLayout, tensor_layout::convolution::GKYXC>,
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>),
bool>::type = false>
static auto MakeBDescriptor_BK0_N_BK1(
const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
......@@ -207,35 +498,40 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& /* input_right_pads */,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t K = wei_g_k_c_xs_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t BK0 = K / BK1;
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume packed
const auto wei_k_y_x_c_grid_desc =
make_naive_tensor_descriptor_packed(make_tuple(K, Y, X, C));
// k_y_x_c for 2d or k_z_y_x_c for 3d
const auto wei_grid_desc = make_wei_grid_desc<BLayout>(K, Z, Y, X, C);
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
const index_t BK0 = math::integer_divide_ceil(K, BK1);
// B: weight tensor
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K, C)),
......@@ -243,7 +539,7 @@ struct TransformConvBwdDataToGemm_v1
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, C), make_tuple(I0, I1));
make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, C), make_tuple(I0, I1));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
......@@ -255,23 +551,33 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
const auto YDot = math::integer_divide_ceil(Y, YTilde);
const auto XDot = math::integer_divide_ceil(X, XTilde);
// GemmK is different for each GEMM
const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
const index_t BK0 =
math::integer_divide_ceil(ZDotSlice * YDotSlice * XDotSlice * K, BK1);
// B weight tensor
if constexpr(NDimSpatial == 2)
{
const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
wei_k_y_x_c_grid_desc,
make_tuple(make_pass_through_transform(K),
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
......@@ -280,9 +586,9 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
const auto wei_k_ydotslice_xdotslice_c_grid_desc = transform_tensor_descriptor(
wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ytilde),
......@@ -294,36 +600,125 @@ struct TransformConvBwdDataToGemm_v1
Sequence<2>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0, 1>{},
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<>{},
Sequence<>{},
Sequence<3>{}));
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_k_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(Sequence<1, 2, 0>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmk_gemmn_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemmn_padded_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
transform_tensor_descriptor(
wei_grid_desc,
make_tuple(
make_pass_through_transform(K),
make_embed_transform(make_tuple(ZDot, ZTilde),
make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
make_embed_transform(make_tuple(YDot, YTilde),
make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
make_embed_transform(make_tuple(XDot, XTilde),
make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc =
transform_tensor_descriptor(
wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
make_tuple(make_pass_through_transform(K),
make_slice_transform(ZDot, I0, ZDotSlice),
make_slice_transform(YDot, I0, YDotSlice),
make_slice_transform(XDot, I0, XDotSlice),
make_freeze_transform(i_ztilde),
make_freeze_transform(i_ytilde),
make_freeze_transform(i_xtilde),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<2>{},
Sequence<4>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<>{},
Sequence<>{},
Sequence<>{},
Sequence<4>{}));
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_bk0_bk1_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, BK0)),
make_pass_through_transform(C),
make_pass_through_transform(BK1)),
make_tuple(Sequence<2, 3, 0>{}, Sequence<4>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmk_gemmnraw_grid_desc = transform_tensor_descriptor(
wei_gemmk_zdotslice_ydotslice_xdotslice_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K)),
make_pass_through_transform(C)),
make_tuple(Sequence<1, 2, 3, 0>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
const auto wei_gemmk_gemm_padded_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
wei_gemmk_gemmnraw_grid_desc,
make_tuple(BK1, GemmNPerBlock),
Sequence<true, DoPadGemmN>{});
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
wei_gemmk_gemm_padded_grid_desc,
make_tuple(
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc.GetLength(I0), GemmNPerBlock, BK1),
Sequence<false, DoPadGemmN, false>{});
make_unmerge_transform(make_tuple(BK0, BK1)),
make_pass_through_transform(wei_gemmk_gemm_padded_grid_desc.GetLength(I1))),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
return wei_gemmbk0_gemm_gemmbk1_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
}
}
template <typename CLayout,
typename std::enable_if<NDimSpatial == 2 &&
typename std::enable_if<(NDimSpatial == 2 || NDimSpatial == 3) &&
(is_same_v<CLayout, tensor_layout::convolution::GNHWC> ||
is_same_v<CLayout, tensor_layout::convolution::GNDHWC> ||
is_same_v<CLayout, tensor_layout::convolution::NHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::NDHWGC> ||
is_same_v<CLayout, tensor_layout::convolution::G_NHW_C>),
bool>::type = false>
static auto
......@@ -339,49 +734,57 @@ struct TransformConvBwdDataToGemm_v1
const std::array<index_t, NDimSpatial>& input_right_pads,
const std::array<index_t, NDimSpatial>& tildes)
{
index_t i_ytilde = tildes[0];
index_t i_xtilde = tildes[1];
index_t i_ztilde = tildes[ZIdx - NonSpatialDimsNum];
index_t i_ytilde = tildes[YIdx - NonSpatialDimsNum];
index_t i_xtilde = tildes[XIdx - NonSpatialDimsNum];
const index_t N = in_g_n_c_wis_lengths[1];
const index_t C = wei_g_k_c_xs_lengths[2];
const index_t Hi = in_g_n_c_wis_lengths[3];
const index_t Wi = in_g_n_c_wis_lengths[4];
const index_t Di = NDimSpatial == 3 ? in_g_n_c_wis_lengths[DIdx] : 1;
const index_t Hi = in_g_n_c_wis_lengths[HIdx];
const index_t Wi = in_g_n_c_wis_lengths[WIdx];
const index_t Ho = out_g_n_k_wos_lengths[3];
const index_t Wo = out_g_n_k_wos_lengths[4];
const index_t Do = NDimSpatial == 3 ? out_g_n_k_wos_lengths[DIdx] : 1;
const index_t Ho = out_g_n_k_wos_lengths[HIdx];
const index_t Wo = out_g_n_k_wos_lengths[WIdx];
const index_t Y = wei_g_k_c_xs_lengths[3];
const index_t X = wei_g_k_c_xs_lengths[4];
const index_t Z = NDimSpatial == 3 ? wei_g_k_c_xs_lengths[ZIdx] : 1;
const index_t Y = wei_g_k_c_xs_lengths[YIdx];
const index_t X = wei_g_k_c_xs_lengths[XIdx];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InLeftPadD = input_left_pads[DIdx - NonSpatialDimsNum];
const index_t InLeftPadH = input_left_pads[HIdx - NonSpatialDimsNum];
const index_t InLeftPadW = input_left_pads[WIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t InRightPadD = input_right_pads[DIdx - NonSpatialDimsNum];
const index_t InRightPadH = input_right_pads[HIdx - NonSpatialDimsNum];
const index_t InRightPadW = input_right_pads[WIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
// assume strided
const auto in_n_hi_wi_c_grid_desc =
make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
make_tuple(in_g_n_c_wis_strides[1],
in_g_n_c_wis_strides[3],
in_g_n_c_wis_strides[4],
in_g_n_c_wis_strides[2]));
// n_hi_wi_c for 2d n_di_hi_wi_c for 3d
const auto in_grid_desc =
make_in_grid_desc<NDimSpatial, CLayout>(N, Di, Hi, Wi, C, in_g_n_c_wis_strides);
if constexpr(ConvBwdDataSpecialization ==
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
Filter1x1Stride1Pad0)
{
// C: input tensor
if constexpr(NDimSpatial == 2)
{
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
make_tuple(make_pass_through_transform(N),
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
......@@ -397,7 +800,51 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}),
make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else if constexpr(NDimSpatial == 3)
{
// C: input tensor
const auto in_n_x_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_x_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_freeze_transform(I0),
make_freeze_transform(I0),
make_freeze_transform(I0),
make_merge_transform(make_tuple(N, Do, Ho, Wo)),
make_pass_through_transform(C)),
make_tuple(Sequence<1>{},
Sequence<3>{},
Sequence<5>{},
Sequence<0, 2, 4, 6>{},
Sequence<7>{}),
make_tuple(
Sequence<>{}, Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
......@@ -406,34 +853,51 @@ struct TransformConvBwdDataToGemm_v1
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
}
else
{
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
const auto ZTilde = ConvStrideD / GcdStrideDilationD;
const auto YTilde = ConvStrideH / GcdStrideDilationH;
const auto XTilde = ConvStrideW / GcdStrideDilationW;
const auto DTilde =
Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
const auto HTilde =
Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
const auto WTilde =
Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
// only work on HTilde and WTilde that contribute to non-padding area of input tensor
// only work on DTilde, HTilde and WTilde that contribute to
// non-padding area of input tensor
const auto IDTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
const auto IHTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
const auto IWTildeSliceBegin = math::integer_divide_floor(
math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
const auto IDTildeSliceEnd = math::min(
DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
const auto IHTildeSliceEnd = math::min(
HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
const auto IWTildeSliceEnd = math::min(
WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
// C: input tensor
if constexpr(NDimSpatial == 2)
{
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_n_hi_wi_c_grid_desc,
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
......@@ -441,7 +905,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(YTilde, HTilde),
......@@ -450,7 +915,8 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
make_tuple(
Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
......@@ -480,13 +946,98 @@ struct TransformConvBwdDataToGemm_v1
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc = ck::tensor_operation::device::PadTensorDescriptor(
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else if(NDimSpatial == 3)
{
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_embed_transform(make_tuple(ZTilde, DTilde),
make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(YTilde, HTilde),
make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(XTilde, WTilde),
make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{}));
const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
transform_tensor_descriptor(
in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
make_tuple(make_pass_through_transform(N),
make_freeze_transform(i_ztilde),
make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
make_freeze_transform(i_ytilde),
make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
make_freeze_transform(i_xtilde),
make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<>{},
Sequence<1>{},
Sequence<>{},
Sequence<2>{},
Sequence<>{},
Sequence<3>{},
Sequence<4>{}));
const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor(
in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
make_tuple(
make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmm_gemmn_grid_desc =
ck::tensor_operation::device::PadTensorDescriptor(
in_gemmmraw_gemmnraw_grid_desc,
make_tuple(GemmMPerBlock, GemmNPerBlock),
Sequence<DoPadGemmM, DoPadGemmN>{});
return in_gemmm_gemmn_grid_desc;
}
else
{
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
}
}
}
// for input bias
......
......@@ -629,7 +629,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
{
static_assert(
(is_same<T, double>::value && (N == 1 || N == 2)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, bhalf_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
......@@ -682,6 +682,20 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset,
static_cast<index_t>(coherence));
}
else if constexpr(N == 8)
{
vector_type<float, 8> tmp{src_thread_data};
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<0>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset,
static_cast<index_t>(coherence));
llvm_amdgcn_raw_buffer_store_fp32x4(tmp.AsType<float4_t>()[Number<1>{}],
dst_wave_buffer_resource,
dst_thread_addr_offset,
dst_wave_addr_offset + 4 * sizeof(float),
static_cast<index_t>(coherence));
}
}
else if constexpr(is_same<T, half_t>::value)
{
......@@ -1114,13 +1128,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
return bit_cast<vector_t>(tmp);
}
else
{
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
}
#else
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
}
else
{
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
src_wave_buffer_resource, src_thread_addr_offset, 0);
return src_thread_element_valid ? tmp : vector_t(0);
}
#endif
}
......@@ -1179,14 +1210,34 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp =
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
}
#else
if(dst_thread_element_valid)
{
if constexpr(is_same<scalar_t, f8_t>::value)
{
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
src_thread_data);
amd_buffer_store_impl<int8_t, vector_size, coherence>(
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
else
{
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
}
}
#endif
}
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/amd_gemm_dpp.hpp"
namespace ck {
namespace dpp8 {
/// Number of lanes that can share data using DPP8 modifiers.
constexpr index_t lane_group_size = 8;
__device__ index_t get_lane_group_local_idx() { return threadIdx.x / lane_group_size; }
__device__ index_t get_thread_idx_in_lane_group() { return threadIdx.x % lane_group_size; }
} // namespace dpp8
} // namespace ck
......@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_32x32x16f8f8;
template <>
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x32f8f8;
template <>
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
{
template <class FloatC>
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
bit_cast<long>(reg_a),
bit_cast<long>(reg_b),
reg_c.template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
#else
vector_type<f8_t, 8> reg_a_v(reg_a);
vector_type<f8_t, 8> reg_b_v(reg_b);
static_for<0, 8, 1>{}([&](auto k) {
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
});
#endif
}
};
} // namespace ck
#endif
......@@ -24,6 +24,7 @@
#include "ck/utility/tuple.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/utility/type.hpp"
#include "ck/utility/type_convert.hpp"
#include "ck/utility/magic_division.hpp"
#include "ck/utility/c_style_pointer_cast.hpp"
#include "ck/utility/is_known_at_compile_time.hpp"
......
......@@ -12,6 +12,7 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using int4_t = _BitInt(4);
#endif
using f8_t = uint8_t;
// vector_type
template <typename T, index_t N>
......@@ -142,6 +143,13 @@ struct scalar_type<int4_t>
};
#endif
template <>
struct scalar_type<f8_t>
{
using type = f8_t;
static constexpr index_t vector_size = 1;
};
//
template <typename T>
struct vector_type<T, 1>
......@@ -944,151 +952,13 @@ using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// Convert X to Y
template <typename Y, typename X>
__host__ __device__ constexpr Y type_convert(X x)
{
static_assert(!std::is_reference_v<Y> && !std::is_reference_v<X>);
return static_cast<Y>(x);
}
// convert bfp16 to fp32
template <>
inline __host__ __device__ constexpr float type_convert<float, bhalf_t>(bhalf_t x)
{
union
{
uint32_t int32;
float fp32;
} u = {uint32_t(x) << 16};
return u.fp32;
}
// convert fp32 to bfp16
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
return uint16_t(u.int32 >> 16);
}
// convert bfp16 to fp16 via fp32
template <>
inline __host__ __device__ constexpr half_t type_convert<half_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<half_t>(x_fp32);
}
// convert fp16 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int32 via fp32
template <>
inline __host__ __device__ constexpr int32_t type_convert<int32_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int32_t>(x_fp32);
}
// convert int32 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int32_t>(int32_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// convert bfp16 to int8 via fp32
template <>
inline __host__ __device__ constexpr int8_t type_convert<int8_t, bhalf_t>(bhalf_t x)
{
float x_fp32 = type_convert<float>(x);
return static_cast<int8_t>(x_fp32);
}
// convert int8 to bfp16 via fp32
template <>
inline __host__ __device__ constexpr bhalf_t type_convert<bhalf_t, int8_t>(int8_t x)
{
float x_fp32 = static_cast<float>(x);
return type_convert<bhalf_t>(x_fp32);
}
// Declare a template function for bf16 conversion using RTN
template <typename Y, typename X>
__host__ __device__ constexpr Y bf16_convert_rtn(X x);
// Convert fp32 to bf16 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, float>(float x)
{
union
{
float fp32;
uint32_t int32;
} u = {x};
// When the exponent bits are not all 1s, then the value is zero, normal,
// or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
// 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
// This causes the bfloat16's mantissa to be incremented by 1 if the 16
// least significant bits of the float mantissa are greater than 0x8000,
// or if they are equal to 0x8000 and the least significant bit of the
// bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
// the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
// has the value 0x7f, then incrementing it causes it to become 0x00 and
// the exponent is incremented by one, which is the next higher FP value
// to the unrounded bfloat16 value. When the bfloat16 value is subnormal
// with an exponent of 0x00 and a mantissa of 0x7f, it may be rounded up
// to a normal value with an exponent of 0x01 and a mantissa of 0x00.
// When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
// incrementing it causes it to become an exponent of 0xFF and a mantissa
// of 0x00, which is Inf, the next higher value to the unrounded value.
bool flag0 = ~u.int32 & 0x7f800000;
// When all of the exponent bits are 1, the value is Inf or NaN.
// Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
// mantissa bit. Quiet NaN is indicated by the most significant mantissa
// bit being 1. Signaling NaN is indicated by the most significant
// mantissa bit being 0 but some other bit(s) being 1. If any of the
// lower 16 bits of the mantissa are 1, we set the least significant bit
// of the bfloat16 mantissa, in order to preserve signaling NaN in case
// the bfloat16's mantissa bits are all 0.
bool flag1 = !flag0 && (u.int32 & 0xffff);
u.int32 += flag0 ? 0x7fff + ((u.int32 >> 16) & 1) : 0; // Round to nearest, round to even
u.int32 |= flag1 ? 0x10000 : 0x0; // Preserve signaling NaN
return uint16_t(u.int32 >> 16);
}
// convert fp16 to bfp16 via fp32 with RTN if higher precision is needed
template <>
inline __host__ __device__ constexpr bhalf_t bf16_convert_rtn<bhalf_t, half_t>(half_t x)
{
float x_fp32 = static_cast<float>(x);
return bf16_convert_rtn<bhalf_t>(x_fp32);
}
// f8
using f8x2_t = typename vector_type<f8_t, 2>::type;
using f8x4_t = typename vector_type<f8_t, 4>::type;
using f8x8_t = typename vector_type<f8_t, 8>::type;
using f8x16_t = typename vector_type<f8_t, 16>::type;
using f8x32_t = typename vector_type<f8_t, 32>::type;
using f8x64_t = typename vector_type<f8_t, 64>::type;
template <typename T>
struct NumericLimits
......@@ -1136,4 +1006,21 @@ struct NumericLimits<int4_t>
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
struct NumericLimits<f8_t>
{
static constexpr uint8_t binary_min = 0x08; // 0b00001000
static constexpr uint8_t binary_max = 0x77; // 0b01110111
static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
static constexpr uint8_t binary_qnan = 0x80; // 0b10000000
__host__ __device__ static constexpr f8_t Min() { return bit_cast<f8_t>(binary_min); }
__host__ __device__ static constexpr f8_t Max() { return bit_cast<f8_t>(binary_max); }
__host__ __device__ static constexpr f8_t Lowest() { return bit_cast<f8_t>(binary_lowest); }
__host__ __device__ static constexpr f8_t QuietNaN() { return bit_cast<f8_t>(binary_qnan); }
};
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/data_type.hpp"
namespace ck {
// fp8 rounding modes
// use standard for rounding to nearest, the faster one
// use stochastic for stochastic rounding, helps to avoid error accumulation
enum class f8_rounding_mode
{
standard,
stochastic
};
} // namespace ck
namespace ck::utils {
namespace {
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
int exponent;
uint32_t head, mantissa, sign;
// nan code is same for float and half
constexpr uint8_t nan_code = 0x80;
constexpr uint32_t nan_mask = is_half ? 0x7C00 : 0x7F800000;
// convert to bitwise
typedef typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type
T_bitwise;
T_bitwise x_bitwise = *(reinterpret_cast<T_bitwise*>(&x));
// unpack the input, depends on datatype
if constexpr(is_float)
{
head = x_bitwise & 0xFF800000;
mantissa = x_bitwise & 0x7FFFFF;
exponent = (head >> type_mant) & 0xFF;
sign = head >> (type_exp + type_mant);
}
else if constexpr(is_half)
{
head = x_bitwise & 0xFC00;
mantissa = x_bitwise & 0x3FF;
exponent = (head >> type_mant) & 0x1F;
sign = head >> (type_exp + type_mant);
}
uint32_t signed_inf = (sign << (type_exp + type_mant)) + (((1 << type_exp) - 1) << type_mant);
uint32_t drop_mask = (1 << (type_mant - f8_mant)) - 1;
constexpr int max_exp = (1 << f8_exp) - (negative_zero_nan ? 1 : 2);
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
if constexpr(negative_zero_nan)
{
if((x_bitwise & nan_mask) == nan_mask)
return nan_code;
}
else
{
if((x_bitwise & nan_mask) == nan_mask)
return signed_inf + (mantissa != 0 ? 1 : 0);
}
// check if x is 0.0
if(x_bitwise == 0)
return 0;
exponent -= exp_low_cutoff - 1;
if(exponent <= 0)
drop_mask = (1 << (type_mant - f8_mant + 1 - exponent)) - 1;
mantissa += 1 << type_mant;
// apply random number if needed
mantissa += (stoch ? rng : mantissa) & drop_mask;
if(mantissa >= (2 << type_mant))
{
mantissa >>= 1;
exponent++;
}
mantissa >>= (type_mant - f8_mant);
// check negative exponent
if(exponent <= 0)
{
if(x_bitwise == 0)
return 0;
else
{
// subnormal range; represented by a subnormal float8 (exponent 0)
// and involves loss of accuracy
mantissa >>= 1 - exponent;
exponent = 0;
}
}
// above range: quantize to maximum possible float of the same sign
else if(exponent > max_exp)
{
if(clip)
{
mantissa = (1 << f8_mant) - 1;
exponent = max_exp;
}
else
{
return signed_inf;
}
}
// check if x is 0.0 or -0.0
if(exponent == 0 && mantissa == 0)
return negative_zero_nan ? 0 : (sign << (f8_exp + f8_mant));
mantissa &= (1 << f8_mant) - 1;
return (sign << (f8_exp + f8_mant)) | (exponent << f8_mant) | mantissa;
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T run_cast_from_f8(f8_t x)
{
// check data type
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
// fp8 exponent/mantissa layout
constexpr int f8_exp = 4;
constexpr int f8_mant = 3;
// resulting type exponent/mantissa layout
constexpr int type_exp = is_half ? 5 : 8;
constexpr int type_mant = is_half ? 10 : 23;
// prepare the codes
constexpr uint8_t nan_code = 0x80;
T fInf, fNegInf, fNaN, fNeg0;
if constexpr(is_half)
{
constexpr uint16_t ihInf = 0x7C00;
constexpr uint16_t ihNegInf = 0xFC00;
constexpr uint16_t ihNaN = 0x7C01;
constexpr uint16_t ihNeg0 = 0x8000;
fInf = *(reinterpret_cast<const half_t*>(&ihInf));
fNegInf = *(reinterpret_cast<const half_t*>(&ihNegInf));
fNaN = *(reinterpret_cast<const half_t*>(&ihNaN));
fNeg0 = *(reinterpret_cast<const half_t*>(&ihNeg0));
}
else if constexpr(is_float)
{
constexpr uint32_t ifInf = 0x7F800000;
constexpr uint32_t ifNegInf = 0xFF800000;
constexpr uint32_t ifNaN = 0x7F800001;
constexpr uint32_t ifNeg0 = 0x80000000;
fInf = *(reinterpret_cast<const float*>(&ifInf));
fNegInf = *(reinterpret_cast<const float*>(&ifNegInf));
fNaN = *(reinterpret_cast<const float*>(&ifNaN));
fNeg0 = *(reinterpret_cast<const float*>(&ifNeg0));
}
// unpack the input
uint32_t sign = x >> (f8_exp + f8_mant);
uint32_t mantissa = x & ((1 << f8_mant) - 1);
int exponent = (x & 0x7F) >> f8_mant;
constexpr int exp_low_cutoff =
(1 << (type_exp - 1)) - (1 << (f8_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
typename std::conditional<std::is_same<T, half_t>::value, uint16_t, uint32_t>::type retval;
if constexpr(negative_zero_nan)
{
if(x == nan_code)
return fNaN;
}
else
{
if(x == nan_code)
return fNeg0;
if(exponent == ((1 << f8_exp) - 1))
return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN;
}
// subnormal input
if(exponent == 0)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int sh = 1 + __builtin_clz(mantissa) - ((1 + type_exp + type_mant) - f8_mant);
mantissa <<= sh;
mantissa &= ((1 << f8_mant) - 1);
exponent += 1 - sh;
}
exponent += exp_low_cutoff - 1;
mantissa <<= type_mant - f8_mant;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if(exponent <= 0)
{
mantissa |= 1 << type_mant;
mantissa >>= 1 - exponent;
exponent = 0;
}
retval = (sign << (type_exp + type_mant)) | (exponent << type_mant) | mantissa;
return *(reinterpret_cast<const T*>(&retval));
}
} // namespace
template <typename T, bool negative_zero_nan, bool clip, bool stoch>
__host__ __device__ f8_t cast_to_f8(T x, uint32_t rng)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "Only half and float can be casted to f8.");
return run_cast_to_f8<T, negative_zero_nan, clip, stoch>(x, rng);
}
template <typename T, bool negative_zero_nan>
__host__ __device__ T cast_from_f8(f8_t x)
{
// check datatype
constexpr bool is_half = std::is_same<T, half_t>::value;
constexpr bool is_float = std::is_same<T, float>::value;
static_assert(is_half || is_float, "only half and float are supported.");
// check if x is 0.0
if(x == 0)
return static_cast<T>(0);
return run_cast_from_f8<T, negative_zero_nan>(x);
}
} // namespace ck::utils
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck {
template <index_t N>
static constexpr __device__ index_t get_shift()
{
return (get_shift<N / 2>() + 1);
};
template <>
constexpr __device__ index_t get_shift<1>()
{
return (0);
}
} // namespace ck
......@@ -3,6 +3,7 @@
#pragma once
#include "data_type.hpp"
#include "type_convert.hpp"
namespace ck {
......@@ -12,13 +13,13 @@ __device__ void inner_product(const TA& a, const TB& b, TC& c);
template <>
__device__ void inner_product<float, float, float>(const float& a, const float& b, float& c)
{
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
#if CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_MAC_F32)
asm volatile("\n \
v_mac_f32 %0, %1, %2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#elif CK_USE_AMD_INNER_PRODUCT_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
#elif CK_USE_AMD_V_MAC_INLINE_ASM && defined(CK_USE_AMD_V_FMAC_F32)
asm volatile("\n \
v_fmac_f32 %0, %1, %2 \n \
"
......@@ -75,22 +76,26 @@ template <>
__device__ void inner_product<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
{
#if defined(CK_USE_AMD_V_DOT2_F32_F16)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm volatile("\n \
v_dot2_f32_f16 %0, %1, %2, %0\n \
s_nop 2 \n \
"
: "=v"(c)
: "v"(a), "v"(b), "0"(c));
#else
c = __builtin_amdgcn_sdot2(a, b, c, false);
c = __builtin_amdgcn_fdot2(a, b, c, false);
#endif
#else
const vector_type<half_t, 2> a_vector{a};
const vector_type<half_t, 2> b_vector{b};
static_for<0, 2, 1>{}([&](auto i) {
c += type_convert<int32_t>(a_vector.AsType<half_t>()[i]) *
type_convert<int32_t>(b_vector.AsType<half_t>()[i]);
c += type_convert<float>(a_vector.AsType<half_t>()[i]) *
type_convert<float>(b_vector.AsType<half_t>()[i]);
});
#endif
}
......@@ -162,9 +167,13 @@ __device__ void
inner_product<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a, const int8x4_t& b, int32_t& c)
{
#if defined(CK_USE_AMD_V_DOT4_I32_I8)
#if CK_USE_AMD_INNER_PRODUCT_INLINE_ASM
#if CK_USE_AMD_V_DOT_INLINE_ASM
// Use 3 x s_nop to avoid hazard (mi200 cdna2 isa page 47
// https://www.amd.com/system/files/TechDocs/instinct-mi200-cdna2-instruction-set-architecture.pdf
// ) s_nop with parameter 2 is equal to 3 x s_nop
asm volatile("\n \
v_dot4_i32_i8 %0, %1, %2, %0\n \
s_nop 2 \n \
"
: "=v"(c)
: "v"(bit_cast<int32_t>(a)), "v"(bit_cast<int32_t>(b)), "0"(c));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
namespace ck {
namespace dpp8 {
template <int SrcLaneIdx>
__device__ void inline_v_dot2c_dpp8_instr(const half2_t& a, const half2_t& b, float& c);
// clang-format off
template <>
__device__ void inline_v_dot2c_dpp8_instr<0>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[0, 0, 0, 0, 0, 0, 0, 0]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<1>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[1, 1, 1, 1, 1, 1, 1, 1]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<2>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[2, 2, 2, 2, 2, 2, 2, 2]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<3>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[3, 3, 3, 3, 3, 3, 3, 3]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<4>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[4, 4, 4, 4, 4, 4, 4, 4]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<5>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[5, 5, 5, 5, 5, 5, 5, 5]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<6>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[6, 6, 6, 6, 6, 6, 6, 6]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
template <>
__device__ void inline_v_dot2c_dpp8_instr<7>(const half2_t& a, const half2_t& b, float& c){
asm volatile("\n v_dot2c_f32_f16_dpp %0, %1, %2 dpp8:[7, 7, 7, 7, 7, 7, 7, 7]" : "=v"(c) : "v"(a), "v"(b), "0"(c));
}
// clang-format on
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted as inline assembly.
*/
template <int SrcLaneIdx, bool ShareA>
__device__ void inline_v_dot2c_dpp8(const half2_t& a, const half2_t& b, float& c)
{
static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
"DPP8 src broadcast lane out of range <0, 7>.");
if constexpr(ShareA)
{
inline_v_dot2c_dpp8_instr<SrcLaneIdx>(a, b, c);
}
else
{
inline_v_dot2c_dpp8_instr<SrcLaneIdx>(b, a, c);
}
}
/**
* DPP8 instrinsics expects to get an integer mask, hardcoding integers for specific broadcast
* patters.
*/
constexpr std::array<int, dpp8::lane_group_size> IntrinsicMaskDpp8 = {
0, // 0, 0, 0, 0, 0, 0, 0, 0
2396745, // 1, 1, 1, 1, 1, 1, 1, 1
4793490, // 2, 2, 2, 2, 2, 2, 2, 2
7190235, // 3, 3, 3, 3, 3, 3, 3, 3
9586980, // 4, 4, 4, 4, 4, 4, 4, 4
11983725, // 5, 5, 5, 5, 5, 5, 5, 5
14380470, // 6, 6, 6, 6, 6, 6, 6, 6
16777215, // 7, 7, 7, 7, 7, 7, 7, 7
};
/**
* Returns DPP8 sel modifier as an integer required for the intrinsic instruction.
*/
template <int SrcLaneIdx>
constexpr int get_dpp_sel_mask_broadcast()
{
static_assert(SrcLaneIdx >= 0 && SrcLaneIdx < dpp8::lane_group_size,
"DPP8 src broadcast lane out of range <0, 7>.");
return IntrinsicMaskDpp8[SrcLaneIdx];
}
template <int SrcLaneIdx>
__device__ void intrinsic_fdot2_impl(const half2_t& a, const half2_t& b, float& c)
{
constexpr int sel_mask = get_dpp_sel_mask_broadcast<SrcLaneIdx>();
const half2_t val_from_other_lane =
bit_cast<half2_t>(__builtin_amdgcn_mov_dpp8(bit_cast<int>(a), sel_mask));
c = __builtin_amdgcn_fdot2(val_from_other_lane, b, c, false);
}
/**
* Dot product of two vectors using `v_dot` instruction with DPP8 submitted using intrinsics.
*/
template <int SrcLaneIdx, bool ShareA>
__device__ void intrinsic_fdot2(const half2_t& a, const half2_t& b, float& c)
{
if constexpr(ShareA)
{
intrinsic_fdot2_impl<SrcLaneIdx>(a, b, c);
}
else
{
intrinsic_fdot2_impl<SrcLaneIdx>(b, a, c);
}
}
/**
* Dot product of two input vectors `a`, `b` using `v_dot` instructions with DPP modifier.
*
* DPP modifier allows us to share one of the vectors between lanes in a lane group.
* When `ShareA` is set, instruction uses vector `a` from lane `SrcLaneIdx` from the same
* lane group (8 lanes per lane group in DPP8). When `ShareA` is not set, vector `b` is shared.
* Note that all the threads in a lane group uses the same vector - broadcast pattern.
*
* `SrcLaneIdx` must be in range from 0 to 7.
*/
template <typename TA, typename TB, typename TC, int SrcLaneIdx, bool ShareA>
__device__ void inner_product_dpp(const TA& a, const TB& b, TC& c)
{
#if CK_USE_AMD_V_DOT_DPP8_INLINE_ASM
inline_v_dot2c_dpp8<SrcLaneIdx, ShareA>(a, b, c);
#else
intrinsic_fdot2<SrcLaneIdx, ShareA>(a, b, c);
#endif
}
} // namespace dpp8
} // namespace ck
......@@ -157,4 +157,76 @@ struct MagicDivision
}
};
struct MDiv
{
// 1 dword -> 3 dword storage
uint32_t divisor;
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv(uint32_t divisor_) : divisor(divisor_)
{
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ MDiv() : divisor(0), multiplier(0), shift(0) {}
__host__ __device__ void update(uint32_t divisor_)
{
divisor = divisor_;
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor);
}
__host__ __device__ uint32_t get() const { return divisor; }
};
struct MDiv2
{
// 1 dword -> 2 dword storage, divisor need compute from runtime
uint32_t multiplier;
uint32_t shift; // TODO: 8 bit is enough
// prefer construct on host
__host__ __device__ MDiv2(uint32_t divisor_)
{
auto tmp = MagicDivision::CalculateMagicNumbers(divisor_);
multiplier = tmp[Number<0>{}];
shift = tmp[Number<1>{}];
}
__host__ __device__ MDiv2() : multiplier(0), shift(0) {}
__host__ __device__ uint32_t div(uint32_t dividend_) const
{
return MagicDivision::DoMagicDivision(dividend_, multiplier, shift);
}
__host__ __device__ void
divmod(uint32_t dividend_, uint32_t divisor_, uint32_t& quotient_, uint32_t& remainder_) const
{
quotient_ = div(dividend_);
remainder_ = dividend_ - (quotient_ * divisor_);
}
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment