Unverified Commit 0e92deb7 authored by Chao Liu's avatar Chao Liu Committed by GitHub
Browse files

Tile program init bulk PR (#4)



Tile Program init bulk PR

---------
Co-authored-by: default avatarzjing14 <zhangjing14@gmail.com>
Co-authored-by: default avatarPo-Yen, Chen <PoYen.Chen@amd.com>
parent 0077eeb3
...@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng ...@@ -51,32 +51,11 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
return Embed<UpLengths, Coefficients>{up_lengths, coefficients}; return Embed<UpLengths, Coefficients>{up_lengths, coefficients};
} }
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return make_merge_transform_v2_magic_division(low_lengths);
#else
return make_merge_transform_v1_carry_check(low_lengths);
#endif
}
template <typename LowLengths>
__host__ __device__ constexpr auto
make_merge_transform_v1_carry_check(const LowLengths& low_lengths)
{
return Merge_v1_carry_check<LowLengths>{low_lengths};
}
template <typename LowLengths> template <typename LowLengths>
__host__ __device__ constexpr auto __host__ __device__ constexpr auto
make_merge_transform_v2_magic_division(const LowLengths& low_lengths) make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
{ {
#if 1
return Merge_v2_magic_division<LowLengths>{low_lengths}; return Merge_v2_magic_division<LowLengths>{low_lengths};
#else
return Merge_v2r2_magic_division<LowLengths>{low_lengths};
#endif
} }
template <typename LowLengths> template <typename LowLengths>
...@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths) ...@@ -86,6 +65,12 @@ make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
return Merge_v3_division_mod<LowLengths>{low_lengths}; return Merge_v3_division_mod<LowLengths>{low_lengths};
} }
template <typename LowLengths>
__host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths)
{
return make_merge_transform_v2_magic_division(low_lengths);
}
template <typename UpLengths, bool Use24BitIntegerCalculation = false> template <typename UpLengths, bool Use24BitIntegerCalculation = false>
__host__ __device__ constexpr auto make_unmerge_transform( __host__ __device__ constexpr auto make_unmerge_transform(
const UpLengths& up_lengths, const UpLengths& up_lengths,
...@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i ...@@ -100,10 +85,10 @@ __host__ __device__ constexpr auto make_freeze_transform(const LowerIndex& low_i
return Freeze<LowerIndex>{low_idx}; return Freeze<LowerIndex>{low_idx};
} }
template <typename UpperIndex> template <typename UpLengths>
__host__ __device__ constexpr auto make_insert_transform(const UpperIndex& up_idx) __host__ __device__ constexpr auto make_replicate_transform(const UpLengths& up_lengths)
{ {
return Insert<UpperIndex>{up_idx}; return Replicate<UpLengths>{up_lengths};
} }
template <typename LowLength, typename SliceBegin, typename SliceEnd> template <typename LowLength, typename SliceBegin, typename SliceEnd>
...@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len ...@@ -114,17 +99,18 @@ __host__ __device__ constexpr auto make_slice_transform(const LowLength& low_len
return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end}; return Slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
} }
template <typename VectorSize, typename UpLength>
__host__ __device__ constexpr auto make_vectorize_transform(const VectorSize& vector_size,
const UpLength& up_length)
{
return Vectorize<VectorSize, UpLength>{vector_size, up_length};
}
template <typename Modulus, typename UpLength> template <typename Modulus, typename UpLength>
__host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus, __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
const UpLength& up_length) const UpLength& up_length)
{ {
return Modulo<Modulus, UpLength>{modulus, up_length}; return Modulo<Modulus, UpLength>{modulus, up_length};
} }
template <typename LowLengths, typename RightShift>
__host__ __device__ constexpr auto make_xor_transform(const LowLengths& low_lengths,
const RightShift& right_shift)
{
return Xor<LowLengths, RightShift>{low_lengths, right_shift};
}
} // namespace ck } // namespace ck
This diff is collapsed.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_adaptor_coordinate.hpp"
namespace ck {
template <index_t NDimHidden, typename TopDimensionHiddenIds>
struct TensorCoordinate
: public TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>
{
using Base = TensorAdaptorCoordinate<NDimHidden, Sequence<0>, TopDimensionHiddenIds>;
// TODO make these private
static constexpr index_t ndim_top_ = TopDimensionHiddenIds::Size();
using HiddenIndex = MultiIndex<NDimHidden>;
using TopIndex = MultiIndex<ndim_top_>;
public:
__host__ __device__ constexpr TensorCoordinate() = default;
__host__ __device__ constexpr TensorCoordinate(const HiddenIndex& idx_hidden) : Base{idx_hidden}
{
}
// construct from TensorAdaptorCoordinte base class
__host__ __device__ constexpr TensorCoordinate(const Base& adaptor_coord) : Base{adaptor_coord}
{
}
__host__ __device__ constexpr auto GetIndex() const { return Base::GetTopIndex(); }
__host__ __device__ constexpr index_t GetOffset() const
{
return Base::GetBottomIndex()[Number<0>{}];
}
__host__ __device__ constexpr const auto& GetHiddenIndex() const
{
return Base::GetHiddenIndex();
}
__host__ __device__ auto& GetHiddenIndex() { return Base::GetHiddenIndex(); }
};
template <typename TensorDesc, typename TopIndex>
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc& tensor_desc,
const TopIndex& idx_top)
{
const auto adaptor_coord = make_tensor_adaptor_coordinate(tensor_desc, idx_top);
return TensorCoordinate<TensorDesc::GetNumOfHiddenDimension(),
remove_cvref_t<decltype(TensorDesc::GetTopDimensionHiddenIds())>>{
adaptor_coord};
}
template <bool JudgeDoTransforms = true, typename TensorDesc, typename TensorCoord, typename Index>
__host__ __device__ constexpr void
move_tensor_coordinate(const TensorDesc& tensor_desc, TensorCoord& coord, const Index& coord_step)
{
move_tensor_adaptor_coordinate(tensor_desc, coord, coord_step);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool
coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid_assuming_top_index_is_valid(tensor_desc, coord);
}
template <typename TensorDesc, typename TensorCoord>
__host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& tensor_desc,
const TensorCoord& coord)
{
return adaptor_coordinate_is_valid(tensor_desc, coord);
}
} // namespace ck
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include "ck/utility/math.hpp" #include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp" #include "ck/utility/sequence.hpp"
#include "ck/utility/sequence_helper.hpp" #include "ck/utility/sequence_helper.hpp"
#include "ck/utility/statically_indexed_array_multi_index.hpp" #include "ck/utility/multi_index.hpp"
#include "ck/utility/tuple_helper.hpp" #include "ck/utility/tuple_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp" #include "ck/tensor_description/tensor_adaptor.hpp"
......
...@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -87,7 +87,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex(); const auto WMMA_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
// |KRepeat |MRepeat|MWave |MLane |KPack // |KRepeat |MRepeat|MWave |MLane |KPack
return make_tuple(0, 0, waveId_m, WMMA_a_idx, 0); return make_multi_index(0, 0, waveId_m, WMMA_a_idx, 0);
} }
__device__ static auto CalculateBThreadOriginDataIndex() __device__ static auto CalculateBThreadOriginDataIndex()
...@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle ...@@ -98,7 +98,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex(); const auto WMMA_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
// |KRepeat |NRepeat|Nwave |NLane |KPack // |KRepeat |NRepeat|Nwave |NLane |KPack
return make_tuple(0, 0, waveId_n, WMMA_b_idx, 0); return make_multi_index(0, 0, waveId_n, WMMA_b_idx, 0);
} }
template <index_t m0, index_t n0> template <index_t m0, index_t n0>
......
...@@ -66,7 +66,7 @@ struct BlockwiseSoftmax ...@@ -66,7 +66,7 @@ struct BlockwiseSoftmax
reduce::Add, reduce::Add,
false>>::type; false>>::type;
using ThreadClusterLengths_M_K = decltype(ThreadClusterDesc_M_K{}.GetLengths()); using ThreadClusterLengths_M_K = decltype(to_sequence(ThreadClusterDesc_M_K{}.GetLengths()));
using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType, using BlockwiseMaxReduce = PartitionedBlockwiseReduction_v2<AccDataType,
BlockSize, BlockSize,
......
...@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -50,6 +50,10 @@ struct ThreadGroupTensorSliceTransfer_v4r1
using Index = MultiIndex<nDim>; using Index = MultiIndex<nDim>;
#if 1 // debug
__host__ __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1() : threadwise_transfer_{} {}
#endif
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1( __device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(
const SrcDesc& src_desc, const SrcDesc& src_desc,
const Index& src_block_slice_origin, const Index& src_block_slice_origin,
......
...@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01 ...@@ -86,7 +86,7 @@ struct BlockToCTileMap_M00_N0_M01
const auto M00 = math::integer_divide_ceil(M0, M01); const auto M00 = math::integer_divide_ceil(M0, M01);
const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor( const auto m00_n0_m01_to_m0_n0_block_cluster_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), make_tuple(make_replicate_transform(make_tuple(1)),
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(make_tuple(M00, M01)),
make_pass_through_transform(make_tuple(N0))), make_pass_through_transform(make_tuple(N0))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
...@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01 ...@@ -402,7 +402,8 @@ struct BlockToCTileMap_M00_N00_M01_N01
const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
make_single_stage_tensor_adaptor( make_single_stage_tensor_adaptor(
make_tuple(make_insert_transform(1), // swallow the carry from lower dimensions make_tuple(make_replicate_transform(
make_tuple(1)), // swallow the carry from lower dimensions
make_unmerge_transform(make_tuple(M00, M01)), make_unmerge_transform(make_tuple(M00, M01)),
make_unmerge_transform(make_tuple(N00, N01))), make_unmerge_transform(make_tuple(N00, N01))),
make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<>{}, Sequence<0>{}, Sequence<1>{}),
......
...@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle ...@@ -630,7 +630,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
Gemm1KPack, // AMmaKStride Gemm1KPack, // AMmaKStride
Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{ Gemm1KPack * XdlopsGemm<FloatAB, MPerXdl, NPerXdl, Gemm1KPack, false>{}.K0PerXdlops>{
// BMmaKStride // BMmaKStride
make_tuple(0, 0, 0, 0)}; // A_origin make_multi_index(0, 0, 0, 0)}; // A_origin
auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer(); auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
......
...@@ -45,7 +45,7 @@ struct ThreadwiseTensorSliceSet_v1 ...@@ -45,7 +45,7 @@ struct ThreadwiseTensorSliceSet_v1
constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx); constexpr auto coord = make_tensor_coordinate(desc, origin_idx + access_idx);
constexpr bool is_valid = constexpr bool is_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(desc, coord); coordinate_has_valid_offset_assuming_top_index_is_valid(desc, coord);
constexpr index_t offset = coord.GetOffset(); constexpr index_t offset = coord.GetOffset();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment