Unverified Commit 0d0150db authored by zjing14's avatar zjing14 Committed by GitHub
Browse files

bf16A_Int8B with fastgelu/bias (#1264)

* changed the copy function to v7r2

* adding multi_abd

* in-progress

* add post-load oob check

* debugging

* adjust instances

* add run_lds

* add elemntwise_op

* replace multi_abd_device with v3

* clean up

* clean

* clean

* Added LDSType

* profiling

* adjust oobcheck

* add missing file

* refactor

* clean

* add examples
parent b4032629
...@@ -42,7 +42,8 @@ template <typename SrcDatas, ...@@ -42,7 +42,8 @@ template <typename SrcDatas,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t DstScalarPerVector, index_t DstScalarPerVector,
typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...> typename SrcResetCoordinateAfterRunFlags, // Sequence<bool ...>
typename DstResetCoordinateAfterRunFlags> // Sequence<bool ...> typename DstResetCoordinateAfterRunFlags, // Sequence<bool ...>
index_t NumThreadScratch = 1>
struct ThreadwiseTensorSliceTransfer_v7r2 struct ThreadwiseTensorSliceTransfer_v7r2
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
...@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -139,14 +140,19 @@ struct ThreadwiseTensorSliceTransfer_v7r2
// SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...> // SrcDescs: Tuple<const SrcDesc0&, const SrcDesc1&, ...>
// SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...> // SrcBuffers: Tuple<const SrcBuffer0&, const SrcBuffer1&, ...>
template <typename SrcBuffers, template <typename SrcBuffers,
index_t ThreadScratchId = 0,
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false> enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs) __device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
// loop over space-filling curve // loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) { static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>(); auto src_vectors = generate_vectors<SrcDatas, SrcScalarPerVector>();
auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>(); auto elm_vectors = generate_vectors<DstDatas, SrcScalarPerVector>();
bool oob_val = true;
// copy data from src_bufs into src_vectors // copy data from src_bufs into src_vectors
static_for<0, nSrc, 1>{}([&](auto i) { static_for<0, nSrc, 1>{}([&](auto i) {
using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type; using src_vector_t = typename remove_cvref_t<decltype(src_vectors[i])>::type;
...@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -155,9 +161,10 @@ struct ThreadwiseTensorSliceTransfer_v7r2
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i], coordinate_has_valid_offset_assuming_visible_index_is_valid(src_descs[i],
src_coords_[i]); src_coords_[i]);
oob_val = oob_val & is_src_valid;
src_vectors(i).template AsType<src_vector_t>()(I0) = src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
is_src_valid);
}); });
constexpr auto get_elem_op_vec_len = []() { constexpr auto get_elem_op_vec_len = []() {
...@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -218,7 +225,8 @@ struct ThreadwiseTensorSliceTransfer_v7r2
unpack2(element_op_, dst_data_refs, src_data_refs); unpack2(element_op_, dst_data_refs, src_data_refs);
}); });
elm_vectors_tuple_(iAccess) = elm_vectors; elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
oob_vectors_tuple_(thread_scratch_id)(iAccess) = oob_val;
// move coordinate // move coordinate
if constexpr(iAccess.value != src_num_access - 1) if constexpr(iAccess.value != src_num_access - 1)
...@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -245,17 +253,38 @@ struct ThreadwiseTensorSliceTransfer_v7r2
}); });
} }
__device__ void TransposeFromElmToDst() #if 1
template <index_t ThreadScratchId = 0>
__device__ void OOBCheck(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
static_for<0, src_num_access, 1>{}([&](auto iAccess) {
auto elm_vectors = elm_vectors_tuple_[thread_scratch_id][iAccess];
auto oob_val = oob_vectors_tuple_[thread_scratch_id][iAccess];
static_for<0, nDst, 1>{}([&](auto i) {
using elm_vector_t = typename remove_cvref_t<decltype(elm_vectors[i])>::type;
elm_vectors(i).template AsType<elm_vector_t>()(I0) =
oob_val ? elm_vectors(i).template AsType<elm_vector_t>()[I0] : elm_vector_t{0};
});
elm_vectors_tuple_(thread_scratch_id)(iAccess) = elm_vectors;
});
}
#endif
template <index_t ThreadScratchId = 0>
__device__ void
TransposeFromElmToDst(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>; using DstData = remove_cvref_t<decltype(DstDatas{}[I0])>;
using SrcThreadScratch = using ElmThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
SrcScalarPerVector, SrcScalarPerVector,
decltype(GetSrcThreadScratchDescriptor()), decltype(GetSrcThreadScratchDescriptor()),
true>; true>;
using DstThreadScratch = using DstThreadScratch =
StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr, StaticTensorTupleOfVectorBuffer<AddressSpaceEnum::Vgpr,
DstData, DstData,
...@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -263,15 +292,17 @@ struct ThreadwiseTensorSliceTransfer_v7r2
decltype(GetDstThreadScratchDescriptor()), decltype(GetDstThreadScratchDescriptor()),
true>; true>;
SrcThreadScratch elm_thread_scratch_; ElmThreadScratch elm_thread_scratch_;
DstThreadScratch dst_thread_scratch_; DstThreadScratch dst_thread_scratch_;
elm_thread_scratch_.data_ = elm_thread_scratch_.data_ =
bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_); bit_cast<decltype(elm_thread_scratch_.data_)>(elm_vectors_tuple_[thread_scratch_id]);
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
((is_same<half_t, remove_cvref_t<DstData>>::value && ((is_same<half_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) || SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) ||
(is_same<f8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0) ||
(is_same<int8_t, remove_cvref_t<DstData>>::value && (is_same<int8_t, remove_cvref_t<DstData>>::value &&
SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0))) SrcScalarPerVector % 4 == 0 && DstScalarPerVector % 4 == 0)))
{ {
...@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -338,20 +369,24 @@ struct ThreadwiseTensorSliceTransfer_v7r2
[&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; }); [&](auto idx) { dst_thread_scratch_(idx) = elm_thread_scratch_[idx]; });
} }
dst_vectors_tuple_ = bit_cast<decltype(dst_vectors_tuple_)>(dst_thread_scratch_.data_); dst_vectors_tuple_(thread_scratch_id) = bit_cast<DstVectorTuple>(dst_thread_scratch_.data_);
} }
// DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...> // DstDescs: Tuple<const DstDesc0&, const DstDesc1&, ...>
// DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...> // DstBuffers: Tuple<const DstBuffer0&, const DstBuffer1&, ...>
template <typename DstBuffers, template <typename DstBuffers,
index_t ThreadScratchId = 0,
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false> enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs) __device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
TransposeFromElmToDst(); OOBCheck(thread_scratch_id);
TransposeFromElmToDst(thread_scratch_id);
// loop over space-filling curve // loop over space-filling curve
static_for<0, dst_num_access, 1>{}([&](auto iAccess) { static_for<0, dst_num_access, 1>{}([&](auto iAccess) {
auto dst_vectors = dst_vectors_tuple_[Number<iAccess>{}]; auto dst_vectors = dst_vectors_tuple_[thread_scratch_id][iAccess];
// copy data from buf_vectors into dst_bufs // copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) { static_for<0, nDst, 1>{}([&](auto i) {
...@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2 ...@@ -578,8 +613,14 @@ struct ThreadwiseTensorSliceTransfer_v7r2
static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess(); static constexpr auto src_num_access = SrcSpaceFillingCurve::GetNumOfAccess();
static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess(); static constexpr auto dst_num_access = DstSpaceFillingCurve::GetNumOfAccess();
StaticallyIndexedArray<ElmVectorsType, src_num_access> elm_vectors_tuple_; using ElmVectorTuple = StaticallyIndexedArray<ElmVectorsType, src_num_access>;
StaticallyIndexedArray<DstVectorsType, dst_num_access> dst_vectors_tuple_; using DstVectorTuple = StaticallyIndexedArray<DstVectorsType, dst_num_access>;
StaticallyIndexedArray<ElmVectorTuple, NumThreadScratch> elm_vectors_tuple_;
StaticallyIndexedArray<DstVectorTuple, NumThreadScratch> dst_vectors_tuple_;
using OOBVectorTuple = StaticallyIndexedArray<bool, src_num_access>;
StaticallyIndexedArray<OOBVectorTuple, NumThreadScratch> oob_vectors_tuple_;
SrcCoords src_coords_; SrcCoords src_coords_;
DstCoords dst_coords_; DstCoords dst_coords_;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment