Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
2b840f5a
Commit
2b840f5a
authored
Nov 18, 2024
by
aska-0096
Browse files
reduce prefetch stage in blockwisepipev4
parent
925c0719
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
47 additions
and
66 deletions
+47
-66
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+47
-66
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
2b840f5a
...
@@ -8,7 +8,7 @@
...
@@ -8,7 +8,7 @@
namespace
ck
{
namespace
ck
{
// Compute optimimal pipeline with highest resource request
// Compute optimimal pipeline with highest resource request
// GlobalPrefetchStages:
4
// GlobalPrefetchStages:
3
// LocalPreFillStages: 2
// LocalPreFillStages: 2
// LocalPreFetchStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 2
// LocalSharedMemoryBuffer: 2
...
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -142,9 +142,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
using
Base
::
AMmaKStride
;
using
Base
::
AMmaKStride
;
using
Base
::
BMmaKStride
;
using
Base
::
BMmaKStride
;
static
constexpr
index_t
PrefetchStages
=
4
;
static
constexpr
index_t
PrefetchStages
=
3
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
PrefillStages
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
static
constexpr
index_t
HotloopUnroll
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
...
@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -164,8 +164,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
}
}
}
}
template
<
typename
ScheduleGroup
>
__device__
static
constexpr
void
HotLoopScheduler
()
__device__
static
constexpr
void
HotLoopScheduler
(
ScheduleGroup
schedule_group
)
{
{
// TODO: Take data type into consideration as pipe ver 3
// TODO: Take data type into consideration as pipe ver 3
// A-B splited schedule
// A-B splited schedule
...
@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -195,42 +194,42 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
ignore
=
i
;
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_a
,
1
>
{}([
&
](
auto
idsread
)
{
static_for
<
0
,
num_dsread_per_issue_a
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
schedule_group
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
});
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
static_for
<
0
,
num_dswrite_per_issue_a
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
schedule_group
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
schedule_group
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_a
,
num_dswrite_per_issue_a
,
schedule_group
);
// MFMA
0
);
// MFMA
});
});
static_for
<
0
,
num_issue_b
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_issue_b
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
ignore
=
i
;
static_for
<
0
,
num_dsread_per_issue_b
,
1
>
{}([
&
](
auto
idsread
)
{
static_for
<
0
,
num_dsread_per_issue_b
,
1
>
{}([
&
](
auto
idsread
)
{
ignore
=
idsread
;
ignore
=
idsread
;
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
schedule_group
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x100
,
1
,
0
);
// DS read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
});
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
static_for
<
0
,
num_dswrite_per_issue_b
,
1
>
{}([
&
](
auto
idswrite
)
{
ignore
=
idswrite
;
ignore
=
idswrite
;
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
schedule_group
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x200
,
1
,
0
);
// DS write
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
schedule_group
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
});
});
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
schedule_group
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x020
,
1
,
0
);
// VMEM read
__builtin_amdgcn_sched_group_barrier
(
0x008
,
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_mfma_per_issue
-
num_dsread_per_issue_a
-
num_dswrite_per_issue_b
,
num_dswrite_per_issue_b
,
schedule_group
);
// MFMA
0
);
// MFMA
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
...
@@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -274,26 +273,15 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
// Global prefetch 1
// Global prefetch 1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Local prefill 1
// Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
),
I0
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I0
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
),
I0
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I0
));
// Local prefill 2
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
),
I1
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
),
I1
);
// Local prefetch 1
// Local prefetch 1
block_sync_lds
();
block_sync_lds
();
...
@@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -316,16 +304,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
// Global prefetch
3
// Global prefetch
2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I0
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Global prefetch 4
// Local prefill 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I1
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
I1
));
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
I1
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
I1
));
// Global prefetch 3
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -343,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -343,9 +335,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
LoopFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
lds_write_buf
,
auto
vmem_buf
,
auto
mfma_reg_buf
)
{
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
...
@@ -368,13 +358,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -368,13 +358,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
a_blockwise_copy
.
RunWrite
(
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
),
vmem_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
vmem_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
vmem_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
...
@@ -411,11 +399,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -411,11 +399,11 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
HotLoopScheduler
(
schedule_group
);
HotLoopScheduler
();
};
};
LoopFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I0
);
LoopFunc
(
I1
,
I1
,
I0
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
,
I1
,
I0
);
LoopFunc
(
I0
,
I0
,
I1
,
I1
);
i
+=
HotloopUnroll
;
i
+=
HotloopUnroll
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
...
@@ -424,9 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -424,9 +412,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
ReadWriteCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
lds_read_reg_buf
,
auto
lds_write_buf
,
auto
lds_write_buf
,
auto
vmem_buf
,
auto
mfma_reg_buf
)
{
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
...
@@ -448,8 +434,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -448,8 +434,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
)
,
vmem_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
.
At
(
lds_write_buf
));
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
)
,
vmem_buf
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
.
At
(
lds_write_buf
));
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
@@ -479,13 +465,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -479,13 +465,10 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
HotLoopScheduler
(
schedule_group
);
HotLoopScheduler
();
};
};
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
ReadCompFunc
=
[
&
](
auto
lds_read_buf
,
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
)
{
auto
lds_read_reg_buf
,
auto
mfma_reg_buf
,
auto
schedule_group
)
{
block_sync_lds
();
block_sync_lds
();
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k
)
{
...
@@ -535,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -535,7 +518,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
});
});
});
});
HotLoopScheduler
(
schedule_group
);
HotLoopScheduler
();
};
};
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
auto
CompFunc
=
[
&
](
auto
mfma_reg_buf
)
{
...
@@ -570,15 +553,13 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -570,15 +553,13 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
if
constexpr
(
TailNum
==
TailNumber
::
Odd
)
{
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I1
);
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
);
ReadCompFunc
(
I0
,
I0
,
I1
,
I1
);
ReadCompFunc
(
I0
,
I0
,
I1
);
CompFunc
(
I0
);
CompFunc
(
I0
);
}
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
else
if
constexpr
(
TailNum
==
TailNumber
::
Even
)
{
{
ReadWriteCompFunc
(
I1
,
I1
,
I0
,
I0
,
I0
,
I1
);
ReadCompFunc
(
I1
,
I1
,
I0
);
ReadWriteCompFunc
(
I0
,
I0
,
I1
,
I1
,
I1
,
I1
);
ReadCompFunc
(
I1
,
I1
,
I0
,
I1
);
CompFunc
(
I1
);
CompFunc
(
I1
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment