"docs/vscode:/vscode.git/clone" did not exist on "44c8d6cc60f081470ee2fb290edde49d427042f9"
Commit 3784329b authored by coderfeli's avatar coderfeli
Browse files

can run

parent 4a1ec815
...@@ -363,15 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -363,15 +363,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = // b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( // b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; // make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
...@@ -423,15 +423,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -423,15 +423,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec; vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = // b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset( // b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}]; // make_tuple(n0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
......
...@@ -109,6 +109,12 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -109,6 +109,12 @@ struct ThreadGroupTensorSliceTransfer_v4r1
} }
} }
template <typename SeqIdx>
__device__ auto GetSrcThreadScratchIdx()
{
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx>();
}
template <typename SrcBuffer, index_t ThreadScratchId = 0> template <typename SrcBuffer, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDesc& src_desc, __device__ void RunRead(const SrcDesc& src_desc,
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
......
...@@ -267,6 +267,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -267,6 +267,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
} }
template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ auto GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
}
template <index_t ThreadScratchId> template <index_t ThreadScratchId>
__device__ void __device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id) TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment