Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5ec6a912
Commit
5ec6a912
authored
Jun 27, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
d39c3f5d
3bb0fe6c
Changes
226
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1619 additions
and
324 deletions
+1619
-324
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+356
-188
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
...device/impl/device_multi_query_attention_forward_wmma.hpp
+3
-2
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+13
-0
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+0
-20
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+41
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+45
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+14
-6
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+15
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+39
-31
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+27
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+83
-29
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
...k/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
+3
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+108
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+146
-1
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+645
-14
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+69
-0
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_dl.hpp
View file @
5ec6a912
...
...
@@ -40,7 +40,8 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__))
defined(__gfx90a__) || defined(__gfx103__) || defined(__gfx11__) || defined(__gfx94__) || \
defined(__gfx12__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
...
...
@@ -673,7 +674,7 @@ struct DeviceGroupedGemmMultipleD_Dl : public DeviceGroupedGemm<ALayout,
}
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
5ec6a912
...
...
@@ -19,6 +19,7 @@
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_tile_loop.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp" // stare wywalic
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
...
...
@@ -42,16 +43,22 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
index_t
KPerBlock
,
typename
OffsettedBlockToCTileMap
,
typename
LocalBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
...
@@ -67,6 +74,7 @@ __global__ void
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared1
[
shared_size
];
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
...
...
@@ -81,27 +89,8 @@ __global__ void
index_t
gemm_tile_id_start
=
0
;
index_t
gemm_tile_id_end
=
0
;
using
AGridDescMK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
BGridDescNK
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
EGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
1
,
1
,
1
))
>
;
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
index_t
M
=
0
,
N
=
0
,
K
=
0
;
index_t
StrideA
,
StrideB
,
StrideE
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
AGridDescMK
a_grid_desc_mk
;
BGridDescNK
b_grid_desc_nk
;
EGridDescMN
e_grid_desc_mn
;
DsGridDescMN
ds_grid_desc_mn
;
auto
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
1
,
1
),
1
,
1
);
do
...
...
@@ -127,31 +116,13 @@ __global__ void
}
b2c_tile_map
=
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
M
,
N
),
group_offset
,
tile_offset
);
OffsettedBlockToCTileMap
(
LocalBlock2ETileMap
(
M
,
N
,
4
),
group_offset
,
tile_offset
);
grid_size_grp
=
b2c_tile_map
.
CalculateGridSize
(
M
,
N
);
gemm_tile_id_start
=
group_offset
;
gemm_tile_id_end
=
group_offset
+
grid_size_grp
;
}
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid
;
...
...
@@ -160,43 +131,269 @@ __global__ void
p_ds_grid
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
bool
has_main_kblock_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_mk
.
GetLength
(
Number
<
1
>
{}));
static
constexpr
index_t
kbatch
=
1
;
static
constexpr
index_t
k_grain
=
kbatch
*
KPerBlock
;
index_t
K_split
=
(
K
+
k_grain
-
1
)
/
k_grain
*
KPerBlock
;
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
// Update tile offset if we have moved within group
b2c_tile_map
.
UpdateTileOffset
(
tile_offset
);
if
(
has_main_kblock_loop
)
using
Problem
=
typename
GridwiseGemm
::
Problem
;
auto
problem
=
Problem
(
gemm_desc_ptr
[
group_id
].
M
,
gemm_desc_ptr
[
group_id
].
N
,
gemm_desc_ptr
[
group_id
].
K
,
gemm_desc_ptr
[
group_id
].
StrideA
,
gemm_desc_ptr
[
group_id
].
StrideB
,
gemm_desc_ptr
[
group_id
].
StrideDs
,
gemm_desc_ptr
[
group_id
].
StrideE
,
kbatch
);
if
(
has_main_k_block_loop
)
{
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
One
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Two
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Three
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Four
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Five
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Six
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Seven
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
GridwiseGemm
::
template
Run_2Lds
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Odd
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
),
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared1
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
b2c_tile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
true
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
GridwiseGemm
::
template
Run_2Lds
<
OffsettedBlockToCTileMap
,
true
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Even
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
)
,
static_cast
<
void
*>
(
p_shared
),
static_cast
<
void
*>
(
p_shared1
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
}
}
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
GridwiseGemm
::
template
Run
<
OffsettedBlockToCTileMap
,
false
,
InMemoryDataOperationEnum
::
Set
,
TailNumber
::
Full
>(
static_cast
<
const
ADataType
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
),
static_cast
<
const
BDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
),
p_ds_grid
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
static_cast
<
EDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_e_grid
)
,
static_cast
<
void
*>
(
p_shared
),
problem
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
);
}
}
tile_id
+=
get_grid_size
();
tile_offset
+=
get_grid_size
();
...
...
@@ -253,10 +450,12 @@ template <typename ALayout,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
typename
ComputeDataType
=
EDataType
>
typename
CDEShuffleBlockTransferScalarPerVectors
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
typename
ComputeTypeA
=
EDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
:
public
DeviceGroupedGemmTileLoop
<
ALayout
,
BLayout
,
...
...
@@ -273,10 +472,13 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
using
DeviceOp
=
DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultiD_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
ComputeDataType
,
AccDataType
,
CShuffleDataType
,
DsDataType
,
...
...
@@ -284,8 +486,7 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
NumGemmKPrefetchStage
,
GemmSpec
,
BlockSize
,
MPerBlock
,
NPerBlock
,
...
...
@@ -315,58 +516,15 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
PipelineVer
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
group_offset
,
index_t
tile_offset
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
group_offset_
{
group_offset
},
tile_offset_
{
tile_offset
}
{
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
+
tile_offset_
-
group_offset_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
__device__
void
UpdateTileOffset
(
index_t
offset
)
{
tile_offset_
=
offset
;
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
group_offset_
;
index_t
tile_offset_
;
};
CDEShuffleBlockTransferScalarPerVectors
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
,
ComputeTypeA
,
ComputeTypeB
>
;
using
KernelArguments
=
GroupedGemmTileLoopKernelArguments
<
NumDTensor
>
;
using
Block2ETileMap
=
BlockToCTileMap_
N00
_M0_N01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
OffsetedLocalBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
using
Block2ETileMap
=
BlockToCTileMap_
Grouped
_M0
0
_N0
_M0
1Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Offset
t
edLocalBlock2ETileMap
=
OffsettedBlockToCTileMap
2
<
Block2ETileMap
>
;
// Argument
struct
Argument
:
public
BaseArgument
...
...
@@ -403,7 +561,6 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
void
*
p_dev_gemm_args_
;
int
occupancy_num_blocks_
;
int
gpu_cu_count_
;
const
std
::
vector
<
GemmDesc
>&
gemm_descs_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
...
...
@@ -496,16 +653,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
return
LaunchKernel
(
kernel
,
arg
,
dev_gemm_args
,
stream_config
);
}
...
...
@@ -546,6 +709,8 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<<
std
::
endl
;
}
// run multiple kernels
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
...
...
@@ -572,63 +737,41 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
return
false
;
}
using
DsGridDescMN
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
template
MakeDsGridDescriptor_M_N
<
DsLayout
,
GemmSpec
>(
{},
{},
{}))
>
;
bool
supported
=
true
;
for
(
const
auto
&
gdesc
:
arg
.
gemm_descs_
)
{
const
auto
M
=
gdesc
.
M_
;
const
auto
N
=
gdesc
.
N_
;
const
auto
K
=
gdesc
.
K_
;
const
auto
StrideA
=
gdesc
.
stride_A_
;
const
auto
StrideB
=
gdesc
.
stride_B_
;
const
auto
StrideE
=
gdesc
.
stride_C_
;
const
auto
&
StrideDs
=
gdesc
.
stride_Ds_
;
// If M dimension is unknown at launch time then validate just NK.
// If N or K dim is zero (or unknown) then the vector loads responsibility lies on
// the user.
if
(
N
*
K
==
0
)
continue
;
const
auto
a_grid_desc_mk
=
GridwiseGemm
::
template
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>(
M
,
K
,
StrideA
);
const
auto
b_grid_desc_nk
=
GridwiseGemm
::
template
MakeBGridDescriptor_N_K
<
BLayout
,
GemmSpec
>(
K
,
N
,
StrideB
);
const
auto
e_grid_desc_mn
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
DsGridDescMN
ds_grid_desc_mn
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsLayout
>>
;
ds_grid_desc_mn
(
j
)
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
DLayout
,
GemmSpec
>(
M
,
N
,
StrideDs
[
j
]);
});
const
auto
b2c_tile_map
=
Block2ETileMap
(
M
,
N
);
constexpr
index_t
k_batch
=
1
;
for
(
index_t
i
=
0
;
i
<
arg
.
group_count_
;
++
i
)
{
std
::
array
<
const
void
*
,
NumDTensor
>
placeholder_p_ds_grid
{};
std
::
array
<
index_t
,
NumDTensor
>
stride_Ds
;
std
::
copy_n
(
arg
.
gemm_descs_
[
i
].
stride_Ds_
.
begin
(),
NumDTensor
,
stride_Ds
.
begin
());
using
GridArg
=
typename
GridwiseGemm
::
Argument
;
GridArg
gridwise_arg
(
nullptr
,
// p_a_grid,
nullptr
,
// p_b_grid,
placeholder_p_ds_grid
,
// p_ds_grid,
nullptr
,
// p_e_grid ,
arg
.
gemm_descs_
[
i
].
M_
,
arg
.
gemm_descs_
[
i
].
N_
,
arg
.
gemm_descs_
[
i
].
K_
,
arg
.
gemm_descs_
[
i
].
stride_A_
,
arg
.
gemm_descs_
[
i
].
stride_B_
,
stride_Ds
,
arg
.
gemm_descs_
[
i
].
stride_C_
,
k_batch
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
);
if
(
!
(
GridwiseGemm
::
template
CheckValidity
(
a_grid_desc_mk
,
b_grid_desc_nk
,
ds_grid_desc_mn
,
e_grid_desc_mn
,
b2c_tile_map
)
&&
GridwiseGemm
::
template
CheckTensorTransfersValidity
<
ALayout
,
BLayout
,
ELayout
>(
M
,
N
,
K
)))
{
if
(
ck
::
EnvIsEnabled
(
CK_ENV
(
CK_LOGGING
)))
if
((
arg
.
gemm_descs_
[
i
].
K_
%
AK1
!=
0
||
arg
.
gemm_descs_
[
i
].
K_
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
std
::
cout
<<
"The provided GEMM problem size (M,N,K) ["
<<
M
<<
","
<<
N
<<
","
<<
K
<<
"] are not supported by current template parameters!"
<<
" In "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
}
supported
=
false
;
return
false
;
}
supported
=
supported
&&
GridwiseGemm
::
CheckValidity
(
gridwise_arg
);
}
return
supported
;
...
...
@@ -651,16 +794,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
...
...
@@ -696,16 +845,22 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
const
auto
kernel
=
kernel_grouped_gemm_multiple_d_xdl
<
GridwiseGemm
,
KernelArguments
,
GemmSpec
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
OffsetedLocalBlock2ETileMap
,
KPerBlock
,
OffsettedLocalBlock2ETileMap
,
Block2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
;
CDEElementwiseOperation
,
BlkGemmPipeSched
,
BlkGemmPipelineVer
>
;
int
occupancy
,
num_cu
;
hip_check_error
(
hipOccupancyMaxActiveBlocksPerMultiprocessor
(
&
occupancy
,
kernel
,
BlockSize
,
0
));
...
...
@@ -739,6 +894,17 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
{
auto
str
=
std
::
ostringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
str
<<
"DeviceGroupedGemmMultipleDXdlCShuffleTileLoop"
<<
"<"
...
...
@@ -760,8 +926,10 @@ struct DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
PipelineVer
<<
", "
<<
LoopSched
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
5ec6a912
...
...
@@ -61,7 +61,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -166,6 +166,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -596,7 +597,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -108,7 +108,8 @@ struct DeviceImageToColumnImpl
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
input_right_pads
,
N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_multi_query_attention_forward_wmma.hpp
View file @
5ec6a912
...
...
@@ -60,7 +60,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -165,6 +165,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -594,7 +595,7 @@ struct DeviceMultiQueryAttentionForward_Wmma
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
5ec6a912
...
...
@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
{
};
// function to take in a struct of type MatrixPadder and call the appropriate function to get
// the output descriptor at runtime for codegen
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
,
typename
CDesc_MRaw_NRaw
>
auto
grid_desc
(
MatrixPadder
<
GemmSpec
,
MPerTileType
,
NPerTileType
,
KPerTileType
>
matrix_padder
,
CDesc_MRaw_NRaw
conv_desc
)
{
auto
res
=
matrix_padder
.
PadCDescriptor_M_N
(
conv_desc
);
return
res
;
}
// M/N/KPerTileType could be index_t or Number<>
template
<
bool
PadM
,
bool
PadN
,
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
5ec6a912
...
...
@@ -528,26 +528,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
struct
ConvInvscale
{
/// @brief Op to multiply convolution results by inverted scale factors
/// @param e Output after scaling
/// @param c Convolution result
/// @param d0 Input scale factor
/// @param d1 Weights scale factor
/// @param d2 Output scale factor
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
,
typename
D2
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
,
const
D2
&
d2
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
,
float
,
float
,
float
>
(
f8_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
,
const
float
&
d2
)
const
{
e
=
type_convert
<
f8_t
>
(
c
/
d0
/
d1
/
d2
);
};
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
5ec6a912
...
...
@@ -961,6 +961,47 @@ struct Elu
const
float
alpha_
;
};
struct
Logistic
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
{
__host__
__device__
ConvInvscale
(
float
scale_in
=
1.
f
,
float
scale_wei
=
1.
f
,
float
scale_out
=
1.
f
)
:
scale_in_
(
scale_in
),
scale_wei_
(
scale_wei
),
scale_out_
(
scale_out
)
{
}
template
<
typename
E
,
typename
C
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
float
>
(
f8_t
&
e
,
const
float
&
c
)
const
{
e
=
type_convert
<
f8_t
>
(
c
/
scale_in_
/
scale_wei_
/
scale_out_
);
};
float
scale_in_
;
float
scale_wei_
;
float
scale_out_
;
};
struct
ConvScale
{
__host__
__device__
ConvScale
(
float
scale_in
=
1.
f
,
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
5ec6a912
...
...
@@ -908,6 +908,51 @@ struct OffsettedBlockToCTileMap
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
};
// second version with 2 offsets
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMap2
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMap2
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
group_offset
,
index_t
tile_offset
)
:
block_to_ctile_map_
{
block_to_ctile_map
},
group_offset_
{
group_offset
},
tile_offset_
{
tile_offset
}
{
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
+
tile_offset_
-
group_offset_
));
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
M
,
N
);
}
__device__
void
UpdateTileOffset
(
index_t
offset
)
{
tile_offset_
=
offset
;
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
group_offset_
;
index_t
tile_offset_
;
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
5ec6a912
...
...
@@ -373,10 +373,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
LRepeat
>
{},
Number
<
LWaves
>
{},
Number
<
LPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -430,10 +434,14 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_LRow
=
I2
;
#else
constexpr
auto
B_LRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{},
B_LRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
/
B_LRow
>
{},
B_LRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_L1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
5ec6a912
...
...
@@ -50,7 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -304,10 +304,14 @@ struct GridwiseFpAintBGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -362,10 +366,14 @@ struct GridwiseFpAintBGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -54,18 +54,18 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
@@ -147,7 +147,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
...
...
@@ -155,12 +155,12 @@ __global__ void
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
))
)
;
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
))
)
;
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
))
)
;
const
long_index_t
a_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
e_batch_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
...
...
@@ -237,7 +237,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseOp
::
SharedMemTrait
::
lds_size
];
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
else
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -497,10 +499,14 @@ struct GridwiseGemmMultipleD_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -536,10 +542,14 @@ struct GridwiseGemmMultipleD_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -571,15 +581,12 @@ struct GridwiseGemmMultipleD_Wmma
// *Caution Here repeat is shuffle repeat
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMRepeatPerShuffle
*
MWave
*
MPerWmma
>
{},
Number
<
CShuffleMRepeatPerShuffle
*
MWave
s
*
MPerWmma
>
{},
I1
,
Number
<
CShuffleNRepeatPerShuffle
*
NWave
*
NPerWmma
>
{}));
Number
<
CShuffleNRepeatPerShuffle
*
NWave
s
*
NPerWmma
>
{}));
return
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
;
}
...
...
@@ -801,6 +808,7 @@ struct GridwiseGemmMultipleD_Wmma
const
auto
MBlock
=
M
/
MPerBlock
;
const
auto
NBlock
=
N
/
NPerBlock
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
e_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
5ec6a912
...
...
@@ -45,7 +45,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -292,10 +295,15 @@ struct GridwiseGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
/
A_KRow
>
{},
A_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
...
...
@@ -350,10 +358,14 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
/
B_KRow
>
{},
B_KRow
)),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
...
...
@@ -522,12 +534,6 @@ struct GridwiseGemm_Wmma
c_grid_desc_m_n
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
struct
SharedMemTrait
{
// LDS allocation for A and B: be careful of alignment
...
...
@@ -559,6 +565,12 @@ struct GridwiseGemm_Wmma
b_block_space_size_aligned
*
sizeof
(
BDataType
));
};
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
using
DefaultBlock2CTileMap
=
remove_cvref_t
<
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
))
>
;
template
<
bool
HasMainKBlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -189,55 +189,55 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
return
std
::
make_tuple
(
Block2CTileMap
Default
::
CalculateGridSize
(
M
,
N
),
1
,
KBatch
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
__host__
__device__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
__host__
__device__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
...
...
@@ -520,7 +520,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
struct
Problem
{
__host__
Problem
(
index_t
M_
,
__host__
__device__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
...
...
@@ -1180,14 +1180,14 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
__host__
__device__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -1210,8 +1210,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
...
...
@@ -1225,6 +1224,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
const
auto
block_2_ctile_map
=
Block2CTileMapDefault
{
problem
.
M
,
problem
.
N
,
4
};
Run
<
Block2CTileMapDefault
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_c_grid
,
p_shared
,
problem
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
template
<
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
...
...
@@ -1244,9 +1272,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -1653,6 +1678,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMapDefault
{
problem
.
M
,
problem
.
N
,
4
};
Run_2Lds
<
Block2CTileMapDefault
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>
(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_c_grid
,
p_shared_0
,
p_shared_1
,
problem
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
template
<
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
const
auto
a_grid_desc_ak0_m_ak1
=
MakeAGridDescriptor_AK0_M_AK1
(
problem
.
M
,
problem
.
MPadded
,
problem
.
K
,
problem
.
KPadded
,
problem
.
StrideA
,
problem
.
AK0
);
...
...
@@ -1672,9 +1729,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp
View file @
5ec6a912
...
...
@@ -36,7 +36,8 @@ __global__ void
const
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
GridwiseTensorRearrangeKernel
::
Run
(
in_grid_desc
,
p_in_global
,
out_grid_desc
,
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
5ec6a912
...
...
@@ -1304,7 +1304,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
ElementwiseOperation
element_op_
;
};
// Specilized for WMMA
// Specilized for WMMA
-Navi3
// A single Wave32 is composed by double row
// Data exchange allowed between these two rows
// This RowLane Dst buf will be filled from two Src buf
...
...
@@ -1439,4 +1439,111 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
ElementwiseOperation
element_op_
{};
};
// Specilized for WMMA-Navi4
template
<
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
(
const
Index
&
src_idx
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstBuffer
>
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
const
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
SrcSliceOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
DstSliceOriginIdx
>>::
value
,
"wrong! SliceOrigin need to known at compile-time"
);
static_assert
(
SrcBuffer
::
IsStaticBuffer
()
&&
DstBuffer
::
IsStaticBuffer
(),
"wrong! Buffer need to be StaticBuffer"
);
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cvref_t
<
SrcDesc
>
{};
constexpr
auto
dst_desc
=
remove_cvref_t
<
DstDesc
>
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{});
constexpr
auto
dst_slice_origin_idx
=
to_multi_index
(
DstSliceOriginIdx
{});
// scalar per access on each dim
constexpr
auto
dst_scalar_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
DstVectorDim
,
DstScalarPerVector
>
{},
Number
<
nDim
>
{});
constexpr
auto
dst_scalar_step_in_vector
=
generate_sequence
(
detail
::
lambda_scalar_step_in_vector
<
DstVectorDim
>
{},
Number
<
nDim
>
{});
using
SpaceFillingCurve
=
SpaceFillingCurve
<
SliceLengths
,
DimAccessOrder
,
remove_cv_t
<
decltype
(
dst_scalar_per_access
)
>>
;
static_assert
(
DstScalarPerVector
==
SpaceFillingCurve
::
ScalarPerVector
,
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector"
);
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
// src_desc error, non constexpr, caused by merge transform
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v_this_row
;
// int type temp value due to intrinsic requirement
int
temp
=
0
;
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row permute.
if
constexpr
(
IntraRowSwizzlePerm
)
{
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert_sp
<
int
>
(
v_this_row
),
0xb3a29180
,
0xf7e6d5c4
,
1
,
0
);
v_this_row
=
type_convert_sp
<
SrcData
>
(
temp
);
}
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert_sp
<
DstData
>
(
v_this_row
);
});
});
}
ElementwiseOperation
element_op_
{};
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
5ec6a912
...
...
@@ -11,12 +11,17 @@ namespace ck {
enum
struct
WmmaInstr
{
// gfx11
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4
,
// gfx12
wmma_f32_16x16x16_f16_gfx12
,
wmma_f32_16x16x16_bf16_gfx12
,
wmma_i32_16x16x16_iu8_gfx12
,
};
/*
...
...
@@ -279,6 +284,122 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
}
};
// gfx12
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_bf16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_pack_number
=
1
;
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
// static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
,
bool
neg_a
=
false
,
bool
neg_b
=
false
,
bool
clamp
=
false
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
static_assert
(
wave_size
==
32
,
"only support wave32 for gfx12 wmma"
);
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_i32_16x16x16_iu8_w32_gfx12
<
MPerWmma
,
NPerWmma
,
neg_a
,
neg_b
,
clamp
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
template
<
typename
src_type_a
,
typename
src_type_b
,
typename
dst_type
,
...
...
@@ -296,13 +417,21 @@ struct WmmaSelector
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
#endif
}
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_bf16
;
#endif
}
template
<
>
...
...
@@ -320,8 +449,13 @@ struct WmmaSelector
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
#else
return
WmmaInstr
::
wmma_i32_16x16x16_iu8
;
#endif
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
...
...
@@ -502,6 +636,9 @@ struct WmmaGemm
__device__
static
auto
GetSubGroupId
()
{
static_assert
(
wmma_instr
.
num_thread_per_subgroups
*
wmma_instr
.
num_subgroups
==
wmma_instr
.
wave_size
,
""
);
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
...
...
@@ -516,12 +653,20 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetLaneIdUnderSubGroup
()
:
GetSwizzledLaneIdLow
();
#endif
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
#ifdef __gfx12__
return
GetLaneIdUnderSubGroup
();
#else
return
TransposeC
?
GetSwizzledLaneIdLow
()
:
GetLaneIdUnderSubGroup
();
#endif
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,12 +14,88 @@
namespace
ck
{
namespace
tensor_operation
{
// function to be used on device, emulates std::accumulate
template
<
typename
T
,
typename
ForwardIterator
,
typename
Size
>
__host__
__device__
auto
mult_accumulate_n
(
ForwardIterator
first
,
Size
count
,
T
init
)
{
for
(
ForwardIterator
x
=
first
;
x
!=
first
+
count
;
x
++
)
{
init
*=
*
x
;
}
return
init
;
}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
struct
TransformConvFwdToGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
long_index_t
calculate_element_space_size_impl
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
strides
,
index_t
i
)
{
long_index_t
acc
=
1
;
for
(;
i
<
(
NDimSpatial
+
3
);
i
++
)
{
acc
+=
static_cast
<
long_index_t
>
(
lengths
[
i
]
-
I1
)
*
static_cast
<
long_index_t
>
(
strides
[
i
]);
}
return
acc
;
}
template
<
typename
ADataType
,
typename
CDataType
>
static
index_t
GetSplitedNSize
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
long_index_t
a_element_space_size
=
calculate_element_space_size_impl
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
I1
);
const
long_index_t
c_element_space_size
=
calculate_element_space_size_impl
(
c_g_n_k_wos_lengths
,
c_g_n_k_wos_strides
,
I1
);
const
long_index_t
element_space_size
=
math
::
max
(
a_element_space_size
*
sizeof
(
ADataType
),
c_element_space_size
*
sizeof
(
CDataType
));
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
const
index_t
N
=
a_g_n_c_wis_lengths
[
I1
];
if
(
element_space_size
>
TwoGB
)
{
// Minimum divisor of N to not exceed 2GB
const
auto
divisor
=
math
::
integer_divide_ceil
(
element_space_size
,
TwoGB
);
if
(
divisor
<=
static_cast
<
double
>
(
N
))
{
// Find least divisor of N larger than element_space_size / TwoGB
// Iterate up to sqrt(N). There are no divisors above this value.
for
(
index_t
least_divisor
=
divisor
;
least_divisor
*
least_divisor
<=
N
;
least_divisor
++
)
{
if
(
N
%
least_divisor
==
0
)
{
return
N
/
least_divisor
;
}
}
// Not found, process one Convolution N per block
return
1
;
}
else
{
// Not possible to support even after split N.
// Too large tensor.
return
N
;
}
}
else
{
// Split N is not needed.
return
N
;
}
}
// TODO: implement ck::tensor_layout::convolution that describe packed/strided dimemsion as
// properties
template
<
typename
ALayout
,
...
...
@@ -38,9 +114,9 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -151,9 +227,10 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -276,13 +353,14 @@ struct TransformConvFwdToGemm
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides
*/
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides*/
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
const
std
::
array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
index_t
N
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
...
...
@@ -478,9 +556,9 @@ struct TransformConvFwdToGemm
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
...
...
@@ -502,9 +580,9 @@ struct TransformConvFwdToGemm
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
...
...
@@ -525,9 +603,9 @@ struct TransformConvFwdToGemm
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
static
auto
MakeCDescriptor_M_N
(
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
,
const
index_t
N
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
...
...
@@ -540,6 +618,559 @@ struct TransformConvFwdToGemm
return
out_gemmm_gemmn_desc
;
}
// Overloaded functions for hipRTC purposes
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
// TODO: figure out aq way to properly pass in layout as an argument
struct
TransformConv
{
TransformConv
()
{}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
auto
transform_func
(
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_lengths
,
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_strides
,
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
{
if
(
NDimSpatial
==
2
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>(
out_lengths
,
out_strides
);
}
else
if
(
NDimSpatial
==
3
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>(
out_lengths
,
out_strides
);
}
else
if
(
NDimSpatial
==
1
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
out_lengths
,
out_strides
);
}
}
};
}
// namespace tensor_operation
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
5ec6a912
...
...
@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
"s"
(
src_resource
)
:
"memory"
);
#else
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
...
...
include/ck/utility/amd_smfmac.hpp
0 → 100644
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#pragma once
namespace
ck
{
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32f16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32f16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_16x16x32bf16
;
template
<
>
struct
intrin_smfmac_f32_16x16x32bf16
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16f16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16f16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_smfmac_f32_32x32x16bf16
;
template
<
>
struct
intrin_smfmac_f32_32x32x16bf16
<
32
,
32
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
}
};
}
// namespace ck
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment