Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8820cf9f
Commit
8820cf9f
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Merge branch 'develop' into feature/integrage-karg-simplification-pr
parents
cb46ef7a
4feebedd
Changes
157
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
325 additions
and
119 deletions
+325
-119
docs/.sphinx/requirements.txt
docs/.sphinx/requirements.txt
+13
-17
example/15_grouped_gemm/CMakeLists.txt
example/15_grouped_gemm/CMakeLists.txt
+3
-1
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
+97
-0
example/15_grouped_gemm/run_grouped_gemm_example.inc
example/15_grouped_gemm/run_grouped_gemm_example.inc
+1
-0
example/31_batched_gemm_gemm/CMakeLists.txt
example/31_batched_gemm_gemm/CMakeLists.txt
+3
-1
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc
...n/run_conv2d_fwd_bias_perchannel_quantization_example.inc
+3
-3
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc
...ion/run_conv2d_fwd_bias_perlayer_quantization_example.inc
+3
-3
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc
...zation/run_conv2d_fwd_perchannel_quantization_example.inc
+3
-3
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
...tization/run_conv2d_fwd_perlayer_quantization_example.inc
+3
-3
example/41_grouped_conv_conv_fwd/CMakeLists.txt
example/41_grouped_conv_conv_fwd/CMakeLists.txt
+3
-2
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+2
-1
example/42_groupnorm/common.hpp
example/42_groupnorm/common.hpp
+23
-0
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
+56
-0
example/42_groupnorm/groupnorm_swish_fp16.cpp
example/42_groupnorm/groupnorm_swish_fp16.cpp
+40
-0
example/42_groupnorm/run_groupnorm_example.inc
example/42_groupnorm/run_groupnorm_example.inc
+7
-72
include/ck/ck.hpp
include/ck/ck.hpp
+17
-8
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
...tion/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
...ensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
+39
-0
No files found.
docs/.sphinx/requirements.txt
View file @
8820cf9f
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
# This file is autogenerated by pip-compile with Python 3.10
# This file is autogenerated by pip-compile with Python 3.10
# by the following command:
# by the following command:
#
#
# pip-compile requirements.in
# pip-compile
.sphinx/
requirements.in
#
#
accessible-pygments==0.0.
4
accessible-pygments==0.0.
3
# via pydata-sphinx-theme
# via pydata-sphinx-theme
alabaster==0.7.13
alabaster==0.7.13
# via sphinx
# via sphinx
...
@@ -20,7 +20,7 @@ babel==2.12.1
...
@@ -20,7 +20,7 @@ babel==2.12.1
# sphinx
# sphinx
backcall==0.2.0
backcall==0.2.0
# via ipython
# via ipython
beautifulsoup4==4.1
2.0
beautifulsoup4==4.1
1.2
# via pydata-sphinx-theme
# via pydata-sphinx-theme
breathe==4.34.0
breathe==4.34.0
# via rocm-docs-core
# via rocm-docs-core
...
@@ -34,7 +34,7 @@ click==8.1.3
...
@@ -34,7 +34,7 @@ click==8.1.3
# via
# via
# jupyter-cache
# jupyter-cache
# sphinx-external-toc
# sphinx-external-toc
comm==0.1.
3
comm==0.1.
2
# via ipykernel
# via ipykernel
debugpy==1.6.6
debugpy==1.6.6
# via ipykernel
# via ipykernel
...
@@ -65,13 +65,11 @@ idna==3.4
...
@@ -65,13 +65,11 @@ idna==3.4
# via requests
# via requests
imagesize==1.4.1
imagesize==1.4.1
# via sphinx
# via sphinx
importlib-metadata==6.
1
.0
importlib-metadata==6.
0
.0
# via
# via
# jupyter-cache
# jupyter-cache
# myst-nb
# myst-nb
importlib-resources==5.10.4
ipykernel==6.21.3
# via rocm-docs-core
ipykernel==6.22.0
# via myst-nb
# via myst-nb
ipython==8.11.0
ipython==8.11.0
# via
# via
...
@@ -87,7 +85,7 @@ jsonschema==4.17.3
...
@@ -87,7 +85,7 @@ jsonschema==4.17.3
# via nbformat
# via nbformat
jupyter-cache==0.5.0
jupyter-cache==0.5.0
# via myst-nb
# via myst-nb
jupyter-client==8.
1.0
jupyter-client==8.
0.3
# via
# via
# ipykernel
# ipykernel
# nbclient
# nbclient
...
@@ -124,7 +122,7 @@ nbclient==0.5.13
...
@@ -124,7 +122,7 @@ nbclient==0.5.13
# via
# via
# jupyter-cache
# jupyter-cache
# myst-nb
# myst-nb
nbformat==5.
8.0
nbformat==5.
7.3
# via
# via
# jupyter-cache
# jupyter-cache
# myst-nb
# myst-nb
...
@@ -187,7 +185,7 @@ pyyaml==6.0
...
@@ -187,7 +185,7 @@ pyyaml==6.0
# myst-parser
# myst-parser
# pybtex
# pybtex
# sphinx-external-toc
# sphinx-external-toc
pyzmq==25.0.
2
pyzmq==25.0.
1
# via
# via
# ipykernel
# ipykernel
# jupyter-client
# jupyter-client
...
@@ -195,8 +193,8 @@ requests==2.28.2
...
@@ -195,8 +193,8 @@ requests==2.28.2
# via
# via
# pygithub
# pygithub
# sphinx
# sphinx
rocm-docs-core
@ git+https://github.com/RadeonOpenCompute/rocm-docs-core.git
rocm-docs-core
==0.2.0
# via -r requirements.in
# via -r
.sphinx/
requirements.in
six==1.16.0
six==1.16.0
# via
# via
# asttokens
# asttokens
...
@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3
...
@@ -235,9 +233,7 @@ sphinx-notfound-page==0.8.3
sphinxcontrib-applehelp==1.0.4
sphinxcontrib-applehelp==1.0.4
# via sphinx
# via sphinx
sphinxcontrib-bibtex==2.5.0
sphinxcontrib-bibtex==2.5.0
# via
# via -r .sphinx/requirements.in
# -r requirements.in
# rocm-docs-core
sphinxcontrib-devhelp==1.0.2
sphinxcontrib-devhelp==1.0.2
# via sphinx
# via sphinx
sphinxcontrib-htmlhelp==2.0.1
sphinxcontrib-htmlhelp==2.0.1
...
@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3
...
@@ -248,7 +244,7 @@ sphinxcontrib-qthelp==1.0.3
# via sphinx
# via sphinx
sphinxcontrib-serializinghtml==1.1.5
sphinxcontrib-serializinghtml==1.1.5
# via sphinx
# via sphinx
sqlalchemy==1.4.4
7
sqlalchemy==1.4.4
6
# via jupyter-cache
# via jupyter-cache
stack-data==0.6.2
stack-data==0.6.2
# via ipython
# via ipython
...
...
example/15_grouped_gemm/CMakeLists.txt
View file @
8820cf9f
...
@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
...
@@ -5,6 +5,7 @@ add_example_executable(example_grouped_gemm_xdl_fp16 grouped_gemm_xdl_fp16.cpp)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_bfp16 grouped_gemm_xdl_bfp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int8 grouped_gemm_xdl_int8.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_multiple_d_dl_fp16 grouped_gemm_multiple_d_dl_fp16.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_splitk_fp16 grouped_gemm_xdl_splitk_fp16.cpp
)
add_dependencies
(
example_grouped_gemm_xdl
add_dependencies
(
example_grouped_gemm_xdl
...
@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl
...
@@ -12,7 +13,8 @@ add_dependencies(example_grouped_gemm_xdl
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_fp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_bfp16
example_grouped_gemm_xdl_int8
example_grouped_gemm_xdl_int8
example_grouped_gemm_multiple_d_dl_fp16
)
example_grouped_gemm_multiple_d_dl_fp16
example_grouped_gemm_xdl_splitk_fp16
)
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
add_example_executable
(
example_grouped_gemm_xdl_int4 grouped_gemm_xdl_int4.cpp
)
...
...
example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp
0 → 100644
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F16
;
using
DsDataType
=
ck
::
Tuple
<>
;
using
EDataType
=
F16
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
DsLayout
=
ck
::
Tuple
<>
;
using
ELayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedGemmXdlSplitKCShuffle
// clang-format off
//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// clang-format on
#include "run_grouped_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
ProblemSize
problem_size
;
ExecutionConfig
config
;
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
problem_size
.
Ns
.
push_back
(
768
);
problem_size
.
Ks
.
push_back
(
4608
);
problem_size
.
stride_As
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Bs
.
push_back
(
problem_size
.
Ks
[
i
]);
problem_size
.
stride_Cs
.
push_back
(
problem_size
.
Ns
[
i
]);
}
if
(
argc
==
4
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=n0, 1=yes)
\n
"
);
exit
(
0
);
}
return
!
run_grouped_gemm
(
problem_size
,
config
);
}
example/15_grouped_gemm/run_grouped_gemm_example.inc
View file @
8820cf9f
...
@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
...
@@ -147,6 +147,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
#else
#else
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
]
.
mData
.
data
());
a_tensors_device
[
i
]
->
ToDevice
(
a_tensors
[
i
]
.
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
]
.
mData
.
data
());
b_tensors_device
[
i
]
->
ToDevice
(
b_tensors
[
i
]
.
mData
.
data
());
c_tensors_device
[
i
]
->
SetZero
();
#endif
#endif
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
p_a
.
push_back
(
a_tensors_device
[
i
]
->
GetDeviceBuffer
());
...
...
example/31_batched_gemm_gemm/CMakeLists.txt
View file @
8820cf9f
add_example_executable
(
example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp32 batched_gemm_gemm_xdl_fp32.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_fp16 batched_gemm_gemm_xdl_fp16.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_bf16 batched_gemm_gemm_xdl_bf16.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
if
(
NOT GPU_TARGETS MATCHES
"gfx940"
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perchannel_quantization_example.inc
View file @
8820cf9f
...
@@ -190,11 +190,11 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_
...
@@ -190,11 +190,11 @@ int run_conv2d_fwd_bias_perchannel_quantization_example(const OutElementOp& out_
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYX
G
C
;
using
BiasLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
BiasLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
RequantScaleLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
RequantScaleLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
K
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc
View file @
8820cf9f
...
@@ -178,10 +178,10 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el
...
@@ -178,10 +178,10 @@ int run_conv2d_fwd_bias_perlayer_quantization_example(const OutElementOp& out_el
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYX
G
C
;
using
BiasLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
BiasLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
K
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc
View file @
8820cf9f
...
@@ -180,10 +180,10 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme
...
@@ -180,10 +180,10 @@ int run_conv2d_fwd_perchannel_quantization_example(const OutElementOp& out_eleme
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYX
G
C
;
using
RequantScaleLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
RequantScaleLayout
=
ck
::
tensor_layout
::
convolution
::
G_K
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
K
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
View file @
8820cf9f
...
@@ -162,9 +162,9 @@ int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element
...
@@ -162,9 +162,9 @@ int run_conv2d_fwd_perlayer_quantization_example(const OutElementOp& out_element
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
C
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
G
KYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
KYX
G
C
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
G
NHWK
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHW
G
K
;
const
auto
in_g_n_c_wis_desc
=
const
auto
in_g_n_c_wis_desc
=
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
ck
::
utils
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
InLayout
>
(
conv_param
);
...
...
example/41_grouped_conv_conv_fwd/CMakeLists.txt
View file @
8820cf9f
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp32 grouped_conv_conv_fwd_xdl_fp32.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_fp16 grouped_conv_conv_fwd_xdl_fp16.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_bf16 grouped_conv_conv_fwd_xdl_bf16.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
if
(
NOT GPU_TARGETS MATCHES
"gfx940"
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
example/42_groupnorm/CMakeLists.txt
View file @
8820cf9f
add_example_executable
(
example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp
)
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
example/42_groupnorm/common.hpp
0 → 100644
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
0 → 100644
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
y
=
x
*
a
;
};
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
#include "run_groupnorm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
run_groupnorm_example
(
argc
,
argv
);
}
example/42_groupnorm/groupnorm_swish_fp16.cpp
0 → 100644
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
YElementOp
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
#include "run_groupnorm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
run_groupnorm_example
(
argc
,
argv
);
}
example/42_groupnorm/groupnorm_
sigmoid_fp16.cpp
→
example/42_groupnorm/
run_
groupnorm_
example.inc
View file @
8820cf9f
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#pragma once
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
int
run_groupnorm_example
(
int
argc
,
char
*
argv
[])
y
=
x
*
a
;
};
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
int
main
(
int
argc
,
char
*
argv
[])
{
{
ck
::
index_t
N
=
2
;
ck
::
index_t
N
=
3
2
;
ck
::
index_t
H
=
32
;
ck
::
index_t
H
=
16
;
ck
::
index_t
W
=
32
;
ck
::
index_t
W
=
16
;
ck
::
index_t
G
=
32
;
ck
::
index_t
G
=
64
;
ck
::
index_t
C
=
30
;
ck
::
index_t
C
=
128
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
...
include/ck/ck.hpp
View file @
8820cf9f
...
@@ -31,20 +31,20 @@
...
@@ -31,20 +31,20 @@
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#define CK_BUFFER_RESOURCE_3RD_DWORD -1
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
#elif defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) // for GPU code
defined(__gfx90a__)
|| defined(__gfx940__)
// for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x100
20
000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x
3
100
4
000
#endif
#endif
// FMA instruction
// FMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#ifndef __HIP_DEVICE_COMPILE__ // for host code, define nothing
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#elif defined(__gfx803__) || defined(__gfx900__) // for GPU code
#define CK_USE_AMD_V_MAC_F32
#define CK_USE_AMD_V_MAC_F32
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || \
#elif defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) ||
defined(__gfx1030__) ||
\
defined(__gfx
103
0__) // for GPU code
defined(__gfx
94
0__) // for GPU code
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_FMAC_F32
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT2_F32_F16
#define CK_USE_AMD_V_DOT4_I32_I8
#define CK_USE_AMD_V_DOT4_I32_I8
...
@@ -53,14 +53,18 @@
...
@@ -53,14 +53,18 @@
// MFMA instruction
// MFMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_MFMA
#define CK_USE_AMD_MFMA
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__)
|| defined(__gfx940__)
// for GPU code
#define CK_USE_AMD_MFMA
#define CK_USE_AMD_MFMA
#endif
#endif
#if
defined(__gfx90a__)
#if
(
defined(__gfx90a__)
|| defined(__gfx940__))
#define CK_USE_AMD_MFMA_BF16_1K_OP
#define CK_USE_AMD_MFMA_BF16_1K_OP
#endif
#endif
#if defined(__gfx940__)
#define CK_USE_AMD_MFMA_GFX940
#endif
// WMMA instruction
// WMMA instruction
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_WMMA
#define CK_USE_AMD_WMMA
...
@@ -80,13 +84,13 @@
...
@@ -80,13 +84,13 @@
// buffer atomic add: floating point
// buffer atomic add: floating point
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#ifndef __HIP_DEVICE_COMPILE__ // for host code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#elif defined(__gfx908__) || defined(__gfx90a__) // for GPU code
#elif defined(__gfx908__) || defined(__gfx90a__)
|| defined(__gfx940__)
// for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 1
#else // for GPU code
#else // for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#define CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 0
#endif
#endif
#if
defined(__gfx90a__) // for GPU code
#if
(
defined(__gfx90a__)
|| defined(__gfx940__))
// for GPU code
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 1
#else
#else
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
#define CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 0
...
@@ -168,6 +172,11 @@
...
@@ -168,6 +172,11 @@
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#endif
namespace
ck
{
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_waveletmodel_cshuffle.hpp
View file @
8820cf9f
...
@@ -47,7 +47,8 @@ __global__ void
...
@@ -47,7 +47,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -416,7 +417,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
...
@@ -416,7 +417,8 @@ struct DeviceGemm_Xdl_WaveletModel_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
8820cf9f
...
@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
...
@@ -31,7 +31,7 @@ struct DeviceGroupedGemm : public BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
DsLayout
::
Size
()
==
DsDataType
::
Size
(),
"wrong! inconsis
i
ten NumDTensor"
);
static_assert
(
DsLayout
::
Size
()
==
DsDataType
::
Size
(),
"wrong! inconsisten
t
NumDTensor"
);
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_a
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
8820cf9f
...
@@ -43,7 +43,8 @@ __global__ void
...
@@ -43,7 +43,8 @@ __global__ void
const
B1ElementwiseOperation
b1_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
)
const
CElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -678,7 +679,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
...
@@ -678,7 +679,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp
0 → 100644
View file @
8820cf9f
#pragma once
#include <iostream>
#include <vector>
#include "device_grouped_gemm.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGroupedGemmSplitK
:
public
DeviceGroupedGemm
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
virtual
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
6
…
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment