Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
72c9f129
Commit
72c9f129
authored
Sep 20, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
241c261f
ded0d83d
Changes
235
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
362 additions
and
127 deletions
+362
-127
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+5
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...mpl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+82
-34
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+30
-14
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
...tensor_operation/gpu/device/impl/device_reduce_common.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
...or_operation/gpu/device/impl/device_reduce_multiblock.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
...or_operation/gpu/device/impl/device_reduce_threadwise.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
...tion/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+33
-1
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
...operation/gpu/element/combined_element_wise_operation.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+52
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+15
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+18
-10
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+47
-8
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+9
-33
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+34
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+10
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
72c9f129
...
...
@@ -710,8 +710,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
return
false
;
}
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
72c9f129
...
...
@@ -586,23 +586,23 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
}
if
constexpr
(
NDimSpatial
==
1
)
{
if
constexpr
(
!
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
())
if
constexpr
(
!
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
())
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
2
)
{
if
constexpr
(
!
(
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
}
else
if
constexpr
(
NDimSpatial
==
3
)
{
if
constexpr
(
!
(
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
if
constexpr
(
!
(
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()))
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
72c9f129
...
...
@@ -102,10 +102,9 @@ __global__ void
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
);
const
index_t
n_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
long_index_t
e_batch_offset
=
const
long_index_t
e_group_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetEPtrOffset
(
g_idx
));
const
auto
&
ds_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
auto
&
ds_
group
_offset
=
compute_ptr_offset_of_groups
.
GetDsPtrOffset
(
g_idx
);
const
long_index_t
e_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetEPtrOffset
(
n_idx
));
...
...
@@ -118,14 +117,14 @@ __global__ void
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_
group
_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
const
auto
&
as_
group
_offset
=
compute_ptr_offset_of_groups
.
GetAsPtrOffset
(
g_idx
);
// compute_ptr_offset_of_n_ not need BatchStrideB so
// in case of MultiA is false but isMultiB is true
...
...
@@ -136,27 +135,27 @@ __global__ void
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
as_n_offset
[
i
];
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
as_n_offset
[
i
];
});
}
else
{
const
long_index_t
a_n_offset
=
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
);
static_for
<
0
,
1
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
batch
_offset
[
i
]
+
a_n_offset
;
});
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_
group
_offset
[
i
]
+
a_n_offset
;
});
}
const
auto
&
bs_
batch
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
const
auto
&
bs_
group
_offset
=
compute_ptr_offset_of_groups
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
batch
_offset
[
i
];
});
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_
group
_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -169,19 +168,19 @@ __global__ void
}
else
{
const
long_index_t
a_
batch
_offset
=
const
long_index_t
a_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetAPtrOffset
(
g_idx
));
const
long_index_t
b_
batch
_offset
=
const
long_index_t
b_
group
_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_groups
.
GetBPtrOffset
(
g_idx
));
const
long_index_t
a_n_offset
=
amd_wave_read_first_lane
(
compute_ptr_offset_of_n
.
GetAPtrOffset
(
n_idx
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_
batch
_offset
+
a_n_offset
,
p_bs_grid
+
b_
batch
_offset
,
p_as_grid
+
a_
group
_offset
+
a_n_offset
,
p_bs_grid
+
b_
group
_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_
batch
_offset
+
e_n_offset
,
p_e_grid
+
e_
group
_offset
+
e_n_offset
,
p_shared
,
a_element_op
,
b_element_op
,
...
...
@@ -283,7 +282,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename
BComputeDataType
=
AComputeDataType
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
index_t
NumGroupsToMerge
=
1
>
struct
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
...
...
@@ -302,6 +302,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
using
DeviceOp
=
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static_assert
(
NumGroupsToMerge
>=
1
);
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
...
...
@@ -318,7 +320,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ConvForwardSpecialization
,
true
/*SplitN*/
,
ADataType
,
EDataType
>
;
EDataType
,
NumGroupsToMerge
>
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
@@ -517,7 +520,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
...
...
@@ -545,7 +549,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_groups_ for multiple AB
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
...
...
@@ -565,8 +570,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
else
{
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_groups_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
1
]
*
conv_N_per_block_
;
// p_as and p_bs are pointers
...
...
@@ -583,7 +590,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
compute_ptr_offset_of_groups_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
1
]
*
conv_N_per_block_
;
...
...
@@ -602,7 +610,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
conv_to_gemm_transformer_d
);
});
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_groups_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
]
*
NumGroupsToMerge
;
compute_ptr_offset_of_n_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
1
]
*
conv_N_per_block_
;
// populate desc for Ds/E
...
...
@@ -726,7 +734,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
arg
.
a_g_n_c_wis_lengths_
[
I1
]
/
arg
.
conv_N_per_block_
;
const
index_t
gdx
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
index_t
gdy
=
arg
.
num_group_
;
const
index_t
gdy
=
arg
.
num_group_
/
NumGroupsToMerge
;
const
index_t
gdz
=
num_workgroups_per_Conv_N
;
const
auto
K
=
...
...
@@ -850,6 +858,10 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
namespace
ctc
=
tensor_layout
::
convolution
;
const
index_t
G
=
arg
.
b_g_k_c_xs_lengths_
[
I0
];
const
index_t
K
=
arg
.
b_g_k_c_xs_lengths_
[
I1
];
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
I2
];
// check device
if
(
get_device_name
()
==
"gfx908"
)
{
...
...
@@ -898,6 +910,42 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
}
}
}
else
if
constexpr
(
ConvForwardSpecialization
==
ConvolutionForwardSpecialization
::
Filter3x3
)
{
if
(
C
!=
1
)
{
return
false
;
}
for
(
index_t
i
=
0
;
i
<
NDimSpatial
;
++
i
)
{
const
index_t
filter_spatial_dim
=
arg
.
b_g_k_c_xs_lengths_
[
i
+
I3
];
if
(
filter_spatial_dim
!=
I3
)
{
return
false
;
}
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
if
constexpr
(
NumGroupsToMerge
>
1
)
{
if
(
!
(
C
==
1
))
{
return
false
;
}
if
(
G
%
NumGroupsToMerge
!=
0
)
{
return
false
;
}
if
constexpr
(
!
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
}
// check vector access of A
// FIXME: layout
...
...
@@ -907,13 +955,18 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ALayout
,
ctc
::
NWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NHWGC
>
||
is_same_v
<
ALayout
,
ctc
::
NDHWGC
>
)
{
const
index_t
C
=
arg
.
a_g_n_c_wis_lengths_
[
2
];
// Check access per C
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
C
%
ABlockTransferSrcScalarPerVector
==
0
))
{
// If not possible, check access per G
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
C
==
1
&&
is_NSpatialGC_GKSpatial_NSpatialGK
<
ALayout
,
BLayout
,
ELayout
>
()
&&
G
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
}
else
{
return
false
;
...
...
@@ -928,8 +981,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
BLayout
,
ctc
::
KZYXGC
>
)
{
const
index_t
C
=
arg
.
b_g_k_c_xs_lengths_
[
2
];
if
(
!
(
BBlockTransferSrcVectorDim
==
2
&&
C
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
...
...
@@ -953,8 +1004,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
DLayout
,
ctc
::
NWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
NDHWGK
>
||
is_same_v
<
DLayout
,
ctc
::
G_K
>
)
{
const
index_t
K
=
arg
.
ds_g_n_k_wos_lengths_
[
i
][
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid
=
false
;
...
...
@@ -999,8 +1048,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
is_same_v
<
ELayout
,
ctc
::
NWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NHWGK
>
||
is_same_v
<
ELayout
,
ctc
::
NDHWGK
>
)
{
const
index_t
K
=
arg
.
e_g_n_k_wos_lengths_
[
2
];
if
(
!
(
K
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
...
...
@@ -1298,7 +1345,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CDEBlockTransferScalarPerVector_NPerBlock
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
NumGroupsToMerge
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
View file @
72c9f129
...
...
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return
false
;
}
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
72c9f129
...
...
@@ -12,7 +12,7 @@ namespace device {
// 1d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NWG
K
_GKXC_NWG
C
()
constexpr
bool
is_NWG
C
_GKXC_NWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNW
K
_GKXC_GNW
C
()
constexpr
bool
is_GNW
C
_GKXC_GNW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWG
K
_GKYXC_NHWG
C
()
constexpr
bool
is_NHWG
C
_GKYXC_NHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
...
...
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNHW
K
_GKYXC_GNHW
C
()
constexpr
bool
is_GNHW
C
_GKYXC_GNHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNHWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCHW_GKYXC_NGKHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKHW
>
;
}
// 3d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NDHWG
K
_GKZYXC_NDHWG
C
()
constexpr
bool
is_NDHWG
C
_GKZYXC_NDHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNDHW
K
_GKZYXC_GNDHW
C
()
constexpr
bool
is_GNDHW
C
_GKZYXC_GNDHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
constexpr
bool
is_NGCDHW_GKZYXC_NGKDHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKDHW
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGC_GKSpatial_NSpatialGK
()
{
return
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatial
K
_GKSpatial_GNSpatial
C
()
constexpr
bool
is_GNSpatial
C
_GKSpatial_GNSpatial
K
()
{
return
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,7 +19,7 @@ namespace device {
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
vector
<
index_t
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
@@ -38,7 +38,7 @@ std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>&
template
<
index_t
Rank
,
int
NumReduceDim
>
std
::
pair
<
long_index_t
,
long_index_t
>
get_2d_lengths
(
const
std
::
array
<
index_t
,
Rank
>&
inLengths
)
{
static_assert
(
Rank
<=
6
,
"bigger Rank size not supported!"
);
static_assert
(
Rank
<=
12
,
"bigger Rank size not supported!"
);
long_index_t
invariant_total_length
=
1
;
long_index_t
reduce_total_length
=
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -51,7 +51,7 @@ struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
PropagateNan
,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
"Invalid thread cluster size assignments!"
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -47,7 +47,7 @@ struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
OutputIndex
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
...
include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise_multi_d.hpp
View file @
72c9f129
...
...
@@ -45,7 +45,7 @@ struct DeviceReduceThreadWiseMultiD : public DeviceReduceMultiD<InDataType,
OutElementwiseOperation
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
12
,
"Bigger Rank size is not supported!"
);
static_assert
(((
InSrcVectorDim
==
0
&&
MThreadSliceSize
%
InSrcVectorSize
==
0
)
||
(
InSrcVectorDim
==
1
&&
KThreadSliceSize
%
InSrcVectorSize
==
0
))
&&
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
72c9f129
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGC"
;
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct
NGCW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCW"
;
};
struct
NGCHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCHW"
;
};
struct
NGCDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCDHW"
;
};
// input tensor
// strided layout
struct
G_NW_C
:
public
BaseTensorLayout
...
...
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGK"
;
};
struct
NGKW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKW"
;
};
struct
NGKHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKHW"
;
};
struct
NGKDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKDHW"
;
};
// output tensor
// strided layout
struct
G_NW_K
:
public
BaseTensorLayout
...
...
include/ck/tensor_operation/gpu/element/combined_element_wise_operation.hpp
View file @
72c9f129
...
...
@@ -3,7 +3,6 @@
#pragma once
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
...
...
@@ -107,6 +106,9 @@ struct TrinaryWithUnaryCombinedOp
UnaryOp2
unary_op2_
{};
};
using
ScaleScalePass
=
UnaryCombinedOp
<
Scale
,
Scale
,
PassThrough
>
;
using
ScaleScaleRelu
=
UnaryCombinedOp
<
Scale
,
Scale
,
Relu
>
;
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
View file @
72c9f129
...
...
@@ -41,6 +41,55 @@ __global__ void
elementwise_op
);
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InAGridDescTuple
,
typename
InBGridDescTuple
,
typename
OutAGridDescTuple
,
typename
OutBGridDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
Block2TileMapA
,
typename
Block2TileMapB
,
typename
ElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_elementwise_dual
(
const
InBGridDescTuple
in_grid_desc_tuple_a
,
const
InBGridDescTuple
in_grid_desc_tuple_b
,
const
OutAGridDescTuple
out_grid_desc_tuple_a
,
const
OutBGridDescTuple
out_grid_desc_tuple_b
,
const
InDataTypePointerTuple
p_in_global_tuple_a
,
const
InDataTypePointerTuple
p_in_global_tuple_b
,
const
OutDataTypePointerTuple
p_out_global_tuple_a
,
const
OutDataTypePointerTuple
p_out_global_tuple_b
,
const
Block2TileMapA
block_2_tile_map_a
,
const
Block2TileMapB
block_2_tile_map_b
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
a_grid_size
)
{
if
(
get_block_1d_id
()
<
a_grid_size
)
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_a
,
out_grid_desc_tuple_a
,
p_in_global_tuple_a
,
p_out_global_tuple_a
,
block_2_tile_map_a
,
elementwise_op
,
get_block_1d_id
());
}
else
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_b
,
out_grid_desc_tuple_b
,
p_in_global_tuple_b
,
p_out_global_tuple_b
,
block_2_tile_map_b
,
elementwise_op
,
get_block_1d_id
()
-
a_grid_size
);
}
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InGridDescTuple
,
typename
OutGridDescTuple
,
...
...
@@ -133,7 +182,8 @@ struct GridwiseElementwise
const
InDataTypePointerTuple
&
p_in_global_tuple
,
const
OutDataTypePointerTuple
&
p_out_global_tuple
,
const
Block2TileMap
&
block_2_tile_map
,
const
ElementwiseOperation
&
elementwise_op
)
const
ElementwiseOperation
&
elementwise_op
,
const
index_t
block_id
=
get_block_1d_id
())
{
constexpr
auto
src_datas
=
generate_tuple
(
...
...
@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number
<
NumOutput
>
{});
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
const
index_t
m0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
M0PerBlock
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
72c9f129
...
...
@@ -417,6 +417,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -454,6 +461,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
struct
Problem
...
...
@@ -953,7 +961,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -970,7 +979,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
@@ -1105,7 +1115,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
if
constexpr
(
!
(
is_same
<
remove_cvref_t
<
CDataType
>
,
half_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
))
is_same
<
remove_cvref_t
<
CDataType
>
,
float
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
bhalf_t
>::
value
||
is_same
<
remove_cvref_t
<
CDataType
>
,
int32_t
>::
value
))
{
if
(
!
karg
.
IsReduceAdd
())
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
72c9f129
...
...
@@ -36,10 +36,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
_multi_d
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
...
@@ -56,7 +55,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
...
...
@@ -69,10 +68,9 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_
multi_d_
2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -93,7 +91,7 @@ __global__ void
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
template
<
typename
ALayout
,
...
...
@@ -454,6 +452,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
}
}();
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
#if 0
using GemmSpecialization = tensor_operation::device::GemmSpecialization;
if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
...
...
@@ -491,6 +496,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// not pad M or N
return c_grid_desc_mraw_nraw;
}
#endif
}
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
...
...
@@ -1016,7 +1022,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
!
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
...
...
@@ -1033,7 +1040,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
&&
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
72c9f129
...
...
@@ -562,6 +562,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_addr_offset
);
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_global_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
T
*
addr
)
{
static_assert
((
is_same
<
T
,
bhalf_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
))
||
(
is_same
<
T
,
half_t
>::
value
&&
(
N
==
2
||
N
==
4
||
N
==
8
)),
"wrong! not implemented"
);
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
{
vector_type
<
half_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2f16
(
bit_cast
<
half2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
half2_t
>()[
i
]);
});
}
#if defined(__gfx942__)
else
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
vector_type
<
bhalf_t
,
N
>
tmp
{
src_thread_data
};
static_for
<
0
,
N
/
2
,
1
>
{}([
&
](
auto
i
)
{
__builtin_amdgcn_global_atomic_fadd_v2bf16
(
bit_cast
<
bhalf2_t
*>
(
addr
)
+
i
,
tmp
.
template
AsType
<
bhalf2_t
>()[
i
]);
});
}
#endif
}
template
<
typename
T
,
index_t
N
>
__device__
void
amd_buffer_atomic_add_impl
(
const
typename
vector_type
<
T
,
N
>::
type
src_thread_data
,
int32x4_t
dst_wave_buffer_resource
,
...
...
@@ -907,6 +935,16 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
using
scalar_t
=
typename
scalar_type
<
vector_t
>::
type
;
constexpr
index_t
vector_size
=
scalar_type
<
vector_t
>::
vector_size
;
if
constexpr
(
is_same
<
T
,
bhalf_t
>::
value
)
{
if
(
dst_thread_element_valid
)
{
amd_global_atomic_add_impl
<
scalar_t
,
vector_size
>
(
src_thread_data
,
p_dst_wave
+
dst_thread_element_offset
);
}
}
else
{
#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
...
...
@@ -919,6 +957,7 @@ amd_buffer_atomic_add(const typename vector_type_maker<T, N>::type::type src_thr
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
}
#endif
}
}
// buffer_atomic_max requires:
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
72c9f129
...
...
@@ -358,13 +358,15 @@ struct DynamicBuffer
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
||
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
int32_t
>
;
#elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
bool
constexpr
use_amd_buffer_addressing
=
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
float
>
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
);
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
half_t
>
&&
scalar_per_x_vector
%
2
==
0
)
||
(
is_same_v
<
remove_cvref_t
<
scalar_t
>
,
bhalf_t
>
&&
scalar_per_x_vector
%
2
==
0
);
#else
bool
constexpr
use_amd_buffer_addressing
=
false
;
#endif
...
...
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
72c9f129
...
...
@@ -1341,7 +1341,7 @@ struct modulo : public base_transform<1, 1>
};
// 2D XOR, NOTE: "xor" is a keyword
template
<
typename
LowLengths
,
typename
RightShift
>
template
<
typename
LowLengths
>
struct
xor_t
:
public
base_transform
<
2
,
2
>
{
static
constexpr
auto
type_enum
=
coord_transform_enum
::
xor_t
;
...
...
@@ -1352,15 +1352,10 @@ struct xor_t : public base_transform<2, 2>
using
UpLengths
=
LowLengths
;
UpLengths
up_lengths_
;
RightShift
right_shift_
;
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
,
right_shift_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
()
:
up_lengths_
{}
{}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
:
up_lengths_
{
low_lengths
},
right_shift_
{
right_shift
}
{
}
CK_TILE_HOST_DEVICE
constexpr
xor_t
(
const
LowLengths
&
low_lengths
)
:
up_lengths_
{
low_lengths
}
{}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
...
...
@@ -1378,13 +1373,8 @@ struct xor_t : public base_transform<2, 2>
idx_low
(
number
<
0
>
{})
=
idx_up
[
number
<
0
>
{}];
const
auto
idx_low_1_tmp
=
(
idx_up
[
number
<
1
>
{}]
-
idx_up
[
number
<
0
>
{}]
*
right_shift_
)
%
up_lengths_
[
number
<
1
>
{}];
const
auto
idx_low_1
=
(
idx_low_1_tmp
>=
0
)
?
idx_low_1_tmp
:
up_lengths_
[
number
<
1
>
{}]
+
idx_low_1_tmp
;
idx_low
(
number
<
1
>
{})
=
idx_low_1
;
idx_low
(
number
<
1
>
{})
=
idx_up
[
number
<
1
>
{}]
^
(
idx_up
[
number
<
0
>
{}]
%
up_lengths_
[
number
<
1
>
{}]);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
...
...
@@ -1419,8 +1409,7 @@ struct xor_t : public base_transform<2, 2>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
;
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
;
}
// MUST be static function
...
...
@@ -1432,14 +1421,6 @@ struct xor_t : public base_transform<2, 2>
array
<
index_t
,
2
>
up_vector_lengths
=
low_vector_lengths
;
array
<
index_t
,
2
>
up_vector_strides
=
low_vector_strides
;
if
constexpr
(
ck_tile
::
is_known_at_compile_time
<
RightShift
>::
value
)
{
if
(
low_vector_lengths
[
1
]
!=
-
1
)
{
up_vector_lengths
(
1
)
=
gcd
(
low_vector_lengths
[
1
],
abs
(
right_shift_
));
}
}
return
make_tuple
(
up_vector_lengths
,
up_vector_strides
);
}
...
...
@@ -1452,10 +1433,6 @@ struct xor_t : public base_transform<2, 2>
print
(
up_lengths_
);
printf
(
", "
);
//
printf
(
"right_shift_: "
);
print
(
right_shift_
);
printf
(
"}"
);
}
};
...
...
@@ -1655,11 +1632,10 @@ CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
return
modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
template
<
typename
LowLengths
,
typename
RightShift
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
,
const
RightShift
&
right_shift
)
template
<
typename
LowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
return
xor_t
<
LowLengths
,
RightShift
>
{
low_lengths
,
right_shift
};
return
xor_t
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLength
,
typename
OffsetLength
>
...
...
include/ck_tile/core/config.hpp
View file @
72c9f129
...
...
@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -156,6 +157,14 @@
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
72c9f129
...
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
standard_asm
,
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return
uint16_t
(
u
.
int32
>>
16
);
}
CK_TILE_HOST
constexpr
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
static
constexpr
uint32_t
FP32_NAN
=
0x7fff0000
;
static
constexpr
uint32_t
ROUND_BIAS_FOR_BF16
=
0x7fff
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
uint32_t
tmp
;
asm
volatile
(
"
\n
\
v_cmp_u_f32 %0, %2, %2
\n
\
v_bfe_u32 %1, %2, 16, 1
\n
\
v_add3_u32 %1, %2, %1, %3
\n
\
v_cndmask_b32 %2, %1, %4, %0
\n
\
v_lshrrev_b32 %2, 16, %2
\n
\
"
:
"=s"
(
check_nan
),
"+v"
(
tmp
),
"+v"
(
u
.
fp32
)
:
"v"
(
ROUND_BIAS_FOR_BF16
),
"v"
(
FP32_NAN
));
return
uint16_t
(
u
.
int32
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard_asm
)
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
...
...
include/ck_tile/core/numeric/math.hpp
View file @
72c9f129
...
...
@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint
32
_t
sad
(
uint
32
_t
x
,
uint
32
_t
y
,
uint
32
_t
acc
)
CK_TILE_DEVICE
uint
16
_t
sad
_u16
(
uint
16
_t
x
,
uint
16
_t
y
,
uint
16
_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
CK_TILE_DEVICE
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t
res
;
asm
volatile
(
"v_sad_u32 %0, %1, %2, %3"
:
"=v"
(
res
)
:
"v"
(
x
),
"v"
(
y
),
"v"
(
acc
));
return
res
;
}
CK_TILE_HOST
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
...
...
Prev
1
2
3
4
5
6
7
8
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment