Unverified Commit f83e9701 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[GEMM] Gemm universal device operation (#1154)



* Optimize GEMM on MI200/300:
1. Add new blockwise gemm pipeline
2. Add irregular splitk intances

* clang format + typo fix

* Fix a bug

* initial commit

* Add more instances to irregular splitk

* blkgemm pipeline v1~4 prototype

* Sanity Checked. Known issue:
1. Poor performance of splitk
2. Register spill on blkgemmpipeline v3

* Sanity and Performance fix:
1. fix a bug related to sanity in grouped b2c mapping
2. fix a bug related to sanity and performance in splitk offset

* Sanity and API update:
1. Remove prefetch stage
2. Fix valid check bug
3, Add first gemm_universal instance into ckProfiler

* Add NN instances for gemm universal

* 1. Add NT instances for gemm_universal
2. Fix a bug about Kpadding in gemm_universal

* Fix a bug regarding padding Odd K number

* remove kernel print

* Fix KPadding bug...

* Update safety check

* another try to fix kpadding..

* Sanity checked

* new instances..

* clang format+typo fix

* remove clang format script's change

* Add non-hotloop compile option

* 1. Add fp16xfp8 example
2. pull packed convert f8 from pr1150

* Some miscs.. opt and fix

* Add pipeline description docs

* Split universal gemm instance library to cut profiler compiling time

* uncomment cmakefile

* Fix a bug caused by blockwise_gemm_pipe_v2

* reduce default splitk to 1

* Add 224x256x64 tile size

* update, including:
1. Experiment pipeline 5~7
2. Optimization for pipeline 4
3. Organized instance library

* temp save

* temp save

* Permuted lds layout, sanity and function checked

* clang format

* Move OOB check from RunRead to RunWrite, for better software pipeline.
TODO: agpr spill when NN layout

* clangformat

* A/B splitpipe scheduler for v3

* Fix two bugs

* bug fix

* fix a bug in oob check

* Example for mixed fp16_fp8 gemm

* Clean experimental code blocks

* Add mixed precision gemm into profiler

* tempsave

* optimize m/n major lds layout

* Add RRR GEMM  mixed precision instances

* Optimize f8 matrix transpose

* Add test_gemm_universal

* A/B spilt schedule for blkpip v5

* Take ds_read2 into iglp scheduling scheme

* format

* fixed cmake

* Add llvm-option into CI cmake flag

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7cdf5a96
...@@ -202,15 +202,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -202,15 +202,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_data_idx_seq = generate_sequence_v2( constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{}); [&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>; using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type; using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container auto src_vector_container =
auto src_vector_container = src_vector_type{ src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>; using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type; using dst_vector_t = typename dst_vector_type::type;
...@@ -305,12 +307,78 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -305,12 +307,78 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx]; dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
}); });
#else #else
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_ // sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value && ((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value && (is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
// each transpose does // each transpose does
...@@ -386,6 +454,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -386,6 +454,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// if there is transpose, it's done here // if there is transpose, it's done here
// if there is oob check, it's done here
// TODO move this elsewhere // TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id); TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
...@@ -738,6 +807,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -738,6 +807,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss); return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
} }
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
return make_naive_tensor_descriptor_packed(src_access_lengths);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor() __device__ static constexpr auto GetDstThreadScratchDescriptor()
{ {
// 1st stage of transforms // 1st stage of transforms
...@@ -789,6 +868,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -789,6 +868,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private: private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){}; static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){}; static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch = using SrcThreadScratch =
...@@ -798,6 +879,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -798,6 +879,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_), decltype(src_thread_scratch_desc_),
true>; true>;
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch
1,
decltype(src_oob_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
DstScalarPerVector, DstScalarPerVector,
...@@ -805,6 +893,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -805,6 +893,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>; true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_; StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
......
This diff is collapsed.
...@@ -163,6 +163,13 @@ struct scalar_type<bf8_t> ...@@ -163,6 +163,13 @@ struct scalar_type<bf8_t>
static constexpr index_t vector_size = 1; static constexpr index_t vector_size = 1;
}; };
template <>
struct scalar_type<bool>
{
using type = bool;
static constexpr index_t vector_size = 1;
};
template <typename T> template <typename T>
struct vector_type<T, 1> struct vector_type<T, 1>
{ {
......
...@@ -10,10 +10,12 @@ namespace ck { ...@@ -10,10 +10,12 @@ namespace ck {
__device__ void block_sync_lds() __device__ void block_sync_lds()
{ {
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM #if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\ // asm volatile("\
s_waitcnt lgkmcnt(0) \n \ // s_waitcnt lgkmcnt(0) \n \
s_barrier \ // s_barrier \
" ::); // " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else #else
__syncthreads(); __syncthreads();
#endif #endif
......
This diff is collapsed.
...@@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x) ...@@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y y; Y y;
// auto t = reinterpret_cast<const Y*>(&x);
// y = *t;
__builtin_memcpy(&y, &x, sizeof(X)); __builtin_memcpy(&y, &x, sizeof(X));
return y; return y;
......
...@@ -21,7 +21,7 @@ template <typename ADataType, ...@@ -21,7 +21,7 @@ template <typename ADataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
typename ComputeTypeA = ADataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA> typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator struct ReferenceGemm : public device::BaseOperator
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment