Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
......@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return make_merge_transform_v2_magic_division(low_lengths);
#else
return make_merge_transform_v1_carry_check(low_lengths);
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
{
return Merge_v1_carry_check<LowLengths>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{
#if 1
return Merge_v2_magic_division<LowLengths>{low_lengths};
#else
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif
}
template <typename LowLengths>
......@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
return Merge_v3_division_mod<LowLengths>{low_lengths};
}
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths,
......@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx};
}
template <typename UpperIndex>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx)
template <typename UpLengths>
__host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{
return Insert<UpperIndex>{up_idx};
return Replicate<UpLengths>{up_lengths};
}
template <typename LowLength, typename SliceBegin, typename SliceEnd>
......@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
}
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
template <typename Modulus, typename UpLength>
__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length)
{
return Modulo<Modulus, UpLength>{modulus, up_length};
}
template <typename LowLengths, typename RightShift>
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return Xor<LowLengths, RightShift>{low_lengths, right_shift};
}
} // namespace ck
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
namespace ck {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct TensorCoordinate
: public TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>
{
using Base = TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) : Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
__host__ __device__ constexpr TensorCoordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
__host__ __device__ constexpr auto GetIndex() const { return Base::GetTopIndex(); }
__host__ __device__ constexpr index_t GetOffset() const
{
return Base::GetBottomIndex()[Number<0>{}];
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const
{
return Base::GetHiddenIndex();
}
__host__ __device__ auto& GetHiddenIndex() { return Base::GetHiddenIndex(); }
};
template <typename TensorDesc, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return TensorCoordinate<TensorDesc::GetNumOfHiddenDimension(),
remove_cvref_t<decltype(TensorDesc::GetTopDimensionHiddenIds())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
__host__ __device__ constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck
......@@ -6,7 +6,7 @@
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp"
#include "ck/utility/multi_index.hpp"
#include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
......
......@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0);
return make_multi_index(0, 0, waveId_m, WMMA_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
......@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0);
return make_multi_index(0, 0, waveId_n, WMMA_b_idx, 0);
}
template <index_t m0, index_t n0>
......
......@@ -66,7 +66,7 @@ struct BlockwiseSoftmax
reduce::Add,
false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths());
using ThreadClusterLengths_M_K = decltype(to_sequence(ThreadClusterDesc_M_K{}.GetLengths()));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize,
......
......@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>;
#if 1 // debug
__host__ __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1() : threadwise_transfer_{} {}
#endif
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
......
......@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01
const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1),
make_tuple(make_replicate_transform(make_tuple(1)),
make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions
make_tuple(make_replicate_transform(
make_tuple(1)), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......
......@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
make_multi_index(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......
......@@ -796,7 +796,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin
make_multi_index(0, 0, 0, 0)}; // A_origin
auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......@@ -953,7 +953,7 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// works like multi-dimension static_for (static_ford), but provides both the linear
// index as well as n-d index
using Acc0TileIterator = SpaceFillingCurve<
decltype(c_thread_lengths),
decltype(to_sequence(c_thread_lengths)),
typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
false>; // SnakeCurved
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment