Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d480a5a6
Unverified
Commit
d480a5a6
authored
Feb 03, 2025
by
Max Podkorytov
Committed by
GitHub
Feb 03, 2025
Browse files
Merge branch 'develop' into ck-flex
parents
bca939ce
9c5b2f39
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
158 additions
and
136 deletions
+158
-136
docs/sphinx/requirements.in
docs/sphinx/requirements.in
+1
-1
docs/sphinx/requirements.txt
docs/sphinx/requirements.txt
+1
-1
example/ck_tile/03_gemm/gemm_basic.cpp
example/ck_tile/03_gemm/gemm_basic.cpp
+14
-25
example/ck_tile/03_gemm/gemm_basic.hpp
example/ck_tile/03_gemm/gemm_basic.hpp
+1
-1
example/ck_tile/03_gemm/script/benchmark_basic.sh
example/ck_tile/03_gemm/script/benchmark_basic.sh
+2
-2
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
+2
-2
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+19
-6
example/ck_tile/16_batched_gemm/batched_gemm.cpp
example/ck_tile/16_batched_gemm/batched_gemm.cpp
+17
-27
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
+20
-26
include/ck/ck.hpp
include/ck/ck.hpp
+2
-2
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+2
-2
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+5
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+9
-4
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+18
-4
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+31
-26
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+1
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp
...impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+6
-1
No files found.
docs/sphinx/requirements.in
View file @
d480a5a6
rocm-docs-core==1.1
4.1
rocm-docs-core==1.1
5.0
sphinxcontrib-bibtex==2.6.3
sphinxcontrib-bibtex==2.6.3
docs/sphinx/requirements.txt
View file @
d480a5a6
...
@@ -199,7 +199,7 @@ requests==2.32.3
...
@@ -199,7 +199,7 @@ requests==2.32.3
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core==1.1
4.1
rocm-docs-core==1.1
5.0
# via -r requirements.in
# via -r requirements.in
rpds-py==0.22.3
rpds-py==0.22.3
# via
# via
...
...
example/ck_tile/03_gemm/gemm_basic.cpp
View file @
d480a5a6
...
@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -20,10 +20,6 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
// This part comes from the Codegen
// This part comes from the Codegen
...
@@ -39,38 +35,31 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -39,38 +35,31 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
example/ck_tile/03_gemm/gemm_basic.hpp
View file @
d480a5a6
...
@@ -79,7 +79,7 @@ auto create_args(int argc, char* argv[])
...
@@ -79,7 +79,7 @@ auto create_args(int argc, char* argv[])
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"n"
,
"4096"
,
"n dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"k"
,
"2048"
,
"k dimension"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"a_layout"
,
"R"
,
"A tensor data layout - Row by default"
)
.
insert
(
"b_layout"
,
"
R
"
,
"B tensor data layout -
Row
by default"
)
.
insert
(
"b_layout"
,
"
C
"
,
"B tensor data layout -
Column
by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"c_layout"
,
"R"
,
"C tensor data layout - Row by default"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_a"
,
"0"
,
"Tensor A stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
.
insert
(
"stride_b"
,
"0"
,
"Tensor B stride"
)
...
...
example/ck_tile/03_gemm/script/benchmark_basic.sh
View file @
d480a5a6
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_basic
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
done
done
...
...
example/ck_tile/03_gemm/script/benchmark_mem_pipeline.sh
View file @
d480a5a6
#!/bin/sh
#!/bin/sh
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
EXE
=
"
$(
find
.
-name
tile_example_gemm_universal
-type
f |
head
-n
1
)
"
VALID
=
0
VALID
=
1
for
b_matrix_layout
in
"R"
"C"
;
do
for
b_matrix_layout
in
"R"
"C"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
m
in
"64"
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
n
in
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
for
k
in
"64"
"512"
"1024"
"2048"
;
do
$EXE
-prec
=
fp16
-b
=
1
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
$EXE
-prec
=
fp16
-m
=
$m
-n
=
$n
-k
=
$k
-a_layout
=
"R"
-b_layout
=
"
$b_matrix_layout
"
-c_layout
=
"R"
-v
=
$VALID
done
done
done
done
done
done
...
...
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -50,7 +50,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -50,7 +50,9 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
bool
TransposeC
=
false
;
constexpr
bool
TransposeC
=
false
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
ck_tile
::
index_t
TileParitionerGroupNum
=
8
;
constexpr
ck_tile
::
index_t
TileParitionerM01
=
4
;
// ===============================================
// ===============================================
...
@@ -58,10 +60,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -58,10 +60,8 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
GemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmSpatiallyLocalTilePartitioner
<
GemmShape
,
TileParitionerGroupNum
,
TileParitionerM01
>
;
using
GemmEpilogue
=
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
Traits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
GemmUniversalTraits
=
ck_tile
::
using
GemmUniversalTraits
=
ck_tile
::
...
@@ -95,6 +95,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -95,6 +95,19 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
using
GemmPipeline
=
using
GemmPipeline
=
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
GEMM_PIPELINE
<
UniversalGemmProblem
,
ck_tile
::
UniversalGemmPipelineAgBgCrPolicy
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
GemmPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
UniversalGemmProblem
::
TransposeC
>>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
GemmKernel
<
TilePartitioner
,
GemmPipeline
,
GemmEpilogue
>
;
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
auto
kargs
=
Kernel
::
MakeKernelArgs
(
args
);
...
...
example/ck_tile/16_batched_gemm/batched_gemm.cpp
View file @
d480a5a6
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
...
@@ -19,12 +19,9 @@ template <typename ALayout, typename BLayout, typename CLayout>
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
batched_gemm
(
const
ck_tile
::
BatchedGemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadM
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadN
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kPadK
=
false
;
constexpr
bool
kTilePermute
=
false
;
// The rank and permutation will also be generate out by the CodeGen part.
constexpr
ck_tile
::
index_t
kOutputRank
=
2
;
constexpr
int
kBlockPerCu
=
1
;
constexpr
int
kBlockPerCu
=
1
;
...
@@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
...
@@ -41,38 +38,31 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
// Whether doing the CShuffle (transpose before the global memory), depending on the output
// layout.
constexpr
bool
CShuffleEpilogue
=
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
CodegenGemmShape
=
using
CodegenGemmShape
=
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
TileGemmShape
<
ck_tile
::
sequence
<
M_Tile
,
N_Tile
,
K_Tile
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp
,
N_Warp
,
K_Warp
>
,
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
ck_tile
::
sequence
<
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
>>
;
using
TilePartitioner
=
ck_tile
::
GemmTile2DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
GemmEpilogue
=
std
::
conditional_t
<
CShuffleEpilogue
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
,
kTilePermute
,
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
kPadM
,
kPadN
>>>
;
using
CodegenGemmTraits
=
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
ck_tile
::
TileGemmTraits
<
kPadM
,
kPadN
,
kPadK
,
ALayout
,
BLayout
,
CLayout
>
;
using
CodegenPipelineProblem
=
ck_tile
::
using
CodegenPipelineProblem
=
ck_tile
::
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
GemmPipelineProblem
<
ADataType
,
BDataType
,
AccDataType
,
CodegenGemmShape
,
CodegenGemmTraits
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
>
;
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
M_Warp
,
N_Warp
,
M_Warp_Tile
,
N_Warp_Tile
,
K_Warp_Tile
,
CodegenPipelineProblem
::
TransposeC
>>
;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
using
Kernel
=
ck_tile
::
BatchedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
,
GemmEpilogue
>
;
...
...
example/ck_tile/17_grouped_gemm/grouped_gemm.cpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -20,12 +20,9 @@ namespace {
...
@@ -20,12 +20,9 @@ namespace {
struct
GroupedGemmKernelParam
struct
GroupedGemmKernelParam
{
{
static
const
bool
kPadM
=
false
;
static
const
bool
kPadM
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadN
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kPadK
=
false
;
static
const
bool
kTilePermute
=
false
;
static
const
ck_tile
::
index_t
kOutputRank
=
2
;
static
const
int
kBlockPerCu
=
1
;
static
const
int
kBlockPerCu
=
1
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
static
const
ck_tile
::
index_t
M_Tile
=
128
;
...
@@ -54,24 +51,6 @@ using CodegenGemmShape =
...
@@ -54,24 +51,6 @@ using CodegenGemmShape =
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
using
TilePartitioner
=
ck_tile
::
GemmTile1DPartitioner
<
CodegenGemmShape
>
;
template
<
typename
CLayout
>
using
GemmEpilogue
=
std
::
conditional_t
<
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
,
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
,
GroupedGemmKernelParam
::
kTilePermute
,
GroupedGemmKernelParam
::
kOutputRank
,
1
,
0
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
>>
,
ck_tile
::
Default2DEpilogue
<
ck_tile
::
Default2DEpilogueProblem
<
AccDataType
,
CDataType
,
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
>>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemmKernelParam
::
kPadM
,
using
CodegenGemmTraits
=
ck_tile
::
TileGemmTraits
<
GroupedGemmKernelParam
::
kPadM
,
GroupedGemmKernelParam
::
kPadN
,
GroupedGemmKernelParam
::
kPadN
,
...
@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
...
@@ -92,10 +71,25 @@ template <typename ALayout, typename BLayout, typename CLayout>
using
CodegenGemmPipeline
=
using
CodegenGemmPipeline
=
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
ck_tile
::
GemmPipelineAGmemBGmemCRegV1
<
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
GemmEpilogue
=
ck_tile
::
CShuffleEpilogue
<
ck_tile
::
CShuffleEpilogueProblem
<
AccDataType
,
CDataType
,
CLayout
,
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>::
kBlockSize
,
TilePartitioner
::
MPerBlock
,
TilePartitioner
::
NPerBlock
,
GroupedGemmKernelParam
::
M_Warp
,
GroupedGemmKernelParam
::
N_Warp
,
GroupedGemmKernelParam
::
M_Warp_Tile
,
GroupedGemmKernelParam
::
N_Warp_Tile
,
GroupedGemmKernelParam
::
K_Warp_Tile
,
CodegenPipelineProblem
<
ALayout
,
BLayout
,
CLayout
>::
TransposeC
>>
;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
using
Kernel
=
ck_tile
::
GroupedGemmKernel
<
TilePartitioner
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
CodegenGemmPipeline
<
ALayout
,
BLayout
,
CLayout
>
,
GemmEpilogue
<
CLayout
>>
;
GemmEpilogue
<
ALayout
,
BLayout
,
CLayout
>>
;
};
// namespace
};
// namespace
std
::
size_t
get_workspace_size
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
std
::
size_t
get_workspace_size
(
const
std
::
vector
<
grouped_gemm_kargs
>&
gemm_descs
)
...
...
include/ck/ck.hpp
View file @
d480a5a6
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include "ck/config.h"
#include "ck/config.h"
#include "ck/utility/env.hpp"
#include "ck/utility/env.hpp"
#ifndef CK_CODE_GEN_RTC
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
// environment variable to enable logging:
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
#ifndef CK_TIME_KERNEL
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
...
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
}
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#include <string>
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
...
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
Filter3x3
,
Filter3x3
,
};
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
...
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#include <string>
#include <sstream>
#include <sstream>
#include <regex>
#include <regex>
#include <optional>
#include <optional>
#include "ck/stream_config.hpp"
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifndef CK_CODE_GEN_RTC
#define GET_OBJECT_NAME_IMLP \
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
std::optional<std::string> GetObjectName() const override \
{ \
{ \
...
@@ -41,7 +43,9 @@ namespace device {
...
@@ -41,7 +43,9 @@ namespace device {
}
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
#endif
#ifndef CK_CODE_GEN_RTC
struct
BaseArgument
struct
BaseArgument
{
{
BaseArgument
()
=
default
;
BaseArgument
()
=
default
;
...
@@ -66,13 +70,14 @@ struct BaseInvoker
...
@@ -66,13 +70,14 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
virtual
~
BaseInvoker
()
{}
};
};
#endif
struct
BaseOperator
struct
BaseOperator
{
{
BaseOperator
()
=
default
;
BaseOperator
()
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef CK_CODE_GEN_RTC
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
@@ -100,7 +105,7 @@ struct BaseOperator
...
@@ -100,7 +105,7 @@ struct BaseOperator
assert
(
p_arg
);
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
p_arg
->
p_workspace_
=
p_workspace
;
}
}
#endif
virtual
~
BaseOperator
()
{}
virtual
~
BaseOperator
()
{}
};
};
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <array>
#include <array>
#endif
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
@@ -13,8 +15,13 @@ namespace ck {
...
@@ -13,8 +15,13 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
/**
/**
* \brief Grouped Convolution Forward
* \brief Grouped Convolution Forward
...
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
...
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
#ifdef CK_CODE_GEN_RTC
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#else
// If DataType is tuple, user has to pass std::array with pointers.
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#endif
#ifndef CK_CODE_GEN_RTC
/**
/**
* \brief Make argument pointer for grouped conv fwd.
* \brief Make argument pointer for grouped conv fwd.
...
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
...
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
d480a5a6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
...
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
MNKOPadding
,
MNKOPadding
,
};
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
{
switch
(
s
)
switch
(
s
)
...
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
...
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
#endif
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
d480a5a6
...
@@ -3,11 +3,17 @@
...
@@ -3,11 +3,17 @@
#pragma once
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <functional>
#include <functional>
#include <iostream>
#include <iostream>
#include <iterator>
#include <iterator>
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include <stdio.h>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -15,15 +21,12 @@
...
@@ -15,15 +21,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -259,8 +262,13 @@ __global__ void
...
@@ -259,8 +262,13 @@ __global__ void
}
// namespace
}
// namespace
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
//
//
// @brief Device Convolution operation.
// @brief Device Convolution operation.
...
@@ -429,8 +437,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -429,8 +437,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmADataType
=
ck
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
using
GemmBDataType
=
ck
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
...
@@ -449,15 +457,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -449,15 +457,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
// Use appropriate gridwise gemm
using
GridwiseGemm
=
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
ck
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using
APointers
=
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
std
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
using
AGridPointer
=
remove_cvref_t
<
...
@@ -812,7 +818,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -812,7 +818,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// FIXME: layout
// FIXME: layout
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
...
@@ -965,18 +970,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
...
@@ -965,18 +970,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
const
CDEElementwiseOperation
&
cde_element_op
)
{
{
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
d480a5a6
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck/library/utility/numeric.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
#include "ck/tensor_operation/gpu/device/device_conv_tensor_rearrange.hpp"
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
d480a5a6
...
@@ -205,8 +205,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -205,8 +205,8 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
ck
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0Padded
=
karg
.
K0Padded
;
const
auto
K0Padded
=
karg
.
K0Padded
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0Padded
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0Padded
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle_lds_direct_load.hpp
View file @
d480a5a6
...
@@ -183,8 +183,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
...
@@ -183,8 +183,8 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
index_t
gdx
,
gdy
,
gdz
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
ck
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0Padded
=
karg
.
K0Padded
;
const
auto
K0Padded
=
karg
.
K0Padded
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0Padded
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0Padded
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
d480a5a6
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include <numeric>
#include <numeric>
#include <sstream>
#include <sstream>
#include "ck/library/utility/numeric.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
...
@@ -212,9 +213,13 @@ __global__ void
...
@@ -212,9 +213,13 @@ __global__ void
}
}
}
// namespace
}
// namespace
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
//
//
// @brief Device Convolution operation.
// @brief Device Convolution operation.
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment