Commit 300337cd authored by letaoqin's avatar letaoqin
Browse files

Merge branch 'develop' into jizhan/reduce_threadwise_multi_d

parents f306d02e 02fa2c29
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
namespace ck {
namespace tensor_operation {
/**
* @brief Transform conv bwd weight to gemm v2
*
* This version does following things:
* 1. Merge KBatch with K0 to align descriptor with universal gemm
* 2. Merge Batch with M and N dimension. It allows to increase compute in
* case of small M and N. It also allows to vector load and store in case of
* K = 1, C = 1 and NHWGC layout.
*/
template <index_t NDimSpatial,
index_t MPerBlock,
index_t NPerBlock,
index_t GemmK1Number,
index_t K0PerBlock,
index_t NumBatchToMerge,
device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
struct TransformConvBwdWeightToGemmV2
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_out_grid_desc(const index_t N,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& output_strides)
{
const index_t BatchStride = output_strides[0];
const index_t WoStride = output_strides[4];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumBatchToMerge, K),
make_tuple(WoStride, BatchStride, KStride));
}
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_in_grid_desc(const index_t N,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& input_strides)
{
const index_t BatchStride = input_strides[0];
const index_t NStride = input_strides[1];
const index_t HiStride = input_strides[3];
const index_t WiStride = input_strides[4];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumBatchToMerge, C),
make_tuple(WiStride, BatchStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Hi, Wi, NumBatchToMerge, C),
make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
}
}
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const index_t K,
const index_t Y,
const index_t X,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
const auto XStride = weights_strides[4];
const auto BatchStride = weights_strides[0];
// Add NumBatchToMerge for Batch+M dimension and, 1 as a placehorder
// for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumBatchToMerge, K, Y * X, 1, C),
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
// Padd 1 to NumBatchToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(K),
make_pass_through_transform(Y * X),
make_pad_transform(1, 0, NumBatchToMerge - 1),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
NumBatchToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
make_pass_through_transform(K),
make_pass_through_transform(Y * X),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
make_merge_transform(make_tuple(Y * X, NumBatchToMerge, C))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_out_grid_desc(const index_t N,
const index_t Do,
const index_t Ho,
const index_t Wo,
const index_t K,
const std::array<index_t, NDimSpatial + 3>& output_strides)
{
const index_t BatchStride = output_strides[0];
const index_t WoStride = output_strides[5];
const auto KStride = Number<1>{};
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumBatchToMerge, K),
make_tuple(WoStride, BatchStride, KStride));
}
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_in_grid_desc(const index_t N,
const index_t Di,
const index_t Hi,
const index_t Wi,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& input_strides)
{
const index_t BatchStride = input_strides[0];
const index_t NStride = input_strides[1];
const index_t DiStride = input_strides[3];
const index_t HiStride = input_strides[4];
const index_t WiStride = input_strides[5];
const auto CStride = input_strides[2];
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumBatchToMerge, C),
make_tuple(WiStride, BatchStride, CStride));
}
else
{
return make_naive_tensor_descriptor(
make_tuple(N, Di, Hi, Wi, NumBatchToMerge, C),
make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride));
}
}
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
constexpr static auto
make_wei_grid_desc(const index_t K,
const index_t Z,
const index_t Y,
const index_t X,
const index_t C,
const std::array<index_t, NDimSpatial + 3>& weights_strides)
{
const auto CStride = Number<1>{};
const auto KStride = weights_strides[1];
const auto XStride = weights_strides[5];
const auto BatchStride = weights_strides[0];
// Add NumBatchToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
const auto desc = make_naive_tensor_descriptor(
make_tuple(NumBatchToMerge, K, Z * Y * X, 1, C),
make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
// Padd 1 to NumBatchToMerge
const auto padded_desc = transform_tensor_descriptor(
desc,
make_tuple(make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(K),
make_pass_through_transform(Z * Y * X),
make_pad_transform(1, 0, NumBatchToMerge - 1),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
// We need only matrices from diagonal. Xor returns 0 for the same
// values. So if matrices is not on diagonal then it will be stored in padding.
// To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
static_assert(NumBatchToMerge == 1 || NumBatchToMerge == 2 || NumBatchToMerge == 4 ||
NumBatchToMerge == 8 || NumBatchToMerge == 16 || NumBatchToMerge == 32 ||
NumBatchToMerge == 64);
const auto unmerged_padded_desc = transform_tensor_descriptor(
padded_desc,
make_tuple(make_xor_transform(make_tuple(NumBatchToMerge, NumBatchToMerge)),
make_pass_through_transform(K),
make_pass_through_transform(Z * Y * X),
make_pass_through_transform(C)),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1>{}, Sequence<2>{}, Sequence<4>{}));
// Merge To M, N
return transform_tensor_descriptor(
unmerged_padded_desc,
make_tuple(make_merge_transform(make_tuple(NumBatchToMerge, K)),
make_merge_transform(make_tuple(Z * Y * X, NumBatchToMerge, C))),
make_tuple(Sequence<0, 1>{}, Sequence<2, 3, 4>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
}
template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const index_t N,
const index_t K,
const index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_strides,
const std::array<index_t, NDimSpatial + 3>& weights_strides,
const std::array<index_t, NDimSpatial + 3>& output_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
{
using namespace ck;
const index_t Hi = input_spatial_lengths[0];
const index_t Wi = input_spatial_lengths[1];
const index_t Ho = output_spatial_lengths[0];
const index_t Wo = output_spatial_lengths[1];
const index_t Y = filter_spatial_lengths[0];
const index_t X = filter_spatial_lengths[1];
const index_t ConvStrideH = conv_filter_strides[0];
const index_t ConvStrideW = conv_filter_strides[1];
const index_t ConvDilationH = conv_filter_dilations[0];
const index_t ConvDilationW = conv_filter_dilations[1];
const index_t InLeftPadH = input_left_pads[0];
const index_t InLeftPadW = input_left_pads[1];
const index_t InRightPadH = input_right_pads[0];
const index_t InRightPadW = input_right_pads[1];
const index_t GemmKTotal = N * Ho * Wo;
const index_t GemmM = K * NumBatchToMerge;
const index_t GemmN = C * X * Y * NumBatchToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
}
else
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(C)),
make_tuple(
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5>{},
Sequence<6>{}));
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Y, X, NumBatchToMerge, C)),
make_merge_transform(make_tuple(N, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 6>{}, Sequence<0, 2, 4>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
}
template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
const index_t N,
const index_t K,
const index_t C,
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
const std::array<index_t, NDimSpatial + 3>& input_strides,
const std::array<index_t, NDimSpatial + 3>& weights_strides,
const std::array<index_t, NDimSpatial + 3>& output_strides,
const std::array<index_t, NDimSpatial>& conv_filter_strides,
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
const std::array<index_t, NDimSpatial>& input_left_pads,
const std::array<index_t, NDimSpatial>& input_right_pads,
const index_t batch_k)
{
using namespace ck;
const index_t Di = input_spatial_lengths[0];
const index_t Hi = input_spatial_lengths[1];
const index_t Wi = input_spatial_lengths[2];
const index_t Do = output_spatial_lengths[0];
const index_t Ho = output_spatial_lengths[1];
const index_t Wo = output_spatial_lengths[2];
const index_t Z = filter_spatial_lengths[0];
const index_t Y = filter_spatial_lengths[1];
const index_t X = filter_spatial_lengths[2];
const index_t ConvStrideD = conv_filter_strides[0];
const index_t ConvStrideH = conv_filter_strides[1];
const index_t ConvStrideW = conv_filter_strides[2];
const index_t ConvDilationD = conv_filter_dilations[0];
const index_t ConvDilationH = conv_filter_dilations[1];
const index_t ConvDilationW = conv_filter_dilations[2];
const index_t InLeftPadD = input_left_pads[0];
const index_t InLeftPadH = input_left_pads[1];
const index_t InLeftPadW = input_left_pads[2];
const index_t InRightPadD = input_right_pads[0];
const index_t InRightPadH = input_right_pads[1];
const index_t InRightPadW = input_right_pads[2];
const index_t GemmKTotal = N * Do * Ho * Wo;
const index_t GemmM = K * NumBatchToMerge;
const index_t GemmN = C * Z * X * Y * NumBatchToMerge;
const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
const index_t GemmKBatch = batch_k;
const index_t GemmK0 =
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
K0PerBlock;
const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
if constexpr(ConvBackwardWeightSpecialization ==
device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmN / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
wei_grid_desc);
}
else
{
// A: output tensor
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
out_grid_desc,
make_tuple(
make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_merge_transform(make_tuple(NumBatchToMerge, GemmM / NumBatchToMerge))),
make_tuple(Sequence<0>{}, Sequence<1, 2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
out_gemmkpad_gemmm_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmM)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// B: input tensor
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
in_grid_desc,
make_tuple(make_pass_through_transform(N),
make_pad_transform(Di, InLeftPadD, InRightPadD),
make_pad_transform(Hi, InLeftPadH, InRightPadH),
make_pad_transform(Wi, InLeftPadW, InRightPadW),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}));
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
in_n_dip_hip_wip_c_grid_desc,
make_tuple(
make_pass_through_transform(N),
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
make_pass_through_transform(NumBatchToMerge),
make_pass_through_transform(C)),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{}),
make_tuple(Sequence<0>{},
Sequence<1, 2>{},
Sequence<3, 4>{},
Sequence<5, 6>{},
Sequence<7>{},
Sequence<8>{}));
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
in_n_z_do_y_ho_x_wo_c_grid_desc,
make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumBatchToMerge, C)),
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
make_tuple(Sequence<1, 3, 5, 7, 8>{}, Sequence<0, 2, 4, 6>{}),
make_tuple(Sequence<1>{}, Sequence<0>{}));
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
in_gemmktotal_gemmn_grid_desc,
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
in_gemmkpad_gemmn_grid_desc,
make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
make_pass_through_transform(GemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
// Padd
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmM, PadGemmM),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
transform_tensor_descriptor(
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
make_tuple(make_pass_through_transform(GemmKBatch * GemmK0),
make_right_pad_transform(GemmN, PadGemmN),
make_pass_through_transform(GemmK1Number)),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
const auto wei_gemmm_gemmn_pad_grid_desc =
transform_tensor_descriptor(wei_grid_desc,
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
make_right_pad_transform(GemmN, PadGemmN)),
make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}));
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
wei_gemmm_gemmn_pad_grid_desc);
}
} // function end
};
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "data_type.hpp" #include "data_type.hpp"
...@@ -297,6 +297,17 @@ enum struct AmdBufferCoherenceEnum ...@@ -297,6 +297,17 @@ enum struct AmdBufferCoherenceEnum
GLC = 1, GLC = 1,
SLC = 2, SLC = 2,
GLC_SLC = 3, GLC_SLC = 3,
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
WAVE_NT0 = 0,
WAVE_NT1 = 2,
GROUP_NT0 = 1,
GROUP_NT1 = 3,
DEVICE_NT0 = 8,
DEVICE_NT1 = 10,
SYSTEM_NT0 = 9,
SYSTEM_NT1 = 11,
}; };
template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence> template <index_t N, AmdBufferCoherenceEnum coherence = AmdBufferCoherenceEnum::DefaultCoherence>
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#pragma once #pragma once
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace ck {
enum struct BlockGemmPipelineScheduler
{
Intrawave,
Interwave,
};
enum struct TailNumber
{
// Single / Double buffer pipeline
Odd,
Even,
// Long prefetch pipeline, up to 8
One,
Two,
Three,
Four,
Five,
Six,
Seven,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full,
};
template <index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t ABufferLoadWidth,
index_t BBufferLoadWidth,
index_t ALDSWriteWidth,
index_t BLDSWriteWidth,
index_t ALDSReadWidth,
index_t BLDSReadWidth,
index_t MRepeat,
index_t NRepeat,
index_t MPerXDL,
index_t NPerXDL,
index_t KPerXDL>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
static constexpr index_t A_LDS_Read_Width = ALDSReadWidth;
static constexpr index_t B_LDS_Read_Width = BLDSReadWidth;
static constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
static constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
static constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
static constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
static constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
static constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr auto Print()
{
printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
BlockSize,
WaveSize,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n",
A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num);
}
};
} // namespace ck
...@@ -163,6 +163,13 @@ struct scalar_type<bf8_t> ...@@ -163,6 +163,13 @@ struct scalar_type<bf8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
template <>
struct scalar_type<bool>
{
using type = bool;
static constexpr index_t vector_size = 1;
};
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP #ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP
...@@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
__syncthreads(); __syncthreads();
} }
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
} // namespace debug } // namespace debug
} // namespace ck } // namespace ck
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <cstring>
#include <string>
#include <string_view>
namespace ck {
namespace internal {
template <typename T>
struct ParseEnvVal
{
};
template <>
struct ParseEnvVal<bool>
{
static bool parse_env_var_value(const char* vp)
{
std::string value_env_str{vp};
for(auto& c : value_env_str)
{
if(std::isalpha(c) != 0)
{
c = std::tolower(static_cast<unsigned char>(c));
}
}
if(value_env_str == "disable" || value_env_str == "disabled" || value_env_str == "0" ||
value_env_str == "no" || value_env_str == "off" || value_env_str == "false")
{
return false;
}
else if(value_env_str == "enable" || value_env_str == "enabled" || value_env_str == "1" ||
value_env_str == "yes" || value_env_str == "on" || value_env_str == "true")
{
return true;
}
else
{
throw std::runtime_error("Invalid value for env variable");
}
return false; // shouldn't reach here
}
};
// Supports hexadecimals (with leading "0x"), octals (if prefix is "0") and decimals (default).
// Returns 0 if environment variable is in wrong format (strtoull fails to parse the string).
template <>
struct ParseEnvVal<uint64_t>
{
static uint64_t parse_env_var_value(const char* vp) { return std::strtoull(vp, nullptr, 0); }
};
template <>
struct ParseEnvVal<std::string>
{
static std::string parse_env_var_value(const char* vp) { return std::string{vp}; }
};
template <typename T>
struct EnvVar
{
private:
T value{};
bool is_unset = true;
public:
const T& GetValue() const { return value; }
bool IsUnset() const { return is_unset; }
void Unset() { is_unset = true; }
void UpdateValue(const T& val)
{
is_unset = false;
value = val;
}
explicit EnvVar(const char* const name, const T& def_val)
{
// NOLINTNEXTLINE (concurrency-mt-unsafe)
const char* vp = std::getenv(name);
if(vp != nullptr) // a value was provided
{
is_unset = false;
value = ParseEnvVal<T>::parse_env_var_value(vp);
}
else // no value provided, use default value
{
value = def_val;
}
}
};
} // end namespace internal
// static inside function hides the variable and provides
// thread-safety/locking
// Used in global namespace
#define CK_DECLARE_ENV_VAR(name, type, default_val) \
namespace ck::env { \
struct name \
{ \
static_assert(std::is_same_v<name, ::ck::env::name>, \
"CK_DECLARE_ENV* must be used in the global namespace"); \
using value_type = type; \
static ck::internal::EnvVar<type>& Ref() \
{ \
static ck::internal::EnvVar<type> var{#name, default_val}; \
return var; \
} \
}; \
}
#define CK_DECLARE_ENV_VAR_BOOL(name) CK_DECLARE_ENV_VAR(name, bool, false)
#define CK_DECLARE_ENV_VAR_UINT64(name) CK_DECLARE_ENV_VAR(name, uint64_t, 0)
#define CK_DECLARE_ENV_VAR_STR(name) CK_DECLARE_ENV_VAR(name, std::string, "")
#define CK_ENV(name) \
ck::env::name {}
template <class EnvVar>
inline const std::string& EnvGetString(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, std::string>);
return EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsEnabled(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
return !EnvVar::Ref().IsUnset() && EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsDisabled(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, bool>);
return !EnvVar::Ref().IsUnset() && !EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline uint64_t EnvValue(EnvVar)
{
static_assert(std::is_same_v<typename EnvVar::value_type, uint64_t>);
return EnvVar::Ref().GetValue();
}
template <class EnvVar>
inline bool EnvIsUnset(EnvVar)
{
return EnvVar::Ref().IsUnset();
}
template <class EnvVar>
void EnvUnset(EnvVar)
{
EnvVar::Ref().Unset();
}
/// updates the cached value of an environment variable
template <typename EnvVar, typename ValueType>
void UpdateEnvVar(EnvVar, const ValueType& val)
{
static_assert(std::is_same_v<typename EnvVar::value_type, ValueType>);
EnvVar::Ref().UpdateValue(val);
}
template <typename EnvVar>
void UpdateEnvVar(EnvVar, const std::string_view& val)
{
EnvVar::Ref().UpdateValue(
ck::internal::ParseEnvVal<typename EnvVar::value_type>::parse_env_var_value(val.data()));
}
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
namespace ck {
static __global__ void flush_icache()
{
asm __volatile__("s_icache_inv \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t" ::
:);
}
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once #pragma once
...@@ -24,3 +25,14 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -24,3 +25,14 @@ constexpr LoopScheduler make_default_loop_scheduler()
} }
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include <ostream>
#include "ck/utility/integral_constant.hpp" #include "ck/utility/integral_constant.hpp"
#include "ck/utility/type.hpp" #include "ck/utility/type.hpp"
#include "ck/utility/functional.hpp" #include "ck/utility/functional.hpp"
...@@ -897,3 +899,14 @@ template <index_t NSize, index_t I> ...@@ -897,3 +899,14 @@ template <index_t NSize, index_t I>
using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type; using uniform_sequence_gen_t = typename uniform_sequence_gen<NSize, I>::type;
} // namespace ck } // namespace ck
template <ck::index_t... Is>
std::ostream& operator<<(std::ostream& os, const ck::Sequence<Is...>)
{
using S = ck::Sequence<Is...>;
os << "{";
ck::static_for<0, S::Size() - ck::Number<1>{}, 1>{}(
[&](auto i) { os << S::At(i).value << ", "; });
os << S::At(S::Size() - ck::Number<1>{}).value << "}";
return os;
}
...@@ -10,10 +10,12 @@ namespace ck { ...@@ -10,10 +10,12 @@ namespace ck {
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\ // asm volatile("\
s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
s_barrier \ // s_barrier \
" ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else #else
__syncthreads(); __syncthreads();
#endif #endif
......
...@@ -162,4 +162,83 @@ struct transpose_vectors<int8_t, NX, NY> ...@@ -162,4 +162,83 @@ struct transpose_vectors<int8_t, NX, NY>
} }
}; };
// transpose f8 4x4
__device__ void transpose_f8_4x4(const f8x4_t& x0,
const f8x4_t& x1,
const f8x4_t& x2,
const f8x4_t& x3,
f8x4_t& y0,
f8x4_t& y1,
f8x4_t& y2,
f8x4_t& y3)
{
int32_t t0, t1;
int32_t z0, z1, z2, z3;
constexpr int32_t m0 = 0x05010400;
constexpr int32_t m1 = 0x05040100;
constexpr int32_t m2 = 0x07060302;
constexpr int32_t m3 = 0x07030602;
// ex: v_perm_b32(0x 11 22 33 44, 0x 55 66 77 88, 0x 05 01 04 00) -> 0x33774488
// -- -- -- -- -- -- -- -- - - - -
// index 7 6 5 4 3 2 1 0 33 77 44 88
// index is reversed because of little endianness (least significant bits first)
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m0);
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m0);
z0 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z1 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
t0 = __builtin_amdgcn_perm(bit_cast<int32_t>(x1), bit_cast<int32_t>(x0), m3);
t1 = __builtin_amdgcn_perm(bit_cast<int32_t>(x3), bit_cast<int32_t>(x2), m3);
z2 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m1);
z3 = __builtin_amdgcn_perm(bit_cast<int32_t>(t1), bit_cast<int32_t>(t0), m2);
y0 = bit_cast<f8x4_t>(z0);
y1 = bit_cast<f8x4_t>(z1);
y2 = bit_cast<f8x4_t>(z2);
y3 = bit_cast<f8x4_t>(z3);
}
template <index_t NX, index_t NY>
struct transpose_vectors<f8_t, NX, NY>
{
// we got [NY * NX] amount of S data to be transposed
static constexpr index_t s_per_x = NY;
static constexpr index_t s_per_y = NX;
using S = f8_t;
using VX = vector_type<f8_t, s_per_x>;
using VY = vector_type<f8_t, s_per_y>;
__device__ void operator()(const StaticallyIndexedArray<const VX&, NX>& vx_tuple,
StaticallyIndexedArray<VY&, NY>& vy_tuple)
{
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I4 = Number<4>{};
static_assert((NX % 4 == 0 && NY % 4 == 0), "wrong!");
// loop over 4x4 tile and transpose data from vx_tuple into vy_tuple
static_for<0, NY, 4>{}([&](auto iy) {
static_for<0, NX, 4>{}([&](auto ix) {
// reference to 4 f8 data from vx_tuple
const auto& x_s4_0 = vx_tuple[ix].template AsType<f8x4_t>()[iy / I4];
const auto& x_s4_1 = vx_tuple[ix + I1].template AsType<f8x4_t>()[iy / I4];
const auto& x_s4_2 = vx_tuple[ix + I2].template AsType<f8x4_t>()[iy / I4];
const auto& x_s4_3 = vx_tuple[ix + I3].template AsType<f8x4_t>()[iy / I4];
// reference to 4 f8 data from vy_tuple
auto& y_s4_0 = vy_tuple(iy).template AsType<f8x4_t>()(ix / I4);
auto& y_s4_1 = vy_tuple(iy + I1).template AsType<f8x4_t>()(ix / I4);
auto& y_s4_2 = vy_tuple(iy + I2).template AsType<f8x4_t>()(ix / I4);
auto& y_s4_3 = vy_tuple(iy + I3).template AsType<f8x4_t>()(ix / I4);
// transpose
transpose_f8_4x4(x_s4_0, x_s4_1, x_s4_2, x_s4_3, y_s4_0, y_s4_1, y_s4_2, y_s4_3);
});
});
}
};
} // namespace ck } // namespace ck
...@@ -40,21 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value; ...@@ -40,21 +40,10 @@ inline constexpr bool is_pointer_v = std::is_pointer<T>::value;
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false> template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y bit_cast(const X& x) __host__ __device__ constexpr Y bit_cast(const X& x)
{ {
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST static_assert(__has_builtin(__builtin_bit_cast), "");
Y y; static_assert(sizeof(X) == sizeof(Y), "Do not support cast between different size of type");
__builtin_memcpy(&y, &x, sizeof(X)); return __builtin_bit_cast(Y, x);
return y;
#else
union AsType
{
X x;
Y y;
};
return AsType{x}.y;
#endif
} }
} // namespace ck } // namespace ck
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
#include "ck/utility/random_gen.hpp" #include "ck/utility/random_gen.hpp"
namespace ck { namespace ck {
// Define the common macro for MI300 models // Define the common macro for gfx94x models
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#define __gfx94__ #define __gfx94__
#endif #endif
......
# ck_tile
## concept
`ck_tile` provides a programming model with templated abstractions to enable users to implement performance-critical kernels for machine learning workloads. introduces following basic concepts to help users building your own operator
- tensor coordinate transformation, this is the core concept of layout/index transform abstraction in both compiler time and run time.
- tile-based programming model, including tile-level api and the concept of distributed tensor.
`ck_tile` is independently from the old ck, located under [/include/ck_tile](/include/ck_tile). You don't need to include anything from old CK, `ck_tile` has similiar (indeed almost the same) implementations for users to build operators. We will have a transition period to pull everything from old ck into `ck_tile`, stay tuned.
## component
`ck_tile` is splitted into several componenets including `core`, `host`, `ops/gemm`, `ops/fmha`... each component you only need to include a single header (e.g `#include "ck_tile/core.hpp"`, `#include "ck_tile/ops/fmha.hpp"`) then you are able to use the function/structure inside (different from old `ck`)
**[core]**
`ck_tile/core` contains all the basic data structure and function to build the kernel, you can only include this header and build your own operators that utilizing all the basic building blocks introduced in ck.
`core/container`
- array, store runtime variables with fixed length (tensor index, register buffer, etc...)
- tuple, same as std::tuple, hold different type of data, and one of the solution to achieve multiple buffer.
- sequence, compile time integer sequence used to build various internal structures, or to describe tile size
- other convenient structure build on top of above 3
`core/numeric`
- gpu data type like `fp16_t`, `bf16_t`, `fp8_t`... and the conversion between each other
- constexpr integer similiar to std::integral_constant to be used as compile time integer.
- math functions and numeric utilities
`core/algorithm`
- coordinate transformation system, used to build tensor transform and compile time indexing. This is the core idea introduced in old `ck` to describe how a tensor is build by several basic transform primitives like `merge`/`unmerge`/`embed` etc... and how we indexing into a ND tensor that finally mapped to 1D memory offset.
`core/tensor`
- tensor descriptor, to describe how a ND tensor
- distributed tensor, describe the storage of this tensor, and the distribution of how a collection of threads collaborately work for this tensor.
- tile level API, including `load_tile`, `store_tile`, `shuffle_tile`, `slice_tile`, etc...
**[host]**
`ck_tile/host` contains all the host side utilities to launch a kernel, create the device buffer, and some reference implementations. This can be used to create examples (like that under ck_tile example folder) and simple executable to invoke this kernel, so if you only need `ck_tile` to build your own device library then it's OK to not include this. Based on this, it is recommended to include the specific header you needed under this folder to avoid including unwanted headers (e.g, only include `ck_tile/host/kernel_launch.hpp`), unless you are writing a host executable.
**[ops/gemm, ops/fmha, ops/reduce...]**
our implementation of different device operators.
- warp, warp tile level operator
- block, block tile level operator
- pipeline, pipeline that can achieve a customized tile level mainloop (or epilogue). By switching different pipeline to the kernel template you can have different kind of pipeline optimizations.
- kernel, template interface for users to instantiate a particular kernel
**[ops/epilogue]**
epilogue part of our kernel. We may extend this epilogue part to let users to build their own cutomized epilogues.
## examples
currently we put all ck_tile related example under [/example/ck_tile](/example/ck_tile/) folder. Please check each example's subfolder.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/arch/amd_buffer_addressing.hpp"
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/map.hpp"
#include "ck_tile/core/container/meta_data_buffer.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/span.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/container/thread_buffer.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/tensor/buffer_view.hpp"
#include "ck_tile/core/tensor/load_tile.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/shuffle_tile.hpp"
#include "ck_tile/core/tensor/slice_tile.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/store_tile.hpp"
#include "ck_tile/core/tensor/sweep_tile.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp"
#include "ck_tile/core/tensor/tensor_coordinate.hpp"
#include "ck_tile/core/tensor/tensor_descriptor.hpp"
#include "ck_tile/core/tensor/tensor_view.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution_encoding.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/ignore.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/unary_element_function.hpp"
# ck_tile/core #
`ck_tile/core` contains every basic functions and structures to create a GPU kernel using `ck_tile`. User should only include `ck_tile/core.hpp` this single header to use all the functionality. Everything is under `ck_tile` namespace. The coding style under this folder should be similar to `std` (`snake_case` for structure/function, Camel for template types...)
```
algorithm/
coordinate transform and some other reusable algorithm
arch/
contains some basic device building block like mma, buffer addressing, etc...
container/
contains basic container data structure, array/sequence/tuple/...
numeric/
data type, and data type related math
tensor/
tensor descriptors and tile level API
utility/
other utility function for both host/device
```
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/algorithm/coordinate_transform.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename Lengths,
typename ArrangeOrder = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
CK_TILE_HOST_DEVICE constexpr auto make_cluster_descriptor(
const Lengths& lengths,
ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type{})
{
constexpr index_t ndim_low = Lengths::size();
const auto reordered_lengths = container_reorder_given_new2old(lengths, order);
const auto low_lengths = generate_tuple(
[&](auto idim_low) { return reordered_lengths[idim_low]; }, number<ndim_low>{});
const auto transform = make_merge_transform(low_lengths);
constexpr auto low_dim_old_top_ids = ArrangeOrder{};
constexpr auto up_dim_new_top_ids = sequence<0>{};
return make_single_stage_tensor_adaptor(
make_tuple(transform), make_tuple(low_dim_old_top_ids), make_tuple(up_dim_new_top_ids));
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/magic_div.hpp"
namespace ck_tile {
enum struct coord_transform_enum
{
undefined,
pass_through,
pad,
embed,
merge,
unmerge,
replicate,
xor_t,
offset,
};
template <index_t NDimLow, index_t NDimUp>
struct base_transform
{
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::undefined;
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; }
// return safe value for vector length/stride, based on compile-time known only
// variables
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths&,
const LowVectorStrides&)
{
if constexpr(NDimUp > 0)
{
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
return make_tuple(up_vector_lengths, up_vector_strides);
}
else
{
return make_tuple(array<index_t, 0>{}, array<index_t, 0>{});
}
}
};
template <typename LowLength>
struct pass_through : public base_transform<1, 1>
{
static constexpr auto type_enum = coord_transform_enum::pass_through;
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr pass_through() = default;
CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length)
: up_lengths_{make_tuple(low_length)}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::pass_through;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up)
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pass_through{");
//
printf("up_lengths_:");
print(up_lengths_);
//
printf("}");
}
};
template <typename LowLength,
typename LeftPadLength,
typename RightPadLength,
bool SkipIsValidCheck = false>
struct pad : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
UpLengths up_lengths_;
LeftPadLength left_pad_length_;
RightPadLength right_pad_length_;
CK_TILE_HOST_DEVICE constexpr pad() : up_lengths_{}, left_pad_length_{}, right_pad_length_{} {}
CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length,
const LeftPadLength& left_pad_length,
const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
left_pad_length_{left_pad_length},
right_pad_length_{right_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck ||
((idx_up[number<0>{}] >= left_pad_length_) &&
(idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_));
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LeftPadLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
struct left_pad
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
UpLengths up_lengths_;
LeftPadLength left_pad_length_;
CK_TILE_HOST_DEVICE constexpr left_pad() = default;
CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length,
const LeftPadLength& left_pad_length)
: up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_);
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LeftPadLength>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("left_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("left_pad_length_: ");
print(left_pad_length_);
printf("}");
}
};
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
struct right_pad : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
UpLengths up_lengths_;
LowLength low_length_;
RightPadLength right_pad_length_;
CK_TILE_HOST_DEVICE constexpr right_pad() = default;
CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length,
const RightPadLength& right_pad_length)
: up_lengths_{make_tuple(low_length + right_pad_length)},
low_length_{low_length},
right_pad_length_{right_pad_length}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up)
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return SkipIsValidCheck;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& idx_up) const
{
return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_);
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<LowLength>::value &&
ck_tile::is_known_at_compile_time<RightPadLength>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
// TODO: we allow pass through this vector length. If one need per-pixel check,
// should change the guaranteed vector length while creating the tensor view.
// It's up to runtime to check the padding length should be multiple of vector length
return make_tuple(low_vector_lengths, low_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("right_pad{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_pad_length_: ");
print(right_pad_length_);
printf("}");
}
};
// idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
// UpLengths and Coefficients can be either of the followings:
// 1) Tuple of index_t, which is known at run-time, or
// 2) Tuple of number, which is known at compile-time, or
// 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
// at compile-time
template <typename UpLengths,
typename Coefficients,
typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
struct embed : public base_transform<1, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<NDimUp>;
UpLengths up_lengths_;
Coefficients coefficients_;
CK_TILE_HOST_DEVICE constexpr embed() = default;
CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths,
const Coefficients& coefficients)
: up_lengths_{up_lengths}, coefficients_{coefficients}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::embed;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = 0;
static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i];
});
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 1 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
idx_diff_low(number<0>{}) = 0;
static_for<0, NDimUp, 1>{}(
[&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<Coefficients>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("embed{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("coefficients_: ");
print(coefficients_);
printf("}");
}
};
template <typename LowLengths>
struct lambda_merge_generate_MagicDivision_calculate_magic_divisor
{
template <index_t I>
CK_TILE_HOST_DEVICE constexpr auto operator()(number<I> i) const
{
return magic_division::calculate_magic_numbers(LowLengths{}[i]);
}
};
// Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
// of both multi-index and delta of multi-index
// Caution:
// 1. The magic number division implementation being used would produce correct result if the
// dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
// 2. The magic number division for int32_t dividened has not been implemented, the int32_t
// dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
// uint32_t is then used.
// 3. For merge primitive, upper-index is the dividend.
// 4. When upper-index is uint32_t, its value need to be within 31-bit range.
// 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
// non-negative.
template <typename LowLengths>
struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
{
static constexpr index_t NDimLow = LowLengths::size();
using LowerIndex = multi_index<NDimLow>;
using UpperIndex = multi_index<1>;
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
using LowLengthsMagicDivisor = decltype(generate_tuple(
lambda_merge_generate_MagicDivision_calculate_magic_divisor<LowLengths>{},
number<NDimLow>{}));
LowLengths low_lengths_;
LowLengthsMagicDivisor low_lengths_magic_divisor_;
UpLengths up_lengths_;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division() = default;
CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_magic_divisor_{generate_tuple(
[&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
number<NDimLow>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::merge;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[I0];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
magic_division::do_magic_division(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
});
idx_low(number<0>{}) = tmp;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up_new[number<0>{}];
static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
index_t tmp2 =
magic_division::do_magic_division(tmp,
this->low_lengths_magic_divisor_[i][I0],
this->low_lengths_magic_divisor_[i][I1]);
index_t idx_low_old = idx_low[i];
idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
tmp = tmp2;
idx_diff_low(i) = idx_low[i] - idx_low_old;
});
idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{});
idx_low(number<0>{}) = tmp;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowLengths>::value &&
ck_tile::is_known_at_compile_time<LowLengthsMagicDivisor>::value &&
ck_tile::is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("merge_v2_magic_division{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
// Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template <typename LowLengths>
struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
{
static constexpr index_t NDimLow = LowLengths::size();
using LowerIndex = multi_index<NDimLow>;
using UpperIndex = multi_index<1>;
using LowLengthsScan =
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
using UpLengths =
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
LowLengths low_lengths_;
LowLengthsScan low_lengths_scan_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod() = default;
CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
: low_lengths_{low_lengths},
low_lengths_scan_{
container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
{
static_assert(LowerIndex::size() == NDimLow, "wrong!");
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
index_t tmp = idx_up[number<0>{}];
// division and mod
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
idx_low(i) = tmp / this->low_lengths_scan_[i];
tmp %= this->low_lengths_scan_[i];
});
idx_low(number<NDimLow - 1>{}) = tmp;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up_new) const
{
static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
LowIdx::size() == NDimLow && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
constexpr auto INm1 = number<NDimLow - 1>{};
index_t tmp = idx_up_new[I0];
static_for<0, NDimLow - 1, 1>{}([&](auto i) {
const index_t tmp2 = idx_low[i];
idx_low(i) = tmp / this->low_lengths_scan_[i];
idx_diff_low(i) = idx_low[i] - tmp2;
tmp %= this->low_lengths_scan_[i];
});
const index_t tmp2 = idx_low[INm1];
idx_low(INm1) = tmp;
idx_diff_low(INm1) = idx_low[INm1] - tmp2;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowLengths>::value &&
ck_tile::is_known_at_compile_time<LowLengthsScan>::value &&
ck_tile::is_known_at_compile_time<UpLengths>::value;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, 1> up_vector_lengths{-1};
array<index_t, 1> up_vector_strides{-1};
up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Merge_v3_direct_division_mod{");
//
printf("low_lengths_ ");
print(low_lengths_);
printf(", ");
//
printf("low_lengths_scan_ ");
print(low_lengths_scan_);
printf(", ");
//
printf("up_lengths_ ");
print(up_lengths_);
printf("}");
}
};
template <typename UpLengths, bool Use24BitIntegerCalculation>
struct unmerge : public base_transform<1, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<NDimUp>;
using UpLengthsScan =
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
UpLengths up_lengths_;
UpLengthsScan up_lengths_scan_;
CK_TILE_HOST_DEVICE constexpr unmerge() = default;
CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
: up_lengths_{up_lengths},
up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::unmerge;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
if constexpr(!Use24BitIntegerCalculation)
{
idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}(
[&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
}
else
{
idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
static_for<0, NDimUp - 1, 1>{}([&](auto i) {
idx_low(number<0>{}) =
(0x00ffffff & idx_low[number<0>{}]) +
(0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
});
}
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&) const
{
calculate_lower_index(idx_diff_low, idx_diff_up);
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<UpLengthsScan>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE static constexpr auto
calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides)
{
array<index_t, NDimUp> up_vector_lengths{-1};
array<index_t, NDimUp> up_vector_strides{-1};
constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
if constexpr(ck_tile::is_known_at_compile_time<decltype(up_length_last)>::value)
{
if(low_vector_lengths[0] != -1)
{
up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
}
}
up_vector_strides(NDimUp - 1) = low_vector_strides[0];
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("unmerge{");
//
printf("up_lengths_");
print(up_lengths_);
printf(", ");
//
printf("up_lengths_scan_");
print(up_lengths_scan_);
printf("}");
}
};
template <typename LowerIndex>
struct freeze : public base_transform<1, 0>
{
LowerIndex low_idx_;
CK_TILE_HOST_DEVICE constexpr freeze() = default;
CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& /* idx_up */) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 0,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = low_idx_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& /* idx_diff_up */,
LowIdx& /* idx_low */,
const UpIdx& /* idx_up_new */)
{
idx_diff_low(number<0>{}) = 0;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<LowerIndex>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("freeze{");
//
printf("low_idx_: ");
print(low_idx_);
printf("}");
}
};
// insert a dangling upper dimension without lower dimension
template <typename UpperLength>
struct insert : public base_transform<0, 1>
{
using UpLengths = decltype(make_tuple(UpperLength{}));
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr insert() = default;
CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length)
: up_lengths_{make_tuple(up_length)}
{
}
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; }
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::size() == 0 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void
update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
}
CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; }
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpperLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("insert{");
//
print(up_lengths_);
printf("}");
}
};
// replicate the original tensor and create a higher dimensional tensor
template <typename UpLengths>
struct replicate : public base_transform<0, UpLengths::size()>
{
static constexpr index_t NDimUp = UpLengths::size();
CK_TILE_HOST_DEVICE constexpr replicate() = default;
CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
{
}
CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
{
static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void
update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
{
static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp &&
LowIdx::size() == 0 && UpIdx::size() == NDimUp,
"wrong! inconsistent # of dimension");
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("replicate{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
//
UpLengths up_lengths_;
};
template <typename LowLength, typename SliceBegin, typename SliceEnd>
struct slice : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
UpLengths up_lengths_;
SliceBegin slice_begin_;
SliceEnd slice_end_;
CK_TILE_HOST_DEVICE constexpr slice() = default;
CK_TILE_HOST_DEVICE constexpr slice(const LowLength&,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
: up_lengths_{make_tuple(slice_end - slice_begin)},
slice_begin_{slice_begin},
slice_end_{slice_end}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<SliceBegin>::value &&
ck_tile::is_known_at_compile_time<SliceEnd>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("slice{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("slice_begin_: ");
print(slice_begin_);
printf(", ");
//
printf("slice_end_: ");
print(slice_end_);
printf("}");
} // namespace ck
}; // namespace ck
/*
* \brief lower_idx = upper_idx % modulus.
* TODO: Need an improved implementation since the modulo operation is expensive.
*/
template <typename Modulus, typename UpLength>
struct modulo : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(UpLength{}));
Modulus modulus_;
UpLengths up_lengths_;
CK_TILE_HOST_DEVICE constexpr modulo() = default;
CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length)
: modulus_{modulus}, up_lengths_{make_tuple(up_length)}
{
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx& up_idx) const
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
const auto idx_low_old = idx_low;
idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_;
idx_diff_low[I0] = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("Modulus{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf("}");
}
};
// 2D XOR, NOTE: "xor" is a keyword
template <typename LowLengths, typename RightShift>
struct xor_t : public base_transform<2, 2>
{
static constexpr auto type_enum = coord_transform_enum::xor_t;
using LowerIndex = multi_index<2>;
using UpperIndex = multi_index<2>;
using UpLengths = LowLengths;
UpLengths up_lengths_;
RightShift right_shift_;
CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{}, right_shift_{} {}
CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths,
const RightShift& right_shift)
: up_lengths_{low_lengths}, right_shift_{right_shift}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::xor_t;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 2 && UpIdx::size() == 2,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}];
const auto idx_low_1_tmp =
(idx_up[number<1>{}] - idx_up[number<0>{}] * right_shift_) % up_lengths_[number<1>{}];
const auto idx_low_1 =
(idx_low_1_tmp >= 0) ? idx_low_1_tmp : up_lengths_[number<1>{}] + idx_low_1_tmp;
idx_low(number<1>{}) = idx_low_1;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff&,
LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 &&
UpIdx::size() == 2,
"wrong! inconsistent # of dimension");
const auto idx_low_old = idx_low;
calculate_lower_index(idx_low, idx_up);
idx_diff_low = idx_low - idx_low_old;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx& /* idx_up */)
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<RightShift>::value;
}
// MUST be static function
template <typename LowVectorLengths, typename LowVectorStrides>
CK_TILE_HOST_DEVICE constexpr auto calculate_upper_dimension_safe_vector_length_strides(
const LowVectorLengths& low_vector_lengths,
const LowVectorStrides& low_vector_strides) const
{
array<index_t, 2> up_vector_lengths = low_vector_lengths;
array<index_t, 2> up_vector_strides = low_vector_strides;
if constexpr(ck_tile::is_known_at_compile_time<RightShift>::value)
{
if(low_vector_lengths[1] != -1)
{
up_vector_lengths(1) = gcd(low_vector_lengths[1], abs(right_shift_));
}
}
return make_tuple(up_vector_lengths, up_vector_strides);
}
CK_TILE_HOST_DEVICE void print() const
{
printf("xor_t{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("right_shift_: ");
print(right_shift_);
printf("}");
}
};
template <typename LowLength, typename OffsetLength>
struct offset : public base_transform<1, 1>
{
using LowerIndex = multi_index<1>;
using UpperIndex = multi_index<1>;
using UpLengths = decltype(make_tuple(LowLength{}));
UpLengths up_lengths_;
OffsetLength offset_length_;
CK_TILE_HOST_DEVICE constexpr offset() = default;
CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length,
const OffsetLength& offset_length)
: up_lengths_{make_tuple(low_length)}, offset_length_{offset_length}
{
}
CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
{
return coord_transform_enum::offset;
}
CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
template <typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
const UpIdx& idx_up) const
{
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_;
}
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
const UpIdxDiff& idx_diff_up,
LowIdx& idx_low,
const UpIdx&)
{
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
UpIdx::size() == 1,
"wrong! inconsistent # of dimension");
constexpr auto I0 = number<0>{};
idx_diff_low[I0] = idx_diff_up[I0];
idx_low += idx_diff_low;
}
CK_TILE_HOST_DEVICE static constexpr bool
is_valid_upper_index_always_mapped_to_valid_lower_index()
{
return true;
}
template <typename UpIdx>
CK_TILE_HOST_DEVICE constexpr bool
is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx&) const
{
return true;
}
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
{
return ck_tile::is_known_at_compile_time<UpLengths>::value &&
ck_tile::is_known_at_compile_time<OffsetLength>::value;
}
CK_TILE_HOST_DEVICE void print() const
{
printf("offset{");
//
printf("up_lengths_: ");
print(up_lengths_);
printf(", ");
//
printf("offset_length_: ");
print(offset_length_);
printf("}");
}
};
//*******************************************************************************************************
template <typename LowLength>
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
{
return pass_through<LowLength>{low_length};
}
template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_pad_transform(const LowLength& low_length,
const LeftPad& left_pad,
const RightPad& right_pad,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
}
template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_left_pad_transform(const LowLength& low_length,
const LeftPadLength& left_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
}
template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
CK_TILE_HOST_DEVICE constexpr auto
make_right_pad_transform(const LowLength& low_length,
const RightPadLength& right_pad_,
bool_constant<SkipIsValidCheck> = bool_constant<false>{})
{
return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
}
template <typename UpLengths,
typename Coefficients,
typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients)
{
return embed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
return merge_v2_magic_division<LowLengths>{low_lengths};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto
make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
{
return merge_v3_division_mod<LowLengths>{low_lengths};
}
template <typename LowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
CK_TILE_HOST_DEVICE constexpr auto
make_unmerge_transform(const UpLengths& up_lengths,
bool_constant<Use24BitIntegerCalculation> = bool_constant<false>{})
{
return unmerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
}
template <typename LowerIndex>
CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx)
{
return freeze<LowerIndex>{low_idx};
}
template <typename UpperIndex>
CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx)
{
return insert<UpperIndex>{up_idx};
}
template <typename UpLengths>
CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{
return replicate<UpLengths>{up_lengths};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length,
const SliceBegin& slice_begin,
const SliceEnd& slice_end)
{
return slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename Modulus, typename UpLength>
CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length)
{
return modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return xor_t<LowLengths, RightShift>{low_lengths, right_shift};
}
template <typename LowLength, typename OffsetLength>
CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length,
const OffsetLength& offset_length)
{
return offset<LowLength, OffsetLength>{low_length, offset_length};
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/statically_indexed_array.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename TensorLengths,
typename DimAccessOrder,
typename ScalarsPerAccess,
bool SnakeCurved = true> // # of scalars per access in each dimension
struct space_filling_curve
{
static constexpr index_t TensorSize =
reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{});
static_assert(0 < TensorSize,
"space_filling_curve should be used to access a non-empty tensor");
static constexpr index_t nDim = TensorLengths::size();
using Index = multi_index<nDim>;
static constexpr index_t ScalarPerVector =
reduce_on_sequence(ScalarsPerAccess{}, multiplies{}, number<1>{});
static constexpr auto access_lengths = TensorLengths{} / ScalarsPerAccess{};
static constexpr auto dim_access_order = DimAccessOrder{};
static constexpr auto ordered_access_lengths =
container_reorder_given_new2old(access_lengths, dim_access_order);
static constexpr auto to_index_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_merge_transform(ordered_access_lengths)),
make_tuple(typename arithmetic_sequence_gen<0, nDim, 1>::type{}),
make_tuple(sequence<0>{}));
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_access()
{
static_assert(TensorLengths::size() == ScalarsPerAccess::size());
static_assert(TensorLengths{} % ScalarsPerAccess{} ==
typename uniform_sequence_gen<TensorLengths::size(), 0>::type{});
return reduce_on_sequence(TensorLengths{}, multiplies{}, number<1>{}) / ScalarPerVector;
}
template <index_t AccessIdx1dHead, index_t AccessIdx1dTail>
static CK_TILE_HOST_DEVICE constexpr auto get_step_between(number<AccessIdx1dHead>,
number<AccessIdx1dTail>)
{
static_assert(AccessIdx1dHead >= 0 && AccessIdx1dHead < get_num_of_access(),
"1D index out of range");
static_assert(AccessIdx1dTail >= 0 && AccessIdx1dTail < get_num_of_access(),
"1D index out of range");
constexpr auto idx_head = get_index(number<AccessIdx1dHead>{});
constexpr auto idx_tail = get_index(number<AccessIdx1dTail>{});
return idx_tail - idx_head;
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_forward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d < get_num_of_access(), "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d + 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_backward_step(number<AccessIdx1d>)
{
static_assert(AccessIdx1d > 0, "1D index should be larger than 0");
return get_step_between(number<AccessIdx1d>{}, number<AccessIdx1d - 1>{});
}
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr Index get_index(number<AccessIdx1d>)
{
#if 0
/*
* \todo: tensor_adaptor::calculate_bottom_index does NOT return constexpr as expected.
*/
constexpr auto ordered_access_idx = to_index_adaptor.calculate_bottom_index(make_multi_index(number<AccessIdx1d>{}));
#else
constexpr auto access_strides =
container_reverse_exclusive_scan(ordered_access_lengths, multiplies{}, number<1>{});
constexpr auto idx_1d = number<AccessIdx1d>{};
// Given tensor strides \p access_lengths, and 1D index of space-filling-curve, compute the
// idim-th element of multidimensional index.
// All constexpr variables have to be captured by VALUE.
constexpr auto compute_index = [ idx_1d, access_strides ](auto idim) constexpr
{
constexpr auto compute_index_impl = [ idx_1d, access_strides ](auto jdim) constexpr
{
auto res = idx_1d.value;
auto id = 0;
static_for<0, jdim.value + 1, 1>{}([&](auto kdim) {
id = res / access_strides[kdim].value;
res -= id * access_strides[kdim].value;
});
return id;
};
constexpr auto id = compute_index_impl(idim);
return number<id>{};
};
constexpr auto ordered_access_idx = generate_tuple(compute_index, number<nDim>{});
#endif
constexpr auto forward_sweep = [&]() {
statically_indexed_array<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto idim) {
index_t tmp = ordered_access_idx[I0];
static_for<1, idim, 1>{}(
[&](auto j) { tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j]; });
forward_sweep_(idim) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate multi-dim tensor index
auto idx_md = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto idim) {
ordered_idx(idim) =
!SnakeCurved || forward_sweep[idim]
? ordered_access_idx[idim]
: ordered_access_lengths[idim] - 1 - ordered_access_idx[idim];
});
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
ScalarsPerAccess{};
}();
return idx_md;
}
// FIXME: rename this function
template <index_t AccessIdx1d>
static CK_TILE_HOST_DEVICE constexpr auto get_index_tuple_of_number(number<AccessIdx1d>)
{
constexpr auto idx = get_index(number<AccessIdx1d>{});
return generate_tuple([&](auto i) { return number<idx[i]>{}; }, number<nDim>{});
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment