Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7610e049
"src/include/ConstantTensorDescriptor.hpp" did not exist on "0a386c46a9f05b95b8bc5d710ffce664f5fd0d41"
Commit
7610e049
authored
Apr 21, 2022
by
Chao Liu
Browse files
refactor
parent
1a24ad25
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
119 additions
and
650 deletions
+119
-650
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
.../gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
..._fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
...nv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
+4
-8
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
...ation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
..._operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+3
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
...tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
+0
-478
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
...peration/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
+5
-9
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
.../gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
.../device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
.../tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
+1
-3
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
...nsor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
+1
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+23
-29
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
...eration/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
+45
-71
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+4
-2
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
7610e049
...
...
@@ -726,11 +726,10 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwi
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_batched_gemm_reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
7610e049
...
...
@@ -430,13 +430,12 @@ struct DeviceBatchedGemmXdl
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk.hpp
View file @
7610e049
...
...
@@ -582,11 +582,10 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_add_nhwc_kyxc_nhwk.hpp
View file @
7610e049
...
...
@@ -698,13 +698,12 @@ struct
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_bias_activation_nhwc_kyxc_nhwk.hpp
View file @
7610e049
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
...
...
@@ -660,13 +658,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r2
<
GridwiseGemm
,
...
...
@@ -919,4 +916,3 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
7610e049
...
...
@@ -640,13 +640,12 @@ struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
7610e049
...
...
@@ -478,13 +478,12 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_data_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
7610e049
...
...
@@ -1296,11 +1296,10 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_container_
[
i
]);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_container_
[
i
].
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp
View file @
7610e049
...
...
@@ -775,13 +775,12 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_reduce_xdl_cshuffle.hpp
View file @
7610e049
...
...
@@ -530,11 +530,10 @@ struct DeviceGemmReduce_Xdl_CShuffle : public DeviceGemmReduce<AElementwiseOpera
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
if
(
has_main_k0_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_reduce_xdl_cshuffle_v1
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
7610e049
...
...
@@ -292,6 +292,7 @@ struct DeviceGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
#if 0
{
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
...
...
@@ -304,6 +305,7 @@ struct DeviceGemmXdl
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
@@ -320,11 +322,9 @@ struct DeviceGemmXdl
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
if
(
has_main_k_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle.hpp
deleted
100644 → 0
View file @
1a24ad25
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#define DEVICE_GEMM_XDL_C_SHUFFLE_HPP
#include <iostream>
#include <sstream>
#include "device.hpp"
#include "device_gemm.hpp"
#include "common_header.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v3r1.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsAddExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
NumPrefetch
=
1
>
struct
DeviceGemmXdl_C_Shuffle
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
AK1
==
0
);
const
index_t
K0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_k0_m_k1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_k0_m_k1
;
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
BK1
==
0
);
const
index_t
K0
=
K
/
BK1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
b_grid_desc_k0_n_k1
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_k0_n_k1
;
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
BBlockLdsAddExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
NumPrefetch
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdl_C_Shuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdl_C_Shuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdl_C_Shuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmXdl_C_Shuffle
::
Argument
;
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
if
(
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdl_C_Shuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl_C_Shuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdl_C_Shuffle
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdl_C_Shuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
int
nrepeat
=
1
)
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
nrepeat
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
/* KBatch */
=
1
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdl_C_Shuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_2d.hpp
View file @
7610e049
#ifndef DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP
#define DEVICE_GEMM_XDL_C_SHUFFLE_BIAS_2D_HPP
#pragma once
#include <iostream>
#include <sstream>
#include "device.hpp"
...
...
@@ -291,18 +289,17 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v
2r3
has invalid setting"
);
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v
3r2
has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r2
<
GridwiseGemm
,
...
...
@@ -505,4 +502,3 @@ struct DeviceGemmXdl_C_Shuffle_Bias_2d
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation.hpp
View file @
7610e049
...
...
@@ -303,13 +303,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r2
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_c_shuffle_bias_activation_add.hpp
View file @
7610e049
...
...
@@ -345,13 +345,12 @@ struct DeviceGemmXdl_C_Shuffle_Bias_Activation_Add
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
if
(
has_main_k0_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdlops_v3r3
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle.hpp
View file @
7610e049
...
...
@@ -465,11 +465,9 @@ struct DeviceGemm_Xdl_CShuffle
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
if
(
has_main_k_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_cshuffle_v2.hpp
View file @
7610e049
...
...
@@ -467,11 +467,9 @@ struct DeviceGemm_Xdl_CShuffle_v2
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
float
ave_time
=
0
;
if
(
has_main_k_b
lock
_l
oop
)
if
(
GridwiseGemm
::
CalculateHasMainKB
lock
L
oop
(
K
)
)
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v2
<
GridwiseGemm
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
7610e049
...
...
@@ -450,59 +450,53 @@ struct DeviceGroupedGemmXdl
float
Run
(
const
Argument
&
arg
,
int
nrepeat
=
1
)
{
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_arg
_arg
;
StaticallyIndexedArray
<
GemmDescKernelArg
,
MaxGroupCount
>
gemm_desc_kernel_arg
s
;
bool
has_main_k
0
_block_loop
=
true
;
bool
has_main_k_block_loop
=
true
;
static_for
<
0
,
MaxGroupCount
,
1
>
{}([
&
](
auto
i
)
{
if
(
i
<
arg
.
gemm_desc_kernel_arg_
.
size
())
{
gemm_desc_kernel_arg
_arg
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
gemm_desc_kernel_arg
s
(
i
)
=
arg
.
gemm_desc_kernel_arg_
[
i
];
std
::
cout
<<
"group: "
<<
i
<<
" arg.a_grid_desc_k0_m_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.b_grid_desc_k0_n_k1_{"
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
<<
gemm_desc_kernel_args
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_args
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
gemm_desc_kernel_args
[
i
].
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
;
std
::
cout
<<
", arg.c_grid_desc_m_n_{ "
<<
gemm_desc_kernel_arg
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
_arg
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
gemm_desc_kernel_arg
s
[
i
].
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_arg_arg
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
if
(
!
GridwiseGemm
::
CheckValidity
(
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
,
gemm_desc_kernel_args
[
i
].
b_grid_desc_k0_n_k1_
,
gemm_desc_kernel_args
[
i
].
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
auto
K0
=
gemm_desc_kernel_arg_arg
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
auto
K
=
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
gemm_desc_kernel_args
[
i
].
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainK
0
BlockLoop
(
K
0
)
!=
has_main_k
0
_block_loop
)
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k
0
_block_loop"
);
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
});
float
ave_time
=
0
;
if
(
has_main_k
0
_block_loop
)
if
(
has_main_k_block_loop
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdlops_v2r3
<
GridwiseGemm
,
...
...
@@ -520,7 +514,7 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg
_arg
,
gemm_desc_kernel_arg
s
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
@@ -544,7 +538,7 @@ struct DeviceGroupedGemmXdl
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
gemm_desc_kernel_arg
_arg
,
gemm_desc_kernel_arg
s
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_reduce_xdl_cshuffle_v1.hpp
View file @
7610e049
...
...
@@ -25,7 +25,7 @@ template <typename GridwiseGemm,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
DGridDescriptor_MBlock_MPerBlock
,
typename
Block2CTileMap
,
bool
HasMainK
0
BlockLoop
>
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -51,22 +51,22 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainK
0
BlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_d0_grid
,
p_d1_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_d0_grid
,
p_d1_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
d0_reduce_op
,
d1_reduce_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
...
...
@@ -154,6 +154,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipeline_v1
<
NumGemmKPrefetchStage
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
...
...
@@ -237,21 +241,10 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
// check NumGemmKPrefetchStage
if
constexpr
(
NumGemmKPrefetchStage
==
1
)
{
// 1-stage prefetch always supported
}
else
if
constexpr
(
NumGemmKPrefetchStage
==
2
)
{
// 2-stage prefetch currently only support even number of K0 loop
// TODO: add support for odd number of K0 loop
if
(
!
((
K
/
KPerBlock
)
%
2
==
0
))
{
return
false
;
}
}
else
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
...
...
@@ -271,12 +264,11 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
grid_size
;
}
// TODO move this function into GEMM-pipeline class
__host__
__device__
static
constexpr
bool
CalculateHasMainK0BlockLoop
(
index_t
K0
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
bool
has_main_k0_block_loop
=
((
K0
*
AK1
)
/
(
NumGemmKPrefetchStage
*
KPerBlock
))
>
1
;
const
index_t
num_loop
=
K
/
KPerBlock
;
return
has_main_k0_block
_loop
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num
_loop
)
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -362,7 +354,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
template
<
bool
HasMainK
0
BlockLoop
,
typename
Block2CTileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
...
...
@@ -485,7 +477,7 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
...
...
@@ -513,43 +505,25 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1
<
remove_cvref_t
<
decltype
(
a_grid_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_block_desc_ak0_m_ak1
)
>
,
remove_cvref_t
<
decltype
(
a_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
a_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_buf
)
>
,
remove_cvref_t
<
decltype
(
a_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
b_grid_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_block_desc_bk0_n_bk1
)
>
,
remove_cvref_t
<
decltype
(
b_blockwise_copy
)
>
,
remove_cvref_t
<
decltype
(
b_grid_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_buf
)
>
,
remove_cvref_t
<
decltype
(
b_block_slice_copy_step
)
>
,
remove_cvref_t
<
decltype
(
blockwise_gemm
)
>
,
remove_cvref_t
<
decltype
(
c_thread_buf
)
>
,
NumGemmKPrefetchStage
,
HasMainK0BlockLoop
>
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
g
ridwise
_g
emm
_p
ipe
line
.
Run
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
G
ridwise
G
emm
P
ipe
::
template
Run
<
HasMainKBlockLoop
>
(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
7610e049
...
...
@@ -120,6 +120,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// K1 should be Number<...>
static
constexpr
auto
K1
=
Number
<
K1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -262,7 +264,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
}();
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
...
...
@@ -487,7 +489,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
Block
Size
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
ThisThread
Block
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment