Unverified Commit f83e9701 authored by Haocong WANG's avatar Haocong WANG Committed by GitHub
Browse files

[GEMM] Gemm universal device operation (#1154)



* Optimize GEMM on MI200/300:
1. Add new blockwise gemm pipeline
2. Add irregular splitk intances

* clang format + typo fix

* Fix a bug

* initial commit

* Add more instances to irregular splitk

* blkgemm pipeline v1~4 prototype

* Sanity Checked. Known issue:
1. Poor performance of splitk
2. Register spill on blkgemmpipeline v3

* Sanity and Performance fix:
1. fix a bug related to sanity in grouped b2c mapping
2. fix a bug related to sanity and performance in splitk offset

* Sanity and API update:
1. Remove prefetch stage
2. Fix valid check bug
3, Add first gemm_universal instance into ckProfiler

* Add NN instances for gemm universal

* 1. Add NT instances for gemm_universal
2. Fix a bug about Kpadding in gemm_universal

* Fix a bug regarding padding Odd K number

* remove kernel print

* Fix KPadding bug...

* Update safety check

* another try to fix kpadding..

* Sanity checked

* new instances..

* clang format+typo fix

* remove clang format script's change

* Add non-hotloop compile option

* 1. Add fp16xfp8 example
2. pull packed convert f8 from pr1150

* Some miscs.. opt and fix

* Add pipeline description docs

* Split universal gemm instance library to cut profiler compiling time

* uncomment cmakefile

* Fix a bug caused by blockwise_gemm_pipe_v2

* reduce default splitk to 1

* Add 224x256x64 tile size

* update, including:
1. Experiment pipeline 5~7
2. Optimization for pipeline 4
3. Organized instance library

* temp save

* temp save

* Permuted lds layout, sanity and function checked

* clang format

* Move OOB check from RunRead to RunWrite, for better software pipeline.
TODO: agpr spill when NN layout

* clangformat

* A/B splitpipe scheduler for v3

* Fix two bugs

* bug fix

* fix a bug in oob check

* Example for mixed fp16_fp8 gemm

* Clean experimental code blocks

* Add mixed precision gemm into profiler

* tempsave

* optimize m/n major lds layout

* Add RRR GEMM  mixed precision instances

* Optimize f8 matrix transpose

* Add test_gemm_universal

* A/B spilt schedule for blkpip v5

* Take ds_read2 into iglp scheduling scheme

* format

* fixed cmake

* Add llvm-option into CI cmake flag

---------
Co-authored-by: default avatarJing Zhang <jizhan@amd.com>
parent 7cdf5a96
......@@ -202,15 +202,17 @@ struct ThreadwiseTensorSliceTransfer_v3r1
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
// maintain a container record is_src_valid, waiting for RunWrite use.
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
// copy data from src_buf into src_vector_container
auto src_vector_container = src_vector_type{
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid)};
auto src_vector_container =
src_vector_type{src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), true)};
using dst_vector_type = vector_type_maker_t<DstData, SrcScalarPerVector>;
using dst_vector_t = typename dst_vector_type::type;
......@@ -305,12 +307,78 @@ struct ThreadwiseTensorSliceTransfer_v3r1
dst_thread_scratch_(idx) = src_thread_scratch_tuple_[thread_scratch_id][idx];
});
#else
// OOB Check
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
constexpr auto ordered_src_access_lengths =
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
// loop over tensor and copy
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
// judge move forward or move backward
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_src_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
// calculate src data index
constexpr auto src_data_idx = [&]() {
Index ordered_idx;
static_for<0, nDim, 1>{}([&](auto i) {
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
: ordered_src_access_lengths[i] - 1 -
ordered_src_access_idx[i];
});
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
src_scalar_per_access;
}();
constexpr auto src_data_idx_seq = generate_sequence_v2(
[&](auto i) { return Number<src_data_idx[i]>{}; }, Number<src_data_idx.Size()>{});
using vector_t = typename vector_type_maker<DstData, SrcScalarPerVector>::type::type;
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : vector_t(0);
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
});
// sub-dword transpose between src_thread_scratch_ and dst_thread_scratch_
// TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{
// each transpose does
......@@ -386,6 +454,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// if there is transpose, it's done here
// if there is oob check, it's done here
// TODO move this elsewhere
TransferDataFromSrcThreadScratchToDstThreadScratch(thread_scratch_id);
......@@ -738,6 +807,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
return transform_tensor_descriptor(desc0, transforms, low_dim_idss, up_dim_idss);
}
__device__ static constexpr auto GetSrcOOBThreadScratchDescriptor()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
return make_naive_tensor_descriptor_packed(src_access_lengths);
}
__device__ static constexpr auto GetDstThreadScratchDescriptor()
{
// 1st stage of transforms
......@@ -789,6 +868,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private:
static constexpr auto src_thread_scratch_desc_ = decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto src_oob_thread_scratch_desc_ =
decltype(GetSrcThreadScratchDescriptor()){};
static constexpr auto dst_thread_scratch_desc_ = decltype(GetDstThreadScratchDescriptor()){};
using SrcThreadScratch =
......@@ -798,6 +879,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
decltype(src_thread_scratch_desc_),
true>;
using SrcOOBThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
bool, // apply data_convert with SrcThreadScratch
1,
decltype(src_oob_thread_scratch_desc_),
true>;
using DstThreadScratch = StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData,
DstScalarPerVector,
......@@ -805,6 +893,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
true>;
StaticallyIndexedArray<SrcThreadScratch, NumThreadScratch> src_thread_scratch_tuple_;
StaticallyIndexedArray<SrcOOBThreadScratch, NumThreadScratch> src_oob_thread_scratch_tuple_;
DstThreadScratch dst_thread_scratch_;
......
This diff is collapsed.
......@@ -163,6 +163,13 @@ struct scalar_type<bf8_t>
static constexpr index_t vector_size = 1;
};
template <>
struct scalar_type<bool>
{
using type = bool;
static constexpr index_t vector_size = 1;
};
template <typename T>
struct vector_type<T, 1>
{
......
......@@ -10,10 +10,12 @@ namespace ck {
__device__ void block_sync_lds()
{
#if CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
asm volatile("\
s_waitcnt lgkmcnt(0) \n \
s_barrier \
" ::);
// asm volatile("\
// s_waitcnt lgkmcnt(0) \n \
// s_barrier \
// " ::);
__builtin_amdgcn_s_waitcnt(0xc07f);
__builtin_amdgcn_s_barrier();
#else
__syncthreads();
#endif
......
This diff is collapsed.
......@@ -43,6 +43,8 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
#if CK_EXPERIMENTAL_USE_MEMCPY_FOR_BIT_CAST
Y y;
// auto t = reinterpret_cast<const Y*>(&x);
// y = *t;
__builtin_memcpy(&y, &x, sizeof(X));
return y;
......
......@@ -21,7 +21,7 @@ template <typename ADataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = ADataType,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceGemm : public device::BaseOperator
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment