Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7c284291
Commit
7c284291
authored
Nov 13, 2023
by
Artur Wojcik
Browse files
Merge branch 'develop' into uif2-initial
parents
751432ca
600fc000
Changes
162
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1340 additions
and
233 deletions
+1340
-233
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+23
-9
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
...device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
...ion/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
...ice/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
.../device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+332
-150
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+126
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+10
-10
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
...operation/gpu/device/impl/device_image_to_column_impl.hpp
+14
-15
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
.../device/impl/device_normalization_bwd_gamma_beta_impl.hpp
+464
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp
...eration/gpu/device/impl/device_normalization_fwd_impl.hpp
+10
-10
include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
.../gpu/device/impl/device_normalization_fwd_splitk_impl.hpp
+10
-10
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+5
-2
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+51
-0
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
+264
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
...tion/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
+15
-9
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
7c284291
...
...
@@ -278,6 +278,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// clang-format off
str
<<
"DeviceGemm_Xdl_CShuffle"
<<
"<"
<<
getGemmSpecializationString
(
GemmSpec
)
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
...
...
@@ -296,7 +297,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
"PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
;
<<
PipelineVersionToString
[
PipelineVer
];
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
7c284291
...
...
@@ -59,7 +59,8 @@ template <typename ADataType,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXDL
,
typename
ComputeType
=
CDataType
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmXdlSplitKCShuffle
:
public
DeviceGemmSplitK
<
ALayout
,
BLayout
,
...
...
@@ -79,7 +80,6 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
// TODO: should be exposed as Tparams.
static
constexpr
index_t
NumGemmKPrefetchStage
=
1
;
static
constexpr
LoopScheduler
LoopSched
=
make_default_loop_scheduler
();
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
...
...
@@ -141,7 +141,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
K0_
,
index_t
K0
Padded
_
,
index_t
k_batch_
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
...
...
@@ -158,7 +158,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
MPadded_
,
NPadded_
,
KPadded_
,
K0_
,
K0
Padded
_
,
k_batch_
),
a_element_op
(
a_element_op_
),
b_element_op
(
b_element_op_
),
...
...
@@ -198,9 +198,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0
=
karg
.
K0
;
const
auto
K0
Padded
=
karg
.
K0
Padded
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
Padded
);
float
ave_time
=
0
;
...
...
@@ -342,7 +342,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
a_element_op
,
b_element_op
,
...
...
@@ -378,7 +378,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
GridwiseGemm
::
CalculateK0
Padded
(
K
,
KBatch
),
KBatch
,
a_element_op
,
b_element_op
,
...
...
@@ -392,7 +392,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
std
::
map
<
LoopScheduler
,
std
::
string
>
LoopSchedToString
{
{
LoopScheduler
::
Default
,
"Default"
},
{
LoopScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
PipelineVersion
,
std
::
string
>
PipelineVersionToString
{{
PipelineVersion
::
v1
,
"v1"
},
{
PipelineVersion
::
v2
,
"v2"
}};
str
<<
GridwiseGemm
::
GetTypeString
()
<<
" LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", PipelineVersion: "
<<
PipelineVersionToString
[
PipelineVer
];
return
str
.
str
();
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp
View file @
7c284291
...
...
@@ -517,7 +517,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
std
::
vector
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
block_2_ctile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
...
...
@@ -579,7 +579,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
7c284291
...
...
@@ -677,7 +677,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
std
::
vector
<
Block2ETileMap
>
block_2_etile_map_container_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOp
a_element_op_
;
...
...
@@ -746,7 +746,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
DeviceOp
::
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
DeviceOp
::
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp
View file @
7c284291
...
...
@@ -927,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
// element-wise op
OutElementwiseOperation
a_element_op_
;
...
...
@@ -999,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
remove_reference_t
<
DeviceOp
::
BGridDesc_B_K0_N0_N1_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp
View file @
7c284291
...
...
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
OutElementwiseOperation
a_element_op_
;
InElementwiseOperation
b_element_op_
;
...
...
@@ -647,7 +647,7 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
using
EmptyTuple
=
Tuple
<>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
7c284291
...
...
@@ -1197,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
Block2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -1276,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
View file @
7c284291
...
...
@@ -537,7 +537,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DefaultBlock2CTileMap
block_2_ctile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -601,7 +601,7 @@ struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
DeviceOp
::
DsGridDesc_M0_M10_M11_N0_N10_N11
,
DeviceOp
::
CGridDesc_M0_M10_M11_N0_N10_N11
,
DefaultBlock2CTileMap
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
,
has_double_loop
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
7c284291
...
...
@@ -428,7 +428,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DefaultBlock2CTileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
...
...
@@ -485,7 +485,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
NumDTensor
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
NumDTensor
>
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
7c284291
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
7c284291
...
...
@@ -9,8 +9,77 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Num
DTensor
>
template
<
index_t
Num
ATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
struct
ComputePtrOffsetOfStridedBatch
{
};
template
<
index_t
NumATensor
,
index_t
NumBTensor
,
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
>
1
||
NumBTensor
>
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
ComputePtrOffsetOfStridedBatch
(
Array
<
ck
::
index_t
,
NumATensor
>&
BatchStrideAs
,
Array
<
ck
::
index_t
,
NumBTensor
>&
BatchStrideBs
,
Array
<
ck
::
index_t
,
NumDTensor
>&
BatchStrideDs
,
index_t
BatchStrideE
)
:
BatchStrideA_
(
BatchStrideAs
),
BatchStrideB_
(
BatchStrideBs
),
BatchStrideDs_
(
BatchStrideDs
),
BatchStrideE_
(
BatchStrideE
)
{
}
__host__
__device__
constexpr
auto
GetAsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumATensor
>
as_offset
;
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
as_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
[
i
]);
});
return
as_offset
;
}
__host__
__device__
constexpr
auto
GetBsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumBTensor
>
bs_offset
;
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
bs_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
[
i
]);
});
return
bs_offset
;
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
Array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
ds_offset
(
i
)
=
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideDs_
[
i
]);
});
return
ds_offset
;
}
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
// alias for kernels without multiple D
[[
maybe_unused
]]
__host__
__device__
constexpr
long_index_t
GetCPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
Array
<
ck
::
index_t
,
NumATensor
>
BatchStrideA_
;
Array
<
ck
::
index_t
,
NumBTensor
>
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
index_t
NumATensor
,
index_t
NumBTensor
,
index_t
NumDTensor
>
struct
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
,
ck
::
enable_if_t
<
(
NumATensor
==
1
&&
NumBTensor
==
1
)
>>
{
ComputePtrOffsetOfStridedBatch
()
=
default
;
...
...
@@ -54,13 +123,67 @@ struct ComputePtrOffsetOfStridedBatch
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideE_
);
}
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
ck
::
index_t
BatchStrideA_
;
ck
::
index_t
BatchStrideB_
;
Array
<
ck
::
index_t
,
NumDTensor
>
BatchStrideDs_
;
index_t
BatchStrideE_
;
index_t
&
BatchStrideC_
=
BatchStrideE_
;
// alias for kernels without multiple D
};
template
<
bool
isTuple
,
typename
Tensors
>
constexpr
static
auto
GetNumABTensors
()
{
if
constexpr
(
isTuple
)
{
return
Number
<
Tensors
::
Size
()
>
{};
}
else
{
return
Number
<
1
>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetAGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
AsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
GridwiseGemm
,
typename
DataType
>
constexpr
static
auto
GetBGridPointer
()
{
if
constexpr
(
isTuple
)
{
return
typename
GridwiseGemm
::
BsGridPointer
{};
}
else
{
return
Tuple
<
const
DataType
*>
{};
}
}
template
<
bool
isTuple
,
typename
Id
,
typename
Type
>
constexpr
static
auto
UnpackDataType
()
{
if
constexpr
(
isTuple
)
{
// unpack if tuple
return
tuple_element_t
<
Id
{},
Type
>
{};
}
else
{
// if no, return Type
return
Type
{};
}
}
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
7c284291
...
...
@@ -265,10 +265,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
_padded
=
GridwiseGemm
::
CalculateK0
Padded
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
stride_c
);
...
...
@@ -297,7 +297,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
m_padded
,
n_padded
,
k_padded
,
k0
,
k0
_padded
,
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
...
...
@@ -320,8 +320,8 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
_padded
=
GridwiseGemm
::
CalculateK0
Padded
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
...
...
@@ -340,7 +340,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
K0
Padded
=
k0
_padded
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
...
...
@@ -362,7 +362,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
Padded
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
...
...
@@ -384,7 +384,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
throw
std
::
runtime_error
(
err
.
str
());
}
K0
=
karg
.
K0
;
K0
=
karg
.
K0
Padded
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp
View file @
7c284291
...
...
@@ -142,19 +142,18 @@ struct DeviceImageToColumnImpl
decltype
(
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
KPerBlock
,
OutputGridDesc
>
(
OutputGridDesc
{}))
>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>>
;
using
GridwiseTensorRearrangeKernel
=
GridwiseTensorRearrange
<
InputGridDesc
,
InputDataType
,
OutputGridDesc
,
OutputDataType
,
BlockSize
,
MPerBlock
,
KPerBlock
,
ThreadClusterLengths
,
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<>>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -224,7 +223,7 @@ struct DeviceImageToColumnImpl
InputGridDesc
in_grid_desc_m_k_
;
OutputGridDesc
out_grid_desc_m_k_
;
ComputePtrOffsetOfStridedBatch
<
I0
>
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
<>
compute_ptr_offset_of_batch_
;
};
struct
Invoker
:
public
BaseInvoker
...
...
@@ -246,7 +245,7 @@ struct DeviceImageToColumnImpl
OutputGridDesc
,
OutputDataType
,
Block2ETileMap
,
ComputePtrOffsetOfStridedBatch
<
I0
>
,
ComputePtrOffsetOfStridedBatch
<>
,
GridwiseTensorRearrangeKernel
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp
0 → 100644
View file @
7c284291
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_normalization_
fwd_
impl.hpp
View file @
7c284291
...
...
@@ -7,7 +7,7 @@
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_selector.hpp"
...
...
@@ -46,14 +46,14 @@ template <typename XDataType,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
,
bool
UseWelford
=
true
>
struct
DeviceNormalizationImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
struct
DeviceNormalization
Fwd
Impl
:
public
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(
...
...
@@ -461,7 +461,7 @@ struct DeviceNormalizationImpl : public DeviceNormalization<XDataType,
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationImpl<"
<<
BlockSize
<<
","
;
str
<<
"DeviceNormalization
Fwd
Impl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYSrcVectorDim
<<
","
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_normalization_splitk_impl.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_normalization_
fwd_
splitk_impl.hpp
View file @
7c284291
...
...
@@ -8,7 +8,7 @@
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization.hpp"
#include "ck/tensor_operation/gpu/device/device_normalization
_fwd
.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/normalization/gridwise_normalization_splitk_1st.hpp"
...
...
@@ -134,14 +134,14 @@ template <typename XDataType,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
SaveMeanInvStdDstVectorSize
>
struct
DeviceNormalizationSplitKImpl
:
public
DeviceNormalization
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
struct
DeviceNormalization
Fwd
SplitKImpl
:
public
DeviceNormalization
Fwd
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
SaveMeanInvStdDataType
,
YElementwiseOperation
,
Rank
,
NumReduceDim
>
{
using
WorkspaceMeanVarDataType
=
SaveMeanInvStdDataType
;
...
...
@@ -732,7 +732,7 @@ struct DeviceNormalizationSplitKImpl : public DeviceNormalization<XDataType,
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationSplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"DeviceNormalization
Fwd
SplitKImpl<"
<<
BlockSize
<<
","
;
str
<<
"Cluster_MK_"
<<
MThreadClusterSize
<<
"_"
<<
KThreadClusterSize
<<
","
;
str
<<
"Slice_MK_"
<<
MThreadSliceSize
<<
"_"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYVectorDim
<<
","
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
7c284291
...
...
@@ -85,10 +85,13 @@ struct Add
struct
ScaleAdd
{
__host__
__device__
ScaleAdd
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
ScaleAdd
(
float
scale
=
1.
f
)
:
scale_
(
scale
)
{}
template
<
typename
Y
,
typename
X0
,
typename
X1
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
;
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x0
,
const
X1
&
x1
)
const
{
y
=
ck
::
type_convert
<
Y
>
(
scale_
*
ck
::
type_convert
<
float
>
(
x0
)
+
ck
::
type_convert
<
float
>
(
x1
));
}
template
<
>
__host__
__device__
void
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
7c284291
...
...
@@ -16,6 +16,57 @@ namespace element_wise {
extern
"C"
__device__
float
__ocml_native_recip_f32
(
float
);
#endif
struct
PassThroughPack2
{
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
// fake conversion
uint16_t
t
=
ck
::
bit_cast
<
uint32_t
>
(
x
);
y
=
ck
::
bit_cast
<
ck
::
f8x2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
auto
t
=
type_convert
<
float2_t
>
(
x
);
y
=
type_convert
<
half2_t
>
(
t
);
}
__host__
__device__
constexpr
void
operator
()(
ck
::
half2_t
&
y
,
const
ck
::
half2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
f8x2_t
&
y
,
const
ck
::
f8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
float2_t
&
y
,
const
ck
::
float2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
int8x2_t
&
y
,
const
ck
::
int8x2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
bhalf2_t
&
y
,
const
ck
::
bhalf2_t
&
x
)
const
{
y
=
x
;
}
__host__
__device__
constexpr
void
operator
()(
ck
::
double2_t
&
y
,
const
ck
::
double2_t
&
x
)
const
{
y
=
x
;
}
constexpr
const
static
bool
is_pack2_invocable
=
true
;
};
struct
PassThrough
{
template
<
typename
Y
,
typename
X
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_3d.hpp
0 → 100644
View file @
7c284291
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
View file @
7c284291
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment