Commit f0759faf authored by Jun Liu's avatar Jun Liu
Browse files

Merge branch 'develop' into amd-develop

parents 20ddaeba 764164b4
...@@ -1080,6 +1080,19 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -1080,6 +1080,19 @@ struct GridwiseGemm_xdl_cshuffle_v3
} }
} }
if constexpr(is_same<remove_cvref_t<CDataType>, bhalf_t>::value)
{
#if DEBUG_LOG
std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl;
#endif // DEBUG_LOG
if(karg.KBatch > 1)
{
return false;
}
}
// check gridwise gemm pipeline // check gridwise gemm pipeline
const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
......
...@@ -42,7 +42,8 @@ template <typename SrcDatas, ...@@ -42,7 +42,8 @@ template <typename SrcDatas,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...> typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...> typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r2 struct ThreadwiseTensorSliceTransfer_v7r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...> // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...> // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers, template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false> enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) __device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// loop over space-filling curve // loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) { static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>(); auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>(); auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
bool oob_val = true;
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type; using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
...@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid;
src_vectors(i).template AsType<src_vector_t>()(I0) = src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
is_src_valid);
}); });
constexpr auto get_elem_op_vec_len = []() { constexpr auto get_elem_op_vec_len = []() {
...@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs); unpack2(element_op_, dst_data_refs, src_data_refs);
}); });
elm_vectors_tuple_(iAccess) = elm_vectors; elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
// move coordinate // move coordinate
if constexpr(iAccess.value != src_num_access - 1) if constexpr(iAccess.value != src_num_access - 1)
...@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}); });
} }
__device__ void TransposeFromElmToDst() #if 1
template <index_t ThreadScratchId = 0>
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
static_for<0, nDst, 1>{}([&](auto i) {
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
});
}
#endif
template <index_t ThreadScratchId = 0>
__device__ void
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>; using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch = using ElmThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()), decltype(GetSrcThreadScratchDescriptor()),
true>; true>;
using DstThreadScratch = using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2
decltype(GetDstThreadScratchDescriptor()), decltype(GetDstThreadScratchDescriptor()),
true>; true>;
SrcThreadScratch elm_thread_scratch_; ElmThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ = elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_); bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value && ((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value && (is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
...@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
} }
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_); dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
} }
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...> // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...> // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers, template <typename DstBuffers,
index_t ThreadScratchId = 0,
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false> enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) __device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
TransposeFromElmToDst(); OOBCheck(thread_scratch_id);
TransposeFromElmToDst(thread_scratch_id);
// loop over space-filling curve // loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) { static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}]; auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
...@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_; using ElmVectorTuple = StaticallyIndexedArray<ElmVectorsType, src_num_access>;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_; using DstVectorTuple = StaticallyIndexedArray<DstVectorsType, dst_num_access>;
StaticallyIndexedArray<ElmVectorTuple, NumThreadScratch> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorTuple, NumThreadScratch> dst_vectors_tuple_;
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#ifndef UTILITY_DEBUG_HPP #ifndef UTILITY_DEBUG_HPP
#define UTILITY_DEBUG_HPP #define UTILITY_DEBUG_HPP
...@@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements) ...@@ -79,6 +79,13 @@ __device__ void print_shared(T const* p_shared, index_t num_elements)
__syncthreads(); __syncthreads();
} }
template <index_t... Ids>
__device__ static bool is_thread_local_1d_id_idx()
{
const auto tid = get_thread_local_1d_id();
return ((tid == Ids) || ...);
}
} // namespace debug } // namespace debug
} // namespace ck } // namespace ck
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <hip/hip_runtime.h>
namespace ck {
static __global__ void flush_icache()
{
asm __volatile__("s_icache_inv \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t"
"s_nop 0 \n\t" ::
:);
}
} // namespace ck
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <ostream>
#pragma once #pragma once
...@@ -24,3 +25,14 @@ constexpr LoopScheduler make_default_loop_scheduler() ...@@ -24,3 +25,14 @@ constexpr LoopScheduler make_default_loop_scheduler()
} }
} // namespace ck } // namespace ck
inline std::ostream& operator<<(std::ostream& os, const ck::LoopScheduler& s)
{
switch(s)
{
case ck::LoopScheduler::Default: os << "Default"; break;
case ck::LoopScheduler::Interwave: os << "Interwave"; break;
default: os << "";
}
return os;
}
This diff is collapsed.
This diff is collapsed.
...@@ -4,7 +4,8 @@ set(GEMM_MULTI_ABD_INSTANCES) ...@@ -4,7 +4,8 @@ set(GEMM_MULTI_ABD_INSTANCES)
list(APPEND GEMM_MULTI_ABD_INSTANCES list(APPEND GEMM_MULTI_ABD_INSTANCES
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_mk_nk_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_bias_gelu_bf16_i8_bf16_km_kn_mn_v1_instance.cpp
device_gemm_xdl_multi_abd_multiply_bias_gelu_bf16_i8_bf16_mk_kn_mn_v1_instance.cpp
) )
add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES}) add_instance_library(device_gemm_multi_abd_instance ${GEMM_MULTI_ABD_INSTANCES})
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment