Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
4525c5d7
Commit
4525c5d7
authored
Dec 02, 2024
by
coderfeli
Browse files
merge upstream
parents
a8d88d8d
44828b7c
Changes
308
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1497 additions
and
581 deletions
+1497
-581
include/ck/library/utility/iterator.hpp
include/ck/library/utility/iterator.hpp
+0
-0
include/ck/library/utility/literals.hpp
include/ck/library/utility/literals.hpp
+0
-0
include/ck/library/utility/numeric.hpp
include/ck/library/utility/numeric.hpp
+0
-0
include/ck/library/utility/ranges.hpp
include/ck/library/utility/ranges.hpp
+0
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+64
-66
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+32
-33
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+131
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
...sor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
+13
-37
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
...tion/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
+0
-136
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+18
-2
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
...or_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
+6
-86
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
...n/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
+309
-73
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+58
-35
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+21
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+19
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+55
-17
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+36
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
+733
-85
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+0
-1
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+2
-0
No files found.
library/
include/ck/library/utility/iterator.hpp
→
include/ck/library/utility/iterator.hpp
View file @
4525c5d7
File moved
library/
include/ck/library/utility/literals.hpp
→
include/ck/library/utility/literals.hpp
View file @
4525c5d7
File moved
library/
include/ck/library/utility/numeric.hpp
→
include/ck/library/utility/numeric.hpp
View file @
4525c5d7
File moved
library/
include/ck/library/utility/ranges.hpp
→
include/ck/library/utility/ranges.hpp
View file @
4525c5d7
File moved
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
4525c5d7
...
@@ -269,15 +269,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
...
@@ -269,15 +269,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
});
});
});
...
@@ -341,14 +340,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
...
@@ -341,14 +340,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_buf
);
});
});
});
});
...
@@ -396,14 +395,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
...
@@ -396,14 +395,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_buf
);
});
});
});
});
...
@@ -447,14 +446,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
...
@@ -447,14 +446,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_buf
);
});
});
});
});
...
@@ -760,15 +759,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
...
@@ -760,15 +759,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{}),
b_block_buf
,
b_block_buf
,
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k0
,
I0
),
make_tuple
(
n0
,
I0
,
k0
,
I0
),
b_thread_buf
);
b_thread_buf
);
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// NOTE: Synchronize threads in a workgroup at the start of each MAC
...
@@ -866,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
...
@@ -866,14 +864,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k0
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k0
,
I0
),
}
);
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -942,14 +940,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
...
@@ -942,14 +940,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k0
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k0
,
I0
),
}
);
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -1018,14 +1016,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
...
@@ -1018,14 +1016,14 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k0
*
KPerInnerLoop
>
{})
,
b_thread_desc_
,
b_block_buf
,
make_tuple
(
n0
,
I0
,
k0
,
I0
)
,
b_thread_desc_
,
b_thread_buf
);
make_tuple
(
n0
,
I0
,
k0
,
I0
),
}
);
b_thread_buf
);
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
4525c5d7
...
@@ -305,14 +305,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -305,14 +305,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
I0
));
a_thread_bufs
(
I0
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
I0
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}
),
b_thread_desc_
,
b_block_buf
.
At
(
I0
)
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_bufs
(
I0
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_bufs
(
I0
)
);
});
});
});
});
...
@@ -356,15 +356,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -356,15 +356,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
b_block_desc_n0_n1_n2_k
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}),
b_block_buf
.
At
(
lds_read_buf
),
b_block_buf
.
At
(
lds_read_buf
),
b_thread_desc_
,
b_thread_desc_
,
make_tuple
(
n0
,
I0
,
k
,
I0
),
make_tuple
(
n0
,
I0
,
k
,
I0
),
b_thread_bufs
(
lds_read_reg_buf
));
b_thread_bufs
(
lds_read_reg_buf
));
});
});
});
});
});
...
@@ -437,14 +436,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -437,14 +436,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}
),
b_thread_desc_
,
b_block_buf
.
At
(
lds_read_buf
)
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_bufs
(
lds_read_reg_buf
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_bufs
(
lds_read_reg_buf
)
);
});
});
});
});
...
@@ -496,14 +495,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -496,14 +495,14 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k
,
I0
),
make_tuple
(
m0
,
I0
,
k
,
I0
),
a_thread_bufs
(
lds_read_reg_buf
));
a_thread_bufs
(
lds_read_reg_buf
));
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
});
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{})
,
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
b_block_buf
.
At
(
lds_read_buf
),
make_tuple
(
n0
,
I0
,
I0
,
Number
<
k
*
BMmaKStride
>
{}
),
b_thread_desc_
,
b_block_buf
.
At
(
lds_read_buf
)
,
make_tuple
(
n0
,
I0
,
k
,
I0
)
,
b_thread_desc_
,
b_thread_bufs
(
lds_read_reg_buf
));
make_tuple
(
n0
,
I0
,
k
,
I0
),
}
);
b_thread_bufs
(
lds_read_reg_buf
)
);
});
});
});
});
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <array>
#include <iostream>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <vector>
#include <vector>
#include "device_base.hpp"
#include "device_base.hpp"
#include "ck/utility/ignore.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
__host__
__device__
GroupedGemmKernelArgument
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
struct
GemmDesc
struct
GemmDesc
{
{
ck
::
index_t
M_
,
N_
,
K_
;
ck
::
index_t
M_
,
N_
,
K_
;
...
@@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator
...
@@ -48,6 +118,66 @@ struct DeviceGroupedGemm : public BaseOperator
CElementwiseOperation
c_element_op
)
=
0
;
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
//---------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// TODO: Add which kernels are using this (TileLoop * FixedNK ??)
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which will contain kernel
/// arguments.
/// @param[in] p_host_kernel_args The pointer to the host memory which contains kernel
/// arguments that should be copied to device memory.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
{
ignore
=
p_arg
;
ignore
=
p_dev_kernel_args
;
ignore
=
p_host_kernel_args
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer and may copy data to device.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
{
ignore
=
p_arg
;
ignore
=
p_dev_kernel_args
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
{
ignore
=
p_arg
;
std
::
ostringstream
err
;
err
<<
"This function is not implemented by the kernel: "
<<
this
->
GetTypeString
()
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include "device_grouped_gemm_splitk.hpp"
#include <array>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
DsLayout
,
typename
DsLayout
,
...
@@ -41,21 +20,18 @@ template <typename ALayout,
...
@@ -41,21 +20,18 @@ template <typename ALayout,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmFixedNK
:
DeviceGroupedGemm
<
ALayout
,
struct
DeviceGroupedGemmFixedNK
:
DeviceGroupedGemm
SplitK
<
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
=
0
;
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp
deleted
100644 → 0
View file @
a8d88d8d
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
///
/// @brief Structure representing single GEMM problem arguments.
///
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// point kernel.
///
/// @tparam NumDTensor The number of D input tensors.
///
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmMultipleDKernelArguments
{
__host__
__device__
GroupedGemmMultipleDKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedGemmMultipleDSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
View file @
4525c5d7
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
#include "device_grouped_gemm.hpp"
...
@@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
...
@@ -31,7 +31,23 @@ struct DeviceGroupedGemmSplitK : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
CElementwiseOperation
>
{
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Sets the k batch size.
///
/// @param p_arg Pointer to the Argument we're going to change.
/// @param[in] kbatch The kbatch value.
///
virtual
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
{
this
->
SetKBatchSize
(
p_arg
,
kbatch
);
};
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp
View file @
4525c5d7
...
@@ -3,83 +3,20 @@
...
@@ -3,83 +3,20 @@
#pragma once
#pragma once
#include <array>
#include <iostream>
#include <vector>
#include <sstream>
#include "device_grouped_gemm.hpp"
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/// @brief Grouped GEMM kernel using output Tile Looping algorithm
///
///
/// @brief Structure representing single GEMM problem arguments.
/// @par This kernel does not require any knowledge about input data sizes (GEMM M/N/K)
///
/// It requires only the number of groups to launch. Other information like
/// The pointer to the vector of those structures is passed to the GroupedGEMM entry
/// data pointers and GEMM sizes, packed into gemm kernel args may be all dynamic
/// point kernel.
/// (known only at kernel run-time).
///
/// @tparam NumDTensor The number of D input tensors.
///
///
template
<
index_t
NumDTensor
=
0
>
/// @note This kernel does not support SplitK.
struct
GroupedGemmTileLoopKernelArguments
{
__host__
__device__
GroupedGemmTileLoopKernelArguments
(
const
void
*
p_a_grid_
,
const
void
*
p_b_grid_
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid_
,
void
*
p_e_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
,
index_t
StrideE_
)
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_ds_grid
{
p_ds_grid_
},
p_e_grid
{
p_e_grid_
},
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideDs
{
StrideDs_
},
StrideE
{
StrideE_
}
{
}
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
void
Print
()
const
{
std
::
stringstream
str
;
for
(
auto
sd
:
StrideDs
)
str
<<
sd
<<
","
;
std
::
cout
<<
"arg {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SE:"
<<
StrideE
<<
", "
<<
"SDs: {"
<<
str
.
str
()
<<
"}"
<<
"}"
<<
std
::
endl
;
}
};
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -104,23 +41,6 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
...
@@ -104,23 +41,6 @@ struct DeviceGroupedGemmTileLoop : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
>
{
{
//----------------------------------------------------------------------------------------------
/// @brief Sets the device kernel arguments pointer.
///
/// @param p_arg The pointer to the Argument we're going to update.
/// @param[in] p_dev_kernel_args The pointer to the device memory which contains kernel
/// arguments.
///
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
=
0
;
//----------------------------------------------------------------------------------------------
/// @brief Gets the device kernel argument size.
///
/// @param[in] p_arg The pointer to the Device op Argument.
///
/// @return The device kernel argument size.
///
virtual
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
=
0
;
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_streamk_v3.hpp
100644 → 100755
View file @
4525c5d7
...
@@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -131,6 +131,7 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
arg
.
Print
();
arg
.
Print
();
...
@@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -147,26 +148,27 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
index_t
K_split
=
(
arg
.
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
hipGetErrorString
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Atomic
)
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_c_grid
,
0
,
arg
.
M
*
arg
.
N
*
sizeof
(
CDataType
),
stream_config
.
stream_id_
));
}
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
dim3
grid_dim
;
dim3
grid_dim
;
if
(
arg
.
Grid_size
<
0
)
if
(
arg
.
Grid_size
<
0
)
{
{
int
occupancy
,
num_cu
;
int
occupancy
,
num_cu
;
hipError_t
rtn
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
rtn
=
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
&
occupancy
,
kernel
,
BlockSize
,
0
);
hip_check_error
(
rtn
);
hipDeviceProp_t
dev_prop
;
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hipDevice_t
dev
;
rtn
=
hipGetDevice
(
&
dev
);
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
rtn
);
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
rtn
=
hipGetDeviceProperties
(
&
dev_prop
,
dev
);
num_cu
=
dev_prop
.
multiProcessorCount
;
hip_check_error
(
rtn
);
num_cu
=
dev_prop
.
multiProcessorCount
;
arg
.
Grid_size
=
num_cu
*
occupancy
;
arg
.
Grid_size
=
num_cu
*
occupancy
;
grid_dim
=
arg
.
Grid_size
;
grid_dim
=
arg
.
Grid_size
;
}
}
...
@@ -196,8 +198,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -196,8 +198,31 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
else
{
{
ave_time
=
launch_and_time_kernel
(
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
StreamKReductionStrategy
::
Atomic
)
{
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
else
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
char
*
workspace_semaphore
=
reinterpret_cast
<
char
*>
(
arg
.
p_workspace_
)
+
arg
.
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
GemmAccDataType
));
auto
preprocess
=
[
&
]()
{
hipMemsetAsync
(
workspace_semaphore
,
0
,
// sizeof(uint32_t),
arg
.
block_2_ctile_map_streamk
.
get_workspace_size_for_semaphore
(),
stream_config
.
stream_id_
);
};
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
preprocess
,
kernel
,
grid_dim
,
dim3
(
BlockSize
),
0
,
arg
);
}
}
}
};
};
...
@@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -211,14 +236,12 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
true
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
>
;
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
}
// Tail number could be One to Seven
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
...
@@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -340,53 +363,49 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
true
,
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Odd
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Odd
>
;
}
Run
(
kernel
);
else
}
{
else
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
const
auto
kernel
=
true
,
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Even
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
else
else
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
true
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Odd
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Odd
>
;
}
Run
(
kernel
);
else
}
{
else
const
auto
kernel
=
{
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
true
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
true
,
minimum_occupancy
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Even
>
;
minimum_occupancy
,
Run
(
kernel
);
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
}
}
}
...
@@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -396,14 +415,11 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
{
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
const
auto
kernel
=
false
,
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
InMemoryDataOperationEnum
::
Set
,
false
,
minimum_occupancy
>
;
InMemoryDataOperationEnum
::
Set
,
Run
(
kernel
);
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
}
}
...
@@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -418,6 +434,29 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
}
}
};
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
p_arg
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
if
constexpr
(
GridwiseGemm
::
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
return
p_arg
->
block_2_ctile_map_streamk
.
get_workspace_size
(
sizeof
(
GemmAccDataType
));
}
else
{
return
0
;
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
pArg
,
void
*
p_workspace
,
const
StreamConfig
&
=
StreamConfig
{})
const
override
{
Argument
*
pArg_
=
dynamic_cast
<
Argument
*>
(
pArg
);
pArg_
->
p_workspace_
=
p_workspace
;
}
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
// TODO: properly implement this check
// TODO: properly implement this check
...
@@ -464,8 +503,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
...
@@ -464,8 +503,205 @@ struct DeviceGemm_Xdl_CShuffle_Streamk_V3 : public DeviceGemm_Streamk_V2<ALayout
CElementwiseOperation
)
CElementwiseOperation
)
{
{
return
Argument
{
constexpr
index_t
minimum_occupancy
=
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
// HS
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
index_t
K_split
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
int
occupancy
,
num_cu
;
const
auto
calculate_grid_size
=
[
&
](
const
auto
&
kernel
)
{
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
hip_check_error
(
hipGetDevice
(
&
dev
));
hip_check_error
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
num_cu
=
dev_prop
.
multiProcessorCount
;
Grid_size
=
num_cu
*
occupancy
;
};
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
calculate_grid_size
(
kernel
);
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
calculate_grid_size
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
calculate_grid_size
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
calculate_grid_size
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
calculate_grid_size
(
kernel
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
calculate_grid_size
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
calculate_grid_size
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
calculate_grid_size
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
calculate_grid_size
(
kernel
);
}
}
}
else
{
// Tail number always 1
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3
<
GridwiseGemm
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
calculate_grid_size
(
kernel
);
}
}
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
streamk_sel
,
Grid_size
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
4525c5d7
...
@@ -18,7 +18,6 @@
...
@@ -18,7 +18,6 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_multiple_d_splitk.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
...
@@ -78,17 +77,17 @@ template <typename ALayout,
...
@@ -78,17 +77,17 @@ template <typename ALayout,
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
// TODO: change gridwise_gemm_v2r4r2 to support AK1 & BK1
enable_if_t
<
AK1
==
BK1
,
bool
>
=
false
>
enable_if_t
<
AK1
==
BK1
,
bool
>
=
false
>
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
struct
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
:
public
DeviceGroupedGemm
MultipleD
SplitK
<
ALayout
,
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
BLayout
,
DsLayout
,
DsLayout
,
ELayout
,
ELayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
DsDataType
,
DsDataType
,
EDataType
,
EDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
CDEElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
;
using
DeviceOp
=
DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
;
...
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -530,7 +529,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
index_t
skipped_group_count_
;
index_t
skipped_group_count_
;
index_t
grid_size_
;
index_t
grid_size_
;
// Pointer to device memory with GEMM kernel arguments.
// Pointer to device memory with GEMM kernel arguments.
const
void
*
p_dev_gemm_args_
;
void
*
p_dev_gemm_
k
args_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
...
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -566,7 +565,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
/// @return The average kernel execution time (if time measurement is enabled.)
/// @return The average kernel execution time (if time measurement is enabled.)
///
///
float
Run
(
const
Argument
&
arg
,
float
Run
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_args
,
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
...
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -621,7 +620,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
///
///
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
arg
.
p_dev_gemm_args_
==
nullptr
)
if
(
arg
.
p_dev_gemm_
k
args_
==
nullptr
)
{
{
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"The gemm arguments device buffer is not allocated!"
err
<<
"The gemm arguments device buffer is not allocated!"
...
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -637,7 +636,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
return
Run
(
arg
,
arg
.
p_dev_gemm_args_
,
arg
.
p_workspace_
,
stream_config
);
return
Run
(
arg
,
arg
.
p_dev_gemm_
k
args_
,
arg
.
p_workspace_
,
stream_config
);
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
float
Run
(
const
BaseArgument
*
p_arg
,
...
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -723,7 +722,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
float
DispatchKernel
(
const
Argument
&
arg
,
float
DispatchKernel
(
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
void
*
dev_gemm_workspace
,
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
const
StreamConfig
&
stream_config
)
const
{
{
...
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -746,7 +745,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
LaunchKernel
(
gemm_kernel
,
return
LaunchKernel
(
gemm_kernel
,
elementwise_kernel
,
elementwise_kernel
,
arg
,
arg
,
dev_gemm_args
,
dev_gemm_
k
args
,
dev_gemm_workspace
,
dev_gemm_workspace
,
stream_config
);
stream_config
);
}
}
...
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -755,12 +754,19 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
float
LaunchKernel
(
const
KernelFunction
&
gemm_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
KernelFunction2
&
elementwise_kernel
,
const
Argument
&
arg
,
const
Argument
&
arg
,
const
void
*
dev_gemm_args
,
void
*
dev_gemm_
k
args
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
[[
maybe_unused
]]
void
*
dev_gemm_workspace
,
const
StreamConfig
&
stream_config
)
const
const
StreamConfig
&
stream_config
)
const
{
{
float
time
{
0.
f
};
float
time
{
0.
f
};
hip_check_error
(
hipMemcpyWithStream
(
dev_gemm_kargs
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
auto
preprocess
=
[
&
]()
{
auto
preprocess
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
hip_check_error
(
hipMemsetAsync
(
dev_gemm_workspace
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
dev_gemm_workspace
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
...
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -774,7 +780,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
),
cast_pointer_to_constant_address_space
(
dev_gemm_
k
args
),
arg
.
gemm_kernel_args_
.
size
(),
arg
.
gemm_kernel_args_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -930,18 +936,30 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
return
str
.
str
();
return
str
.
str
();
}
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
void
SetDeviceKernelArgs
(
Base
Argument
*
p_
arg
,
void
*
p_dev_kernel_args
)
const
override
{
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
if
(
arg_ptr
)
arg
.
gemm_kernel_args_
.
data
(),
{
GetDeviceKernelArgSize
(
&
arg
),
arg_ptr
->
p_dev_gemm_kargs_
=
p_dev_kernel_args
;
hipMemcpyHostToDevice
));
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
void
S
etDeviceKernelArg
s
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
size_t
G
etDeviceKernelArg
Size
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
auto
arg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg
)
{
return
arg
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
...
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -974,17 +992,22 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
[[
deprecated
]]
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
arg
.
UpdateKBatch
(
kbatch
);
}
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
sizeof
(
GemmTransKernelArg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage::Argument structure!"
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
4525c5d7
...
@@ -20,7 +20,6 @@
...
@@ -20,7 +20,6 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -522,7 +521,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
;
ComputeTypeB
>
;
using
KernelArguments
=
GroupedGemm
TileLoop
KernelArgument
s
<
NumDTensor
>
;
using
KernelArguments
=
GroupedGemmKernelArgument
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
using
OffsettedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap2
<
Block2ETileMap
>
;
...
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
...
@@ -936,12 +935,31 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return
str
.
str
();
return
str
.
str
();
}
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
hip_check_error
(
hipMemcpy
(
p_dev_kernel_args
,
p_host_kernel_args
,
GetDeviceKernelArgSize
(
&
arg
),
hipMemcpyHostToDevice
));
}
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
,
const
void
*
p_host_kernel_args
)
const
override
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
,
p_host_kernel_args
);
}
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
void
SetDeviceKernelArgs
(
Argument
&
arg
,
void
*
p_dev_kernel_args
)
const
{
{
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
arg
.
p_dev_gemm_args_
=
p_dev_kernel_args
;
}
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
virtual
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
p_dev_kernel_args
);
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
4525c5d7
#pragma once
#pragma once
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -717,7 +717,24 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
return
p_arg_
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDXdlCShuffle::Argument structure!"
);
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
4525c5d7
...
@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -445,6 +445,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
// TODO: replace with GroupedGemmKernelArgument
struct
GemmBiasTransKernelArg
struct
GemmBiasTransKernelArg
{
{
// pointers
// pointers
...
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -900,40 +901,58 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
static
void
SetDeviceKernelArgs
(
Argument
&
arg
,
const
void
*
kernel_args
)
{
arg
.
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
// polymorphic
// polymorphic
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
const
void
*
kernel_args
)
const
override
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
kernel_args
)
const
override
{
{
return
SetDeviceKernelArgs
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kernel_args
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
grouped_gemm_kernel_args_dev
=
kernel_args
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
return
arg
.
group_count_
*
arg
.
barrier_size_grp_
*
sizeof
(
uint32_t
);
{
return
arg_ptr
->
group_count_
*
arg_ptr
->
barrier_size_grp_
*
sizeof
(
uint32_t
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
return
arg
.
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
{
return
arg_ptr
->
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
,
void
*
p_workspace
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
override
const
StreamConfig
&
stream_config
=
StreamConfig
{})
const
override
{
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
if
(
arg_ptr
)
{
arg_ptr
->
p_workspace_
=
p_workspace
;
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
hip_check_error
(
hip_check_error
(
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
p_
arg
),
stream_config
.
stream_id_
));
hipMemsetAsync
(
p_workspace
,
0
,
GetWorkSpaceSize
(
arg
_ptr
),
stream_config
.
stream_id_
));
}
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
...
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
...
@@ -941,7 +960,26 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
// polymorphic
// polymorphic
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
void
SetKBatch
(
BaseArgument
*
p_arg
,
index_t
k_batch
)
const
override
{
{
return
SetKBatch
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
k_batch
);
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
k_batch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
auto
arg_ptr
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
arg_ptr
)
{
arg_ptr
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemm_Xdl_Fixed_NK::Argument structure!"
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
4525c5d7
...
@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -538,10 +538,16 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
return
false
;
return
false
;
}
}
if
(
std
::
is_same_v
<
EDataType
,
ck
::
bhalf_t
>
&&
arg
.
K_BATCH
>
1
&&
!
is_bf16_atomic_supported
())
{
return
false
;
}
bool
supported
=
true
;
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
...
@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -631,16 +637,42 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
auto
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
sizeof
(
GemmTransKernelArg
);
if
(
p_arg_
)
{
return
p_arg_
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
}
size_t
GetDeviceKernelArgSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
GetWorkSpaceSize
(
p_arg
);
}
// TODO: deperecation notice.
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
// polymorphic
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
if
(
p_arg_
)
{
p_arg_
->
UpdateKBatch
(
kbatch
);
}
else
throw
std
::
runtime_error
(
"The argument pointer is not an object of "
"DeviceGroupedGemmMultipleDSplitKXdlCShuffle::Argument structure!"
);
}
void
SetDeviceKernelArgs
(
BaseArgument
*
p_arg
,
void
*
p_dev_kernel_args
)
const
override
{
return
this
->
SetWorkSpacePointer
(
p_arg
,
p_dev_kernel_args
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_streamk_v3.hpp
100644 → 100755
View file @
4525c5d7
...
@@ -14,6 +14,8 @@
...
@@ -14,6 +14,8 @@
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/workgroup_barrier.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -38,7 +40,7 @@ __global__ void
...
@@ -38,7 +40,7 @@ __global__ void
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
karg
,
karg
.
p_workspace_
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
#endif // end of if (defined(__gfx9__))
...
@@ -62,7 +64,13 @@ __global__ void
...
@@ -62,7 +64,13 @@ __global__ void
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
);
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
,
karg
.
p_workspace_
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
#endif // end of if (defined(__gfx9__))
...
@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -521,7 +529,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
Streamk_sel_
,
Grid_size_
},
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
Streamk_sel_
,
Grid_size_
},
p_a_grid
{
p_a_grid_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
p_c_grid
{
p_c_grid_
},
block_2_ctile_map_streamk
(
M_
,
N_
,
AK0Number
*
CalculateKPadded
(
K_
,
1
),
Grid_size_
,
Streamk_sel_
)
{
{
}
}
...
@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -529,6 +539,13 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
ADataType
*
p_a_grid
;
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
CDataType
*
p_c_grid
;
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
KPerBlock
,
StreamKReductionStrategy
::
Atomic
,
8
,
4
>
block_2_ctile_map_streamk
;
};
};
struct
SplitKBatchOffset
struct
SplitKBatchOffset
...
@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -853,6 +870,19 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MXdlPerWave
/
CShuffleMXdlPerWavePerShuffle
>
{},
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
Number
<
NXdlPerWave
/
CShuffleNXdlPerWavePerShuffle
>
{},
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
}
using
BlockwiseGemmPipe
=
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipelineVer
,
...
@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1118,6 +1148,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
__host__
__device__
static
constexpr
auto
GetClusterLengthReduction
()
{
// TODO: assume C is row major
// TODO: we always first loop over N, then M
constexpr
auto
NPerBlockPow2
=
math
::
next_power_of_two
<
NPerBlock
>
();
constexpr
auto
NPerBlockReduction
=
NPerBlockPow2
/
CShuffleBlockTransferScalarPerVector_NPerBlock
;
constexpr
auto
MPerBlockReduction
=
(
BlockSize
+
NPerBlockReduction
-
1
)
/
NPerBlockReduction
;
return
Sequence
<
MPerBlockReduction
,
NPerBlockReduction
>
{};
}
__host__
__device__
static
constexpr
auto
GetPartialAccBlockDescriptor
()
{
const
auto
c_partial_acc_block_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
NPerBlock
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MPerBlock
,
NPerBlock
),
make_tuple
(
I1
,
MPerBlock
));
}
}();
return
c_partial_acc_block_m_n
;
}
using
Block2CTileMap_streamk
=
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
using
Block2CTileMap_streamk
=
BlockToCTileMap_GemmStreamK_v2
<
MPerBlock
,
NPerBlock
,
NPerBlock
,
KPerBlock
,
KPerBlock
,
...
@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1132,22 +1190,42 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
void
*
p_shared
,
Problem
&
problem
)
Problem
&
problem
,
void
*
p_workspace
)
{
{
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
problem
.
Streamk_sel
);
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
bool
is_sk_block
,
is_dp_block
,
is_reduction_block
;
index_t
num_k_block_main_loop
;
index_t
num_k_block_main_loop
;
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
AccDataType
)));
for
(
auto
block_idx
=
get_block_1d_id
();
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
block_idx
+=
gridDim
.
x
)
...
@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1163,6 +1241,214 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
num_k_block_main_loop
=
iter_end
-
iter_start
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
block_idx
));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{});
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_m
=
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
0
,
0
,
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
0
,
0
,
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_m
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
acc_buf
;
// start to compute
auto
reduction_idx
=
block_idx
-
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
reduction_idx
,
problem
.
M
,
problem
.
N
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
__syncthreads
();
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
Sequence
<
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
>
,
// DimAccessOrder,
1
,
// SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
thread_m_cluster_id
,
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
// SrcData,
CDataType
,
// DstData,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
1
,
1
,
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
3
,
// DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
),
CElementwiseOperation
{}};
wg_barrier
.
wait_eq
(
reduction_idx
,
tile_acc_offset_end
-
tile_acc_offset_start
);
if
(
threadIdx
.
x
==
0
)
{
p_semaphore
[
reduction_idx
]
=
0
;
}
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
acc_buf
.
Clear
();
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
acc_thread_buf_load_desc
,
make_tuple
(
I0
,
I0
),
parcial_acc_buf
);
static_for
<
0
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
1
>
{}(
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
acc_thread_buf_load_desc
.
CalculateOffset
(
make_tuple
(
0
,
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
}
if
(
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
<
NPerBlock
)
{
acc_store
.
Run
(
acc_thread_buf_store_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
NReduceIters
!=
1
)
{
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
continue
;
}
}
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
while
(
true
)
{
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1173,33 +1459,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
...
@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1363,11 +1622,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1477,7 +1745,34 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1535,15 +1830,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}
else
if
(
is_sk_block
)
else
if
(
is_sk_block
)
{
{
// each block copy its data from LDS to global
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
c_shuffle_block_copy_lds_to_global
StreamKReductionStrategy
::
Atomic
)
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
{
decltype
(
c_grid_buf
),
// each block copy its data from LDS to global
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// constexpr offset
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
make_tuple
(
0
,
0
,
0
,
0
));
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_tuple
(
MXdlPerWave
,
0
,
NXdlPerWave
,
0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_partial_acc_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
}
}
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
...
@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1555,15 +1875,33 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
}
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
tile_idx
);
}
}
}
// shuffle c and write-out end
// exit condition
// exit condition
iter_end
-=
current_iter_length
;
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
if
(
iter_end
<=
iter_start
)
break
;
break
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
// make sure next loop LDS is ready for use
block_sync_lds
();
block_sync_lds
();
}
}
// while loop
}
}
// for loop
}
}
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
...
@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1574,19 +1912,43 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_0
,
void
*
p_shared_1
,
void
*
p_shared_1
,
Problem
&
problem
)
Problem
&
problem
,
void
*
p_workspace
)
{
{
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
);
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
uint32_t
iter_start
,
iter_end
;
uint32_t
iter_start
,
iter_end
;
bool
is_sk_block
,
is_dp_block
;
//, is_padding_block; //
, is_reduction_block;
bool
is_sk_block
,
is_dp_block
,
is_reduction_block
;
index_t
num_k_block_main_loop
;
index_t
num_k_block_main_loop
;
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
Block2CTileMap_streamk
block_2_ctile_map_streamk
(
problem
.
M
,
problem
.
N
,
AK0Number
*
problem
.
KPadded
,
problem
.
Grid_size
,
problem
.
Streamk_sel
);
for
(
auto
block_idx
=
get_block_1d_id
();
for
(
auto
block_idx
=
get_block_1d_id
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
<
block_2_ctile_map_streamk
.
get_grid_dims
();
block_idx
+=
gridDim
.
x
)
block_idx
+=
gridDim
.
x
)
...
@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1601,6 +1963,235 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
block_2_ctile_map_streamk
.
get_block_itr
(
block_idx
,
iter_start
,
iter_end
);
num_k_block_main_loop
=
iter_end
-
iter_start
;
num_k_block_main_loop
=
iter_end
-
iter_start
;
uint32_t
*
p_semaphore
=
reinterpret_cast
<
uint32_t
*>
(
reinterpret_cast
<
char
*>
(
p_workspace
)
+
block_2_ctile_map_streamk
.
get_workspace_size_for_acc
(
sizeof
(
AccDataType
)));
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
is_reduction_block
=
static_cast
<
uint32_t
>
(
block_idx
)
>=
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
if
(
is_reduction_block
)
{
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
block_idx
));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
constexpr
auto
NReduceIters
=
math
::
integer_divide_ceil
(
Number
<
NPerBlock
>
{},
cluster_length_reduce
.
At
(
I1
)
*
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{});
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CShuffleBlockTransferScalarPerVector_NPerBlock
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
partial_acc_load_step_n
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_n_reverse
=
make_multi_index
(
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_load_step_m
=
make_multi_index
(
cluster_length_reduce
.
At
(
I0
),
0
);
constexpr
auto
partial_acc_store_step_n
=
make_multi_index
(
0
,
0
,
0
,
cluster_length_reduce
.
At
(
I1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_n_reverse
=
make_multi_index
(
0
,
0
,
0
,
-
1
*
cluster_length_reduce
.
At
(
I1
).
value
*
(
NReduceIters
-
1
)
*
CShuffleBlockTransferScalarPerVector_NPerBlock
);
constexpr
auto
partial_acc_store_step_m
=
make_multi_index
(
0
,
cluster_length_reduce
.
At
(
I0
),
0
,
0
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
parcial_acc_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
true
>
acc_buf
;
// start to compute
auto
reduction_idx
=
block_idx
-
block_2_ctile_map_streamk
.
reduction_start_block_idx
;
auto
spatial_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
reduction_idx
,
problem
.
M
,
problem
.
N
);
workgroup_barrier
wg_barrier
(
p_semaphore
);
uint32_t
tile_acc_offset_start
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
);
uint32_t
tile_acc_offset_end
=
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_tile
(
reduction_idx
+
1
);
uint32_t
expected_count
=
tile_acc_offset_end
-
tile_acc_offset_start
;
if
(
threadIdx
.
x
==
0
)
{
p_semaphore
[
reduction_idx
]
=
0
;
}
__syncthreads
();
auto
acc_load
=
ThreadwiseTensorSliceTransfer_v2
<
AccDataType
,
// SrcData,
AccDataType
,
// DstData,
decltype
(
c_partial_acc_block_m_n
),
// SrcDesc,
decltype
(
acc_thread_buf_load_desc
),
// DstDesc,
Sequence
<
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
>
,
// DimAccessOrder,
1
,
// SrcVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
thread_m_cluster_id
,
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
// SrcData,
CDataType
,
// DstData,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
1
,
1
,
1
,
CShuffleBlockTransferScalarPerVector_NPerBlock
>
,
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
3
,
// DstVectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
),
CElementwiseOperation
{}};
#if 0
if(threadIdx.x == 0) {
printf("bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", static_cast<int>(blockIdx.x),
reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
__builtin_amdgcn_readfirstlane(spatial_idx[I0]),
__builtin_amdgcn_readfirstlane(spatial_idx[I1]));
}
#endif
if
(
threadIdx
.
x
==
0
)
{
atomicAdd
(
&
p_semaphore
[
reduction_idx
],
1
);
}
wg_barrier
.
wait_eq
(
p_semaphore
[
reduction_idx
],
expected_count
);
using
Accumulation
=
ck
::
detail
::
AccumulateWithNanCheck
<
false
/*PropagateNan*/
,
reduce
::
Add
,
AccDataType
>
;
for
(
int
i_m
=
0
;
i_m
<
MReduceIters
;
i_m
++
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
acc_buf
.
Clear
();
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
c_partial_acc_buf
,
acc_thread_buf_load_desc
,
make_tuple
(
I0
,
I0
),
parcial_acc_buf
);
static_for
<
0
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
1
>
{}(
[
&
](
auto
i_vec
)
{
constexpr
auto
offset
=
acc_thread_buf_load_desc
.
CalculateOffset
(
make_tuple
(
0
,
i_vec
));
Accumulation
::
Calculate
(
acc_buf
(
Number
<
offset
>
{}),
parcial_acc_buf
[
Number
<
offset
>
{}]);
});
}
if
(
thread_n_cluster_id
*
CShuffleBlockTransferScalarPerVector_NPerBlock
<
NPerBlock
)
{
acc_store
.
Run
(
acc_thread_buf_store_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
acc_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
if
constexpr
(
NReduceIters
!=
1
)
{
if
constexpr
(
i_n_reduce
!=
(
NReduceIters
-
1
))
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n
);
}
else
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_n_reverse
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_n_reverse
);
}
}
});
{
acc_load
.
MoveSrcSliceWindow
(
c_partial_acc_block_m_n
,
partial_acc_load_step_m
);
acc_store
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
partial_acc_store_step_m
);
}
}
continue
;
}
}
// offset for last acc buffer of this block
uint32_t
block_acc_offset
=
(
block_2_ctile_map_streamk
.
get_acc_buffer_offset_from_block
(
block_idx
+
1
)
-
1
)
*
MPerBlock
*
NPerBlock
;
while
(
true
)
{
{
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
uint32_t
current_iter_length
=
__builtin_amdgcn_readfirstlane
(
...
@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1611,33 +2202,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_end
-
1
,
tile_idx
,
iter_offset
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
iter_offset
=
__builtin_amdgcn_readfirstlane
(
iter_offset
-
current_iter_length
+
1
);
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeBGridDescriptor_BK0_N_BK1
(
problem
.
K
,
problem
.
KPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideB
,
problem
.
BK0
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
problem
.
M
,
problem
.
MPadded
,
problem
.
N
,
problem
.
NPadded
,
problem
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
,
problem
.
MBlock
,
problem
.
NBlock
);
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
block_work_idx
=
auto
block_work_idx
=
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
block_2_ctile_map_streamk
.
tile_to_spatial
(
tile_idx
,
problem
.
M
,
problem
.
N
);
...
@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1811,11 +2375,20 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
=
GetCBlockDescriptor_MShuffle_MPerShuffle_NShuffle_NPerShuffle
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
AmdBufferCoherenceEnum
::
GLC
>
(
reinterpret_cast
<
AccDataType
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1925,6 +2498,35 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
c_element_op
};
// LDS to global partial acc
auto
c_block_copy_lds_to_partial_acc
=
ThreadGroupTensorSliceTransfer_v6r1r2
<
ThisThreadBlock
,
// index_t BlockSize,
CElementwiseOperation
,
// ElementwiseOperation,
// InMemoryDataOperationEnum::Set, // DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CShuffleDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
false
,
// bool ThreadTransferSrcResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
false
>
// bool ThreadTransferDstResetCoordinateAfterRun, => need to be
// false, othre wise has scratch
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -1982,15 +2584,40 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}
else
if
(
is_sk_block
)
else
if
(
is_sk_block
)
{
{
// each block copy its data from LDS to global
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
c_shuffle_block_copy_lds_to_global
StreamKReductionStrategy
::
Atomic
)
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
{
decltype
(
c_grid_buf
),
// each block copy its data from LDS to global
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_copy_lds_to_global
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_grid_buf
),
InMemoryDataOperationEnum
::
AtomicAdd
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
}
else
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
// constexpr offset
c_block_copy_lds_to_partial_acc
.
SetSrcSliceOrigin
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
make_tuple
(
0
,
0
,
0
,
0
));
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
c_block_copy_lds_to_partial_acc
.
SetDstSliceOrigin
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
make_tuple
(
MXdlPerWave
,
0
,
NXdlPerWave
,
0
));
c_block_copy_lds_to_partial_acc
.
template
Run
<
decltype
(
c_shuffle_block_buf
),
decltype
(
c_partial_acc_buf
),
InMemoryDataOperationEnum
::
Set
>(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_partial_acc_buf
);
}
}
}
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
...
@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
...
@@ -2002,6 +2629,27 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}
}
});
});
}
}
// exit condition
iter_end
-=
current_iter_length
;
if
(
iter_end
<=
iter_start
)
break
;
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
block_acc_offset
-=
MPerBlock
*
NPerBlock
;
}
// make sure next loop LDS is ready for use
block_sync_lds
();
}
if
constexpr
(
Block2CTileMap_streamk
::
ReductionStrategy
==
StreamKReductionStrategy
::
Reduction
)
{
if
(
is_sk_block
)
{
// increase the counter for this tile
workgroup_barrier
wg_barrier
(
p_semaphore
);
wg_barrier
.
inc
(
0
);
}
}
}
}
}
}
}
...
...
include/ck/utility/loop_scheduler.hpp
View file @
4525c5d7
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck_tile/core.hpp
View file @
4525c5d7
...
@@ -52,6 +52,7 @@
...
@@ -52,6 +52,7 @@
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_elementwise.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/tile_window_utils.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/tensor/update_tile.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
...
@@ -62,6 +63,7 @@
...
@@ -62,6 +63,7 @@
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/philox_rand.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/random.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/reduce_operator.hpp"
#include "ck_tile/core/utility/static_counter.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/to_sequence.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/transpose_vectors.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
...
...
Prev
1
2
3
4
5
6
7
8
9
10
…
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment