Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
25e35b59
Commit
25e35b59
authored
Jun 11, 2022
by
Chao Liu
Browse files
rename, clean
parent
8a60a329
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
98 additions
and
90 deletions
+98
-90
example/03_gemm_bias_add_fastgelu/CMakeLists.txt
example/03_gemm_bias_add_fastgelu/CMakeLists.txt
+1
-0
example/03_gemm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
...emm_bias_add_fastgelu/gemm_bias_add_fastgelu_xdl_fp16.cpp
+0
-0
example/03_gemm_bias_fastgelu/CMakeLists.txt
example/03_gemm_bias_fastgelu/CMakeLists.txt
+0
-1
example/CMakeLists.txt
example/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
...ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+12
-11
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+83
-76
No files found.
example/03_gemm_bias_add_fastgelu/CMakeLists.txt
0 → 100644
View file @
25e35b59
add_example_executable
(
example_gemm_bias_add_fastgelu_xdl_fp16 gemm_bias_add_fastgelu_xdl_fp16.cpp
)
example/03_gemm_bias_fastgelu/gemm_bias_fastgelu_xdl_fp16.cpp
→
example/03_gemm_bias_
add_
fastgelu/gemm_bias_
add_
fastgelu_xdl_fp16.cpp
View file @
25e35b59
File moved
example/03_gemm_bias_fastgelu/CMakeLists.txt
deleted
100644 → 0
View file @
8a60a329
add_example_executable
(
example_gemm_bias_fastgelu_xdl_fp16 gemm_bias_fastgelu_xdl_fp16.cpp
)
example/CMakeLists.txt
View file @
25e35b59
...
@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
...
@@ -39,7 +39,7 @@ endfunction(add_example_executable_no_testing EXAMPLE_NAME)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
01_gemm
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
02_gemm_alpha_beta
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
03_gemm_bias_relu
)
add_subdirectory
(
03_gemm_bias_fastgelu
)
add_subdirectory
(
03_gemm_bias_
add_
fastgelu
)
add_subdirectory
(
04_gemm_bias_relu_add
)
add_subdirectory
(
04_gemm_bias_relu_add
)
add_subdirectory
(
06_conv2d_fwd_bias_relu
)
add_subdirectory
(
06_conv2d_fwd_bias_relu
)
add_subdirectory
(
07_conv2d_fwd_bias_relu_add
)
add_subdirectory
(
07_conv2d_fwd_bias_relu_add
)
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp
View file @
25e35b59
...
@@ -33,7 +33,7 @@ struct DeviceGemmMultipleD : public BaseOperator
...
@@ -33,7 +33,7 @@ struct DeviceGemmMultipleD : public BaseOperator
ck
::
index_t
StrideE
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
);
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
25e35b59
...
@@ -489,7 +489,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -489,7 +489,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2
C
TileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2
E
TileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
cde_element_op_
{
cde_element_op
}
...
@@ -500,7 +500,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -500,7 +500,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
block_2_etile_map_
))
block_2_etile_map_
))
{
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
Make
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
Make
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
...
@@ -512,7 +512,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -512,7 +512,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
Make
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
Make
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
d_grid_desc_m_n
);
});
});
}
}
...
@@ -538,13 +538,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -538,13 +538,14 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
StaticallyIndexedArray
<
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2
C
TileMap
block_2_etile_map_
;
typename
GridwiseGemm
::
DefaultBlock2
E
TileMap
block_2_etile_map_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
...
@@ -625,10 +626,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -625,10 +626,10 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
NumDTensor
>
,
typename
GridwiseGemm
::
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2
C
TileMap
,
typename
GridwiseGemm
::
DefaultBlock2
E
TileMap
,
has_main_loop
>
;
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
...
@@ -782,7 +783,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
...
@@ -782,7 +783,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<DsDataType:
return
str
.
str
();
return
str
.
str
();
}
}
};
// namespace device
};
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
25e35b59
...
@@ -14,18 +14,24 @@
...
@@ -14,18 +14,24 @@
namespace
ck
{
namespace
ck
{
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
DsDataType
,
typename
DsDataType
,
typename
Float
C
,
typename
Float
E
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
C
DE
ElementwiseOperation
,
InMemoryDataOperationEnum
C
GlobalMemoryDataOperation
,
InMemoryDataOperationEnum
E
GlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
C
GridDesc_M_N
,
typename
E
GridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
MPerBlock
,
...
@@ -55,8 +61,8 @@ template <typename FloatAB,
...
@@ -55,8 +61,8 @@ template <typename FloatAB,
index_t
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
C
Shuffle
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
C
DE
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
C
DE
ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
>
struct
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
struct
GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
{
{
...
@@ -153,12 +159,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -153,12 +159,12 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template
<
typename
Block2
C
TileMap
>
template
<
typename
Block2
E
TileMap
>
__host__
__device__
static
constexpr
bool
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
C
GridDesc_M_N
&
c
_grid_desc_m_n
,
const
E
GridDesc_M_N
&
e
_grid_desc_m_n
,
const
Block2
C
TileMap
&
block_2_
c
tile_map
)
const
Block2
E
TileMap
&
block_2_
e
tile_map
)
{
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
@@ -168,7 +174,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -168,7 +174,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
if
(
!
(
M
==
c
_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c
_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
e
_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e
_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
return
false
;
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
...
@@ -182,7 +188,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -182,7 +188,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
return
false
;
return
false
;
}
}
if
(
!
block_2_
c
tile_map
.
CheckValidity
(
c
_grid_desc_m_n
))
if
(
!
block_2_
e
tile_map
.
CheckValidity
(
e
_grid_desc_m_n
))
{
{
return
false
;
return
false
;
}
}
...
@@ -199,58 +205,59 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -199,58 +205,59 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
Make
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
C
GridDesc_M_N
&
c
_grid_desc_m_n
)
Make
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
E
GridDesc_M_N
&
e
_grid_desc_m_n
)
{
{
const
auto
M
=
c
_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
e
_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c
_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
e
_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
c
_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
const
auto
e
_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c
_grid_desc_m_n
,
e
_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c
_grid_desc_mblock_mperblock_nblock_nperblock
;
return
e
_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeDefaultBlock2
C
TileMap
(
const
C
GridDesc_M_N
&
c
_grid_desc_m_n
)
MakeDefaultBlock2
E
TileMap
(
const
E
GridDesc_M_N
&
e
_grid_desc_m_n
)
{
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
C
GridDesc_M_N
>
(
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
E
GridDesc_M_N
>
(
c
_grid_desc_m_n
);
e
_grid_desc_m_n
);
}
}
using
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
Make
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
C
GridDesc_M_N
{}))
>
;
Make
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
E
GridDesc_M_N
{}))
>
;
using
DefaultBlock2
C
TileMap
=
using
DefaultBlock2
E
TileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2
C
TileMap
(
C
GridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDefaultBlock2
E
TileMap
(
E
GridDesc_M_N
{}))
>
;
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
using
DsGridPointer
=
decltype
(
MakeDsGridPointer
());
template
<
bool
HasMainKBlockLoop
,
typename
Block2
C
TileMap
>
template
<
bool
HasMainKBlockLoop
,
typename
Block2
E
TileMap
>
__device__
static
void
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
Float
C
*
__restrict__
p_
c
_grid
,
Float
E
*
__restrict__
p_
e
_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
const
C
DE
ElementwiseOperation
&
c
de
_element_op
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
StaticallyIndexedArray
<
C
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
const
StaticallyIndexedArray
<
E
GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>&
NumDTensor
>&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
// FIXME: use tuple
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
// FIXME: Ds desc may be of different
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
// type from E
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
const
Block2CTileMap
&
block_2_ctile_map
)
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -266,17 +273,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -266,17 +273,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
},
},
Number
<
NumDTensor
>
{});
Number
<
NumDTensor
>
{});
auto
c
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
e
_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_
c
_grid
,
c
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_
e
_grid
,
e
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
// divide block work by [M, N]
const
auto
block_work_idx
=
const
auto
block_work_idx
=
block_2_
c
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
block_2_
e
tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_
c
tile_map
.
ValidCTileIndex
(
if
(
!
block_2_
e
tile_map
.
ValidCTileIndex
(
block_work_idx
,
block_work_idx
,
make_tuple
(
c
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
make_tuple
(
e
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
e
_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
{
return
;
return
;
}
}
...
@@ -537,27 +544,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -537,27 +544,27 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// shuffle: blockwise copy C from LDS to global
// shuffle: blockwise copy C from LDS to global
#if 1
#if 1
auto
c
_shuffl
e_block_copy_lds_
to
_global
=
ThreadGroupTensorSliceTransfer_v6r3
<
auto
c
d
e_block_copy_lds_
and
_global
=
ThreadGroupTensorSliceTransfer_v6r3
<
ThisThreadBlock
,
// ThreadGroup
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
C
DE
ElementwiseOperation
,
// ElementwiseOperation,
C
GlobalMemoryDataOperation
,
// DstInMemOp,
E
GlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
C
Shuffle
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
C
DE
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename Src0Data,
FloatCShuffle
,
// typename Src0Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I0
])
>
,
// typename Src1Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I0
])
>
,
// typename Src1Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I1
])
>
,
// typename Src2Data,
remove_cvref_t
<
decltype
(
DsDataType
{}[
I1
])
>
,
// typename Src2Data,
Float
C
,
// typename DstData,
Float
E
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
decltype
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
]),
decltype
(
c
_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e
_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
C
DE
ShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc1ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
false
,
// bool ThreadTransferSrc2ResetCoordinateAfterRun,
...
@@ -568,37 +575,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -568,37 +575,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c
de
_element_op
};
#else
#else
auto
c
_shuffl
e_block_copy_lds_
to
_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
auto
c
d
e_block_copy_lds_
and
_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
C
DE
ElementwiseOperation
,
// ElementwiseOperation,
C
GlobalMemoryDataOperation
,
// DstInMemOp,
E
GlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
C
Shuffle
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
C
DE
BlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
FloatCShuffle
,
// typename Src0Data,
FloatCShuffle
,
// typename Src0Data,
Float
C
,
// typename DstData,
Float
E
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c
_grid_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
e
_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
C
DE
ShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrc0ResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
make_multi_index
(
block_work_idx
[
I0
],
0
,
block_work_idx
[
I1
],
0
),
c_element_op
};
c
de
_element_op
};
#endif
#endif
// space filling curve for threadwise C in VGPR
// space filling curve for threadwise C in VGPR
before shuffle
constexpr
auto
sfc_c_vgpr
=
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
...
@@ -611,8 +618,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -611,8 +618,8 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
M4
,
M4
,
1
>>
{};
1
>>
{};
// space filling curve for shuffled blockwise C
in global mem
// space filling curve for shuffled blockwise C
/D/E
constexpr
auto
sfc_c
_global
=
constexpr
auto
sfc_c
de_block
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
Sequence
<
1
,
...
@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c
_global
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_c
de_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
...
@@ -640,37 +647,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
...
@@ -640,37 +647,37 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_xdl_cshuffle
// each block copy its data from LDS to global
// each block copy its data from LDS to global
#if 1
#if 1
c
_shuffl
e_block_copy_lds_
to
_global
.
Run
(
c
d
e_block_copy_lds_
and
_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_shuffle_block_buf
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
ds_grid_buf
[
I0
],
ds_grid_buf
[
I0
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
ds_grid_buf
[
I1
],
ds_grid_buf
[
I1
],
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
c
_grid_buf
);
e
_grid_buf
);
#else
#else
c
_shuffl
e_block_copy_lds_
to
_global
.
Run
(
c
d
e_block_copy_lds_
and
_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_shuffle_block_buf
,
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
c
_grid_buf
);
e
_grid_buf
);
#endif
#endif
if
constexpr
(
access_id
<
num_access
-
1
)
if
constexpr
(
access_id
<
num_access
-
1
)
{
{
constexpr
auto
c_global_step
=
sfc_c
_global
.
GetForwardStep
(
access_id
);
constexpr
auto
c_global_step
=
sfc_c
de_block
.
GetForwardStep
(
access_id
);
// move on Ds
// move on Ds
c
_shuffl
e_block_copy_lds_
to
_global
.
MoveSrc1SliceWindow
(
c
d
e_block_copy_lds_
and
_global
.
MoveSrc1SliceWindow
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
c_global_step
);
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I0
],
c_global_step
);
c
_shuffl
e_block_copy_lds_
to
_global
.
MoveSrc2SliceWindow
(
c
d
e_block_copy_lds_
and
_global
.
MoveSrc2SliceWindow
(
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
c_global_step
);
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
I1
],
c_global_step
);
// move on
C
// move on
E
c
_shuffl
e_block_copy_lds_
to
_global
.
MoveDstSliceWindow
(
c
d
e_block_copy_lds_
and
_global
.
MoveDstSliceWindow
(
c
_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
e
_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
}
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment