Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1dbdab56
Commit
1dbdab56
authored
Aug 18, 2022
by
Jing Zhang
Browse files
merge develop
parents
d2e49b23
bac7df8f
Changes
192
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6519 additions
and
364 deletions
+6519
-364
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp
...tion/gpu/device/device_batched_contraction_multiple_d.hpp
+64
-0
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1015
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
.../tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
+86
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+915
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
...operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
+87
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
.../device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+916
-0
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
.../tensor_operation/gpu/device/device_batchnorm_forward.hpp
+44
-0
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
...ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
+41
-0
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+0
-247
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
..._operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
+86
-91
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
...ice/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
+0
-1
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
...ude/ck/tensor_operation/gpu/device/device_elementwise.hpp
+272
-24
include/ck/tensor_operation/gpu/device/device_elementwise_base.hpp
...k/tensor_operation/gpu/device/device_elementwise_base.hpp
+45
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
...peration/gpu/device/device_gemm_multiple_d_multiple_r.hpp
+85
-0
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+873
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
...ensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
+523
-0
include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp
...tion/gpu/device/device_grouped_contraction_multiple_d.hpp
+72
-0
include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
...ce/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
+908
-0
include/ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
.../ck/tensor_operation/gpu/device/device_layernorm_impl.hpp
+487
-0
No files found.
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceBatchedContractionMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_contraction_multiple_d_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
FloatDsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
ignore
=
compute_ptr_offset_of_batch
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
// B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
// D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
// E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedContractionMultipleD_Xdl_CShuffle
:
public
DeviceBatchedContractionMultipleD
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedContractionMultipleD_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides_vec
)
{
assert
(
a_gs_ms_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
&&
a_gs_ms_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_gs_ms_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_gs_ms_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
// Assume: B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides_vec
)
{
assert
(
b_gs_ns_ks_lengths_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
&&
b_gs_ns_ks_strides_vec
.
size
()
==
NumDimG
+
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_gs_ns_ks_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_gs_ns_ks_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
NumDimG
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
e_ms_ns_strides
[
Number
<
NumDimM
-
1
>
{}],
e_ms_ns_strides
[
Number
<
NumDimM
+
NumDimN
-
1
>
{}]));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
// assume E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_G_M_N
(
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides_vec
)
{
assert
(
e_gs_ms_ns_lengths_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
&&
e_gs_ms_ns_strides_vec
.
size
()
==
NumDimG
+
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
start
,
auto
end
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
start
+
i
];
},
Number
<
end
-
start
>
{});
};
const
auto
e_gs_ms_ns_lengths
=
to_tuple
(
e_gs_ms_ns_lengths_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
const
auto
e_gs_ms_ns_strides
=
to_tuple
(
e_gs_ms_ns_strides_vec
,
Number
<
0
>
{},
Number
<
NumDimG
+
NumDimM
+
NumDimN
>
{});
// dimension Ids for G0, G1, ...
constexpr
auto
gDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimG
,
1
>::
type
{};
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
,
NumDimG
+
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimG
+
NumDimM
,
NumDimG
+
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for G0, G1, ...
const
auto
gLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
gDimIds
);
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_gs_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
G
=
container_reduce
(
gLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_g_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
G
,
M
,
N
),
make_tuple
(
e_gs_ms_ns_strides
[
Number
<
NumDimG
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
-
1
>
{}],
e_gs_ms_ns_strides
[
Number
<
NumDimG
+
NumDimM
+
NumDimN
-
1
>
{}]));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
else
{
// naive tensor E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_gs_ms_ns
=
make_naive_tensor_descriptor
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
);
// transformed tensor E[G = G0 * G1 * ..., MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 *
// N2 * ...]
const
auto
e_grid_desc_g_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_gs_ms_ns
,
make_tuple
(
make_merge_transform
(
gLengths
),
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
gDimIds
,
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// return matrix_padder.PadCDescriptor_M_N(e_grid_desc_g_mraw_nraw);
return
e_grid_desc_g_mraw_nraw
;
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
static
auto
MakeDsGridDescriptor_G_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths_vec
[
i
],
ds_gs_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
using
DsGridDesc_G_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_G_M_N
({},
{}))
>
;
using
EGridDesc_G_M_N
=
decltype
(
MakeEGridDescriptor_G_M_N
({},
{}));
struct
ComputePtrOffsetOfStridedBatch
{
ComputePtrOffsetOfStridedBatch
(
index_t
batch_stride_A
,
index_t
batch_stride_B
,
DsGridDesc_G_M_N
ds_grid_desc_g_m_n
,
EGridDesc_G_M_N
e_grid_desc_g_m_n
)
:
batch_stride_A_
(
batch_stride_A
),
batch_stride_B_
(
batch_stride_B
),
ds_grid_desc_g_m_n_
(
ds_grid_desc_g_m_n
),
e_grid_desc_g_m_n_
(
e_grid_desc_g_m_n
)
{
}
__host__
__device__
constexpr
long_index_t
GetAPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_A_
);
}
__host__
__device__
constexpr
long_index_t
GetBPtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
batch_stride_B_
);
}
__host__
__device__
constexpr
auto
GetDsPtrOffset
(
index_t
g_idx
)
const
{
std
::
array
<
long_index_t
,
NumDTensor
>
ds_offset
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
ds_offset
[
i
]
=
ds_grid_desc_g_m_n_
[
i
].
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
});
return
ds_offset
;
}
__host__
__device__
constexpr
long_index_t
GetEPtrOffset
(
index_t
g_idx
)
const
{
return
e_grid_desc_g_m_n_
.
CalculateOffset
(
make_multi_index
(
g_idx
,
0
,
0
));
}
private:
index_t
batch_stride_A_
;
index_t
batch_stride_B_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
};
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
(
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
ds_grid_desc_g_m_n_
{
DeviceOp
::
MakeDsGridDescriptor_G_M_N
(
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
)},
e_grid_desc_g_m_n_
{
DeviceOp
::
MakeEGridDescriptor_G_M_N
(
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_mz_stride_
{},
a_kz_stride_
{},
b_nz_stride_
{},
b_kz_stride_
{},
ds_nz_stride_
{},
e_nz_stride_
{},
a_batch_stride_
{
a_gs_ms_ks_strides
[
NumDimG
-
1
]},
b_batch_stride_
{
b_gs_ns_ks_strides
[
NumDimG
-
1
]},
compute_ptr_offset_of_batch_
{
a_batch_stride_
,
b_batch_stride_
,
ds_grid_desc_g_m_n_
,
e_grid_desc_g_m_n_
}
{
static_assert
(
NumDimG
>
0
&&
NumDimM
>
0
&&
NumDimN
>
0
&&
NumDimK
>
0
,
""
);
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_gs_ms_ns_lengths
[
i
],
ds_gs_ms_ns_strides
[
i
]);
});
// populate desc for Ds/E
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
// for sanity check of vector memory access
a_mz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
-
1
];
a_kz_stride_
=
a_gs_ms_ks_strides
[
NumDimG
+
NumDimM
+
NumDimK
-
1
];
b_nz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
-
1
];
b_kz_stride_
=
b_gs_ns_ks_strides
[
NumDimG
+
NumDimN
+
NumDimK
-
1
];
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_stride_
[
i
]
=
ds_gs_ms_ns_strides
[
i
][
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
e_nz_stride_
=
e_gs_ms_ns_strides
[
NumDimG
+
NumDimM
+
NumDimN
-
1
];
}
void
Print
()
const
{
std
::
cout
<<
"A[M, K]: "
<<
a_grid_desc_m_k_
<<
std
::
endl
;
std
::
cout
<<
"B[N, K]: "
<<
b_grid_desc_n_k_
<<
std
::
endl
;
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
std
::
cout
<<
"Ds[M, N]: "
<<
ds_grid_desc_m_n_
[
i
]
<<
std
::
endl
;
});
std
::
cout
<<
"E[M, N]: "
<<
e_grid_desc_m_n_
<<
std
::
endl
;
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
DsGridDesc_G_M_N
ds_grid_desc_g_m_n_
;
EGridDesc_G_M_N
e_grid_desc_g_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t
a_mz_stride_
;
index_t
a_kz_stride_
;
index_t
b_nz_stride_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
index_t
e_mz_stride_
;
index_t
e_nz_stride_
;
index_t
a_batch_stride_
;
index_t
b_batch_stride_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemmMultipleD_xdl_cshuffle has invalid setting"
);
}
const
index_t
G
=
arg
.
e_grid_desc_g_m_n_
.
GetLength
(
I0
);
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
)
*
G
;
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_contraction_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
G
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_etile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
{
return
false
;
}
// check vector access
static_assert
((
ABlockTransferSrcVectorDim
==
1
||
ABlockTransferSrcVectorDim
==
2
)
&&
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of B: could be on N or BK1 dimension
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
if
(
!
(
arg
.
ds_nz_stride_
[
i
]
==
1
&&
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
i
].
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid_d_access
=
false
;
}
});
if
(
valid_d_access
==
false
)
{
return
false
;
}
// vector memory access of E: always on NPerBlock dimension
if
(
!
((
arg
.
e_nz_stride_
==
1
&&
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
return
false
;
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
a_gs_ms_ks_strides
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_lengths
,
const
std
::
vector
<
index_t
>&
b_gs_ns_ks_strides
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_lengths
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_gs_ms_ns_strides
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_lengths
,
const
std
::
vector
<
index_t
>&
e_gs_ms_ns_strides
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
a_gs_ms_ns_lengths
,
a_gs_ms_ks_strides
,
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
,
ds_gs_ms_ns_lengths
,
ds_gs_ms_ns_strides
,
e_gs_ms_ns_lengths
,
e_gs_ms_ns_strides
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedContractionMultipleD_Xdl_CShuffle"
<<
"<"
<<
NumDimG
<<
", "
<<
NumDimM
<<
", "
<<
NumDimN
<<
", "
<<
NumDimK
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
ABlockTransferSrcVectorDim
<<
", "
<<
BBlockTransferSrcVectorDim
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
O
,
ck
::
index_t
Batch
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB0
,
ck
::
index_t
StrideB1
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB0
,
ck
::
index_t
BatchStrideB1
,
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
Acc0ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
typename
ALayout
,
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmGemm_Xdl_CShuffle
:
public
DeviceBatchedGemmGemm
<
ALayout
,
BLayout
,
B1Layout
,
CLayout
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmGemm_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b1_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
Gemm1KPerBlock
)
*
Gemm1KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
// TODO: implement finer-grained padding
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
Default
)
{
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
else
{
// pad both B1N and B1K
const
auto
B1K0
=
K
/
B1K1
;
const
auto
b1_grid_desc_n_k
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideB1_
(
BatchStrideB1
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB1_
);
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
private:
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideB1_
;
index_t
BatchStrideC_
;
};
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
// = ORaw
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
NRaw
,
Gemm1NRaw
,
StrideB1
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
Gemm1NRaw
,
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
// Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_gemm_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b1
,
void
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedGemmGemm_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmSoftmaxGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b0
,
const
void
*
p_b1
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
O
,
ck
::
index_t
Batch
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB0
,
ck
::
index_t
StrideB1
,
ck
::
index_t
StrideC
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB0
,
ck
::
index_t
BatchStrideB1
,
ck
::
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
B0ElementwiseOperation
b0_element_op
,
Acc0ElementwiseOperation
acc0_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
Acc0ElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmSoftmaxGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
Acc0ElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputeBasePtrOfStridedBatch
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
AccElementwiseOperation
acc_element_op
,
const
B1ElementwiseOperation
b1_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetABasePtr
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetBBasePtr
(
g_idx
)));
const
long_index_t
b1_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetB1BasePtr
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch
.
GetCBasePtr
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b1_grid
+
b1_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b1_grid
;
ignore
=
p_c_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
acc_element_op
;
ignore
=
b1_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b1_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
// Computes C = A * B0 * B1
// ^^^^^^ (Acc0)
// ^^^^^^^^^^^ (Acc1)
template
<
typename
ALayout
,
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
// Gemm0NPerBlock
index_t
KPerBlock
,
// Gemm0KPerBlock
index_t
Gemm1NPerBlock
,
index_t
Gemm1KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
B1K1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
index_t
B1BlockTransferSrcVectorDim
,
index_t
B1BlockTransferSrcScalarPerVector
,
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
struct
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
:
public
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
BLayout
,
B1Layout
,
CLayout
,
ADataType
,
BDataType
,
B1DataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
// Args: Gemm1KRaw, Gemm1NRaw, StrideB1
static
auto
MakeB1GridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b1_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
Gemm1KPerBlock
)
*
Gemm1KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
// TODO: implement finer-grained padding
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
Default
)
{
const
auto
B1K0
=
KRaw
/
B1K1
;
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
else
{
// pad both B1N and B1K
const
auto
B1K0
=
K
/
B1K1
;
const
auto
b1_grid_desc_n_k
=
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b1_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b1_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B1K0
,
B1K1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b1_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
Gemm1NPerBlock
)
*
Gemm1NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
struct
ComputeBasePtrOfStridedBatch
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideB1_
(
BatchStrideB1
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetABasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetBBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetB1BasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB1_
);
}
__host__
__device__
constexpr
long_index_t
GetCBasePtr
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
private:
index_t
BatchStrideA_
;
index_t
BatchStrideB_
;
index_t
BatchStrideB1_
;
index_t
BatchStrideC_
;
};
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
B1GridDesc_BK0_N_BK1
=
decltype
(
MakeB1GridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
B1DataType
*
p_b1_grid
,
CDataType
*
p_c_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
// = ORaw
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b1_grid_
{
p_b1_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b1_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeB1GridDescriptor_BK0_N_BK1
(
NRaw
,
Gemm1NRaw
,
StrideB1
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
Gemm1NRaw
,
StrideC
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
acc_element_op_
{
acc_element_op
},
b1_element_op_
{
b1_element_op
},
c_element_op_
{
c_element_op
},
batch_count_
(
Batch
),
compute_base_ptr_of_batch_
{
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b1_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
AccElementwiseOperation
acc_element_op_
;
B1ElementwiseOperation
b1_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
batch_count_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
// Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
B1GridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
ComputeBasePtrOfStridedBatch
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b1_grid_
,
arg
.
p_c_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
acc_element_op_
,
arg
.
b1_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
batch_count_
,
arg
.
compute_base_ptr_of_batch_
);
};
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
const
B1DataType
*
p_b1
,
CDataType
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_b1
,
p_c
,
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b1
,
void
*
p_c
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
Gemm1NRaw
,
index_t
Batch
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB1
,
index_t
StrideC
,
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideB1
,
index_t
BatchStrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
AccElementwiseOperation
acc_element_op
,
B1ElementwiseOperation
b1_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
B1DataType
*>
(
p_b1
),
static_cast
<
CDataType
*>
(
p_c
),
MRaw
,
NRaw
,
KRaw
,
Gemm1NRaw
,
Batch
,
StrideA
,
StrideB
,
StrideB1
,
StrideC
,
BatchStrideA
,
BatchStrideB
,
BatchStrideB1
,
BatchStrideC
,
a_element_op
,
b_element_op
,
acc_element_op
,
b1_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerBlock
<<
", "
<<
Gemm1NPerBlock
<<
", "
<<
Gemm1KPerBlock
<<
", "
<<
B1K1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
void
*
p_y
,
double
exponentialAverageFactor
,
void
*
resultRunningMean
,
void
*
resultRunningVariance
,
double
epsilon
,
void
*
resultSaveMean
,
void
*
resultSaveInvVariance
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <array>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormInfer
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
double
epsilon
,
const
void
*
estimatedMean
,
const
void
*
estimatedInvVariance
,
void
*
p_y
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormInferPtr
=
std
::
unique_ptr
<
DeviceBatchNormInfer
<
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
deleted
100644 → 0
View file @
d2e49b23
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
index_t
NDim
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
DeviceElementwise
<
2
,
1
,
NDim
,
ElementwiseFunctor
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
strides
[
I
];
},
Number
<
NDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
NDim
>
1
)
{
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NDim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
using
AGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
BGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
ComputeDataType
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
ElementwiseFunctor
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
a_strides
,
const
std
::
vector
<
index_t
>&
b_strides
,
const
std
::
vector
<
index_t
>&
c_strides
,
ElementwiseFunctor
functor
)
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_c_
(
p_c
),
lengths_
(
lengths
),
a_strides_
(
a_strides
),
b_strides_
(
b_strides
),
c_strides_
(
c_strides
),
functor_
(
functor
),
blockSize_
(
256
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
a_strides
,
gridSize_
,
blockSize_
);
b_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
b_strides
,
gridSize_
,
blockSize_
);
c_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
c_strides
,
gridSize_
,
blockSize_
);
}
const
ADataType
*
p_a_
;
const
BDataType
*
p_b_
;
CDataType
*
p_c_
;
std
::
vector
<
int
>
lengths_
;
AGridDesc_M
a_grid_desc_m_
;
BGridDesc_M
b_grid_desc_m_
;
CGridDesc_M
c_grid_desc_m_
;
std
::
vector
<
index_t
>
a_strides_
;
std
::
vector
<
index_t
>
b_strides_
;
std
::
vector
<
index_t
>
c_strides_
;
ElementwiseFunctor
functor_
;
index_t
blockSize_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinEltwise
,
ADataType
,
BDataType
,
CDataType
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m_
,
arg
.
b_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
functor_
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
size
()
!=
NDim
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
MPerThread
%
scalarPerVector
==
0
;
return
ret
;
};
if
(
!
IsScalarPerVectorValid
(
pArg
->
a_strides_
.
back
()
==
1
,
AScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
b_strides_
.
back
()
==
1
,
BScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
c_strides_
.
back
()
==
1
,
CScalarPerVector
))
return
false
;
return
true
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
2
>
p_inputs
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_inputs
[
0
]),
static_cast
<
const
BDataType
*>
(
p_inputs
[
1
]),
static_cast
<
CDataType
*>
(
p_outputs
[
0
]),
lengths
,
input_strides
[
0
],
input_strides
[
1
],
output_strides
[
0
],
functor
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
1dbdab56
...
...
@@ -14,7 +14,7 @@
#include "ck/tensor_operation/gpu/device/device_cgemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_
binary_
elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -538,48 +538,43 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
float
ave_time
=
0
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Subtract
=
ck
::
tensor_operation
::
element_wise
::
Subtract
;
using
GridwiseBinAdd
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Add
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
using
GridwiseBinSubtract
=
GridwiseBinaryElementwise_1D
<
CDataType
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Subtract
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
const
auto
add_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinAdd
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Add
>
;
const
auto
subtract_kernel
=
kernel_binary_elementwise_1d
<
GridwiseBinSubtract
,
CDataType
,
CDataType
,
CDataType
,
CGridDesc_M
,
CGridDesc_M
,
CGridDesc_M
,
Subtract
>
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
Subtract
=
ck
::
tensor_operation
::
element_wise
::
Subtract
;
using
GridwiseBinAdd
=
GridwiseElementwise_1D
<
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Add
,
MPerThread
,
Sequence
<
AScalarPerVector
,
BScalarPerVector
>
,
Sequence
<
CScalarPerVector
>>
;
using
GridwiseBinSubtract
=
GridwiseElementwise_1D
<
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Subtract
,
MPerThread
,
Sequence
<
AScalarPerVector
,
BScalarPerVector
>
,
Sequence
<
CScalarPerVector
>>
;
const
auto
add_kernel
=
kernel_elementwise_1d
<
GridwiseBinAdd
,
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Add
>
;
const
auto
subtract_kernel
=
kernel_elementwise_1d
<
GridwiseBinSubtract
,
Tuple
<
CGridDesc_M
,
CGridDesc_M
>
,
Tuple
<
CGridDesc_M
>
,
Tuple
<
const
CDataType
*
,
const
CDataType
*>
,
Tuple
<
CDataType
*>
,
Subtract
>
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
...
...
@@ -631,18 +626,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
block_2_ctile_map_
);
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
s
ubtract_kernel
,
dim3
(
grid_size
)
,
dim3
(
BlockS
ize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_
c_grid_
real_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_
desc_m_
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
s
tream_config
,
subtract_kernel
,
dim3
(
grid_s
ize
),
dim3
(
BlockSize
)
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
)
,
make_tuple
(
arg
.
c_grid_
desc_m_
)
,
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
)
,
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
))
,
make_tuple
(
arg
.
p_
c_grid_
real_
)
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -679,18 +674,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_size
)
,
dim3
(
BlockS
ize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_
c_grid_
imag_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_
desc_m_
,
Add
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_s
ize
),
dim3
(
BlockSize
)
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
)
,
make_tuple
(
arg
.
c_grid_
desc_m_
)
,
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
)
,
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
))
,
make_tuple
(
arg
.
p_
c_grid_
imag_
)
,
Add
{});
}
else
{
...
...
@@ -742,18 +737,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
block_2_ctile_map_
);
// c_real = aux - aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
s
ubtract_kernel
,
dim3
(
grid_size
)
,
dim3
(
BlockS
ize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_
c_grid_
real_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_
desc_m_
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
s
tream_config
,
subtract_kernel
,
dim3
(
grid_s
ize
),
dim3
(
BlockSize
)
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
)
,
make_tuple
(
arg
.
c_grid_
desc_m_
)
,
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
)
,
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
))
,
make_tuple
(
arg
.
p_
c_grid_
real_
)
,
Subtract
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -790,18 +785,18 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
arg
.
block_2_ctile_map_
);
// c_imag = aux + aux_2
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_size
)
,
dim3
(
BlockS
ize
),
0
,
arg
.
p_aux_grid_
,
arg
.
p_aux_2_grid_
,
arg
.
p_
c_grid_
imag_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
c_grid_
desc_m_
,
Add
{});
ave_time
+=
launch_and_time_kernel
(
stream_config
,
add_kernel
,
dim3
(
grid_s
ize
),
dim3
(
BlockSize
)
,
0
,
make_tuple
(
arg
.
c_grid_desc_m_
,
arg
.
c_grid_desc_m_
)
,
make_tuple
(
arg
.
c_grid_
desc_m_
)
,
make_tuple
(
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_grid_
)
,
const_cast
<
const
CDataType
*>
(
arg
.
p_aux_2_grid_
))
,
make_tuple
(
arg
.
p_
c_grid_
imag_
)
,
Add
{});
}
return
ave_time
;
...
...
include/ck/tensor_operation/gpu/device/device_conv2d_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
1dbdab56
...
...
@@ -13,7 +13,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_convnd_bwd_weight_nwc_kxc_nwk_xdl_cshuffle.hpp
View file @
1dbdab56
...
...
@@ -13,7 +13,6 @@
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_unary_elementwise_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
include/ck/tensor_operation/gpu/device/device_elementwise.hpp
View file @
1dbdab56
...
...
@@ -2,38 +2,286 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include <sstream>
#include "ck/utility/math.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_base.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "
device_base
.hpp"
#include "
ck/host_utility/kernel_launch
.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
struct
DeviceElementwise
:
public
BaseOperator
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim
,
index_t
MPerThread
,
typename
InScalarPerVectorSeq
,
typename
OutScalarPerVectorSeq
>
struct
DeviceElementwise
:
public
DeviceElementwiseBase
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
NumInputTensor
>
p_inputs
,
std
::
array
<
void
*
,
NumOutputTensor
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
ElementwiseFunctor
functor
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
ck
::
index_t
NumInputTensor
,
ck
::
index_t
NumOutputTensor
,
index_t
NDim
,
typename
ElementwiseFunctor
>
using
DeviceElementwisePtr
=
std
::
unique_ptr
<
DeviceElementwise
<
NumInputTensor
,
NumOutputTensor
,
NDim
,
ElementwiseFunctor
>>
;
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
static_assert
(
NumInput
==
InScalarPerVectorSeq
::
Size
()
&&
NumOutput
==
OutScalarPerVectorSeq
::
Size
(),
"Tuple size is inconsistent with the number of in/out!"
);
static
auto
GenerateInDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
nullptr
);
},
Number
<
NumInput
>
{});
};
static
auto
GenerateOutDataTypePointerTuple
()
{
return
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
nullptr
);
},
Number
<
NumOutput
>
{});
};
using
InDataTypePointerTuple
=
decltype
(
GenerateInDataTypePointerTuple
());
using
OutDataTypePointerTuple
=
decltype
(
GenerateOutDataTypePointerTuple
());
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
NumDim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
NumDim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
NumDim
>
1
)
{
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
NumDim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
template
<
index_t
TupleSize
>
static
auto
GenerateInOutGrid1dDescTuple
(
Number
<
TupleSize
>
)
{
return
generate_tuple
(
[
&
](
auto
)
{
if
constexpr
(
NumDim
>
1
)
{
return
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
);
}
else
{
return
MakeDescriptor_M
({
1
},
{
1
},
1
,
1
);
};
},
Number
<
TupleSize
>
{});
};
using
InGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDescTuple
(
Number
<
NumInput
>
{}));
using
OutGrid1dDescTuple
=
decltype
(
GenerateInOutGrid1dDescTuple
(
Number
<
NumOutput
>
{}));
using
GridwiseElementwise
=
GridwiseElementwise_1D
<
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
,
MPerThread
,
InScalarPerVectorSeq
,
OutScalarPerVectorSeq
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
:
lengths_
(
lengths
),
inStridesArray_
(
inStridesArray
),
outStridesArray_
(
outStridesArray
),
elementwise_op_
(
elementwise_op
),
blockSize_
(
256
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
in_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
InDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
const
DataType
*>
(
in_dev_buffers
[
I
.
value
]);
},
Number
<
NumInput
>
{});
out_dev_buffers_
=
generate_tuple
(
[
&
](
auto
I
)
{
using
DataType
=
remove_cvref_t
<
decltype
(
OutDataTypeTuple
{}[
I
])
>
;
return
static_cast
<
DataType
*>
(
out_dev_buffers
[
I
.
value
]);
},
Number
<
NumOutput
>
{});
in_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_M
(
lengths
,
inStridesArray
[
I
.
value
],
gridSize_
,
blockSize_
);
},
Number
<
NumInput
>
{});
out_grid_1d_desc_tuple_
=
generate_tuple
(
[
&
](
auto
I
)
{
return
MakeDescriptor_M
(
lengths
,
outStridesArray
[
I
.
value
],
gridSize_
,
blockSize_
);
},
Number
<
NumOutput
>
{});
}
InDataTypePointerTuple
in_dev_buffers_
;
OutDataTypePointerTuple
out_dev_buffers_
;
InGrid1dDescTuple
in_grid_1d_desc_tuple_
;
OutGrid1dDescTuple
out_grid_1d_desc_tuple_
;
std
::
array
<
index_t
,
NumDim
>
lengths_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray_
;
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseElementwise
,
InGrid1dDescTuple
,
OutGrid1dDescTuple
,
InDataTypePointerTuple
,
OutDataTypePointerTuple
,
ElementwiseOperation
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
in_grid_1d_desc_tuple_
,
arg
.
out_grid_1d_desc_tuple_
,
arg
.
in_dev_buffers_
,
arg
.
out_dev_buffers_
,
arg
.
elementwise_op_
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[
&
](
const
std
::
array
<
index_t
,
NumDim
>&
lengths
,
const
std
::
array
<
index_t
,
NumDim
>&
strides
,
index_t
scalarPerVector
)
{
if
(
strides
.
back
()
==
1
&&
lengths
.
back
()
%
scalarPerVector
==
0
)
return
true
;
if
(
strides
.
back
()
!=
1
&&
scalarPerVector
==
1
)
return
true
;
return
false
;
};
bool
valid
=
true
;
static_for
<
0
,
NumInput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
inStridesArray_
[
I
.
value
],
InScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
static_for
<
0
,
NumOutput
,
1
>
{}([
&
](
auto
I
)
{
if
(
!
IsScalarPerVectorValid
(
pArg
->
lengths_
,
pArg
->
outStridesArray_
[
I
.
value
],
OutScalarPerVectorSeq
::
At
(
I
)))
valid
=
false
;
});
return
valid
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
inStridesArray
,
outStridesArray
,
in_dev_buffers
,
out_dev_buffers
,
elementwise_op
);
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
};
// namespace device
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_elementwise_base.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <memory>
#include <array>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim
>
struct
DeviceElementwiseBase
:
public
BaseOperator
{
static
constexpr
int
NumInput
=
InDataTypeTuple
::
Size
();
static
constexpr
int
NumOutput
=
OutDataTypeTuple
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
NumDim
>
lengths
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumInput
>
inStridesArray
,
const
std
::
array
<
std
::
array
<
index_t
,
NumDim
>
,
NumOutput
>
outStridesArray
,
const
std
::
array
<
const
void
*
,
NumInput
>
in_dev_buffers
,
const
std
::
array
<
void
*
,
NumOutput
>
out_dev_buffers
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
// namespace device
template
<
typename
InDataTypeTuple
,
typename
OutDataTypeTuple
,
typename
ElementwiseOperation
,
index_t
NumDim
>
using
DeviceElementwiseBasePtr
=
std
::
unique_ptr
<
DeviceElementwiseBase
<
InDataTypeTuple
,
OutDataTypeTuple
,
ElementwiseOperation
,
NumDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// FIXME: DeviceGemmReduce type need to well define the problem
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
RsDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
QsElementwiseOperation
,
typename
RsElementwiseOperation
>
struct
DeviceGemmMultipleDMultipleR
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumRTensor
=
RsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
array
<
void
*
,
NumRTensor
>
p_rs
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
std
::
array
<
ck
::
index_t
,
NumDTensor
>
StrideDs
,
ck
::
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
QsElementwiseOperation
qs_element_op
,
RsElementwiseOperation
rs_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
RsDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
QsElementwiseOperation
,
typename
RsElementwiseOperation
>
using
DeviceGemmMultipleDMultipleRPtr
=
std
::
unique_ptr
<
DeviceGemmMultipleDMultipleR
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
RsDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
QsElementwiseOperation
,
RsElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatDsPointer
,
typename
FloatE
,
typename
FloatRsPointer
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
QsElementwiseOperation
,
typename
RsElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
RsGridDescriptor_MBlock_MPerBlock
,
typename
Block2ETileMap
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_multiple_d_multiple_r_xdl_cshuffle
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatDsPointer
p_ds_grid
,
FloatE
*
__restrict__
p_e_grid
,
FloatRsPointer
p_rs_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
QsElementwiseOperation
qs_element_op
,
const
RsElementwiseOperation
rs_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
RsGridDescriptor_MBlock_MPerBlock
rs_grid_desc_mblock_mperblock
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_rs_grid
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
qs_element_op
,
rs_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
rs_grid_desc_mblock_mperblock
,
block_2_etile_map
);
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
p_rs_grid
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
qs_element_op
;
ignore
=
rs_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
rs_grid_desc_mblock_mperblock
;
ignore
=
block_2_etile_map
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// GEMM:
// input : A[AK0, M, AK1]
// input : B[AK0, N, AK1]
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// output : R0[M], R1[M], ...
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
// R0 = r_op0(Q0), R1 = r_op1(Q1), ...
// Assume:
// D0, D1, ... and E have the same layout
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
ReduceAccDataType
,
typename
RsDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
QsElementwiseOperation
,
typename
RsElementwiseOperation
,
typename
ThreadReduceOperations
,
typename
RsGlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock
,
index_t
CDEReduceThreadTransferScalarPerVector_NPerBlock
,
index_t
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGemmMultipleDMultipleR_Xdl_CShuffle
:
public
DeviceGemmMultipleDMultipleR
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
RsDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
QsElementwiseOperation
,
RsElementwiseOperation
>
{
using
DeviceOp
=
DeviceGemmMultipleDMultipleR_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumRTensor
=
RsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerBlock
)
*
KPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
static
auto
MakeEGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideE
)
{
const
auto
e_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
DELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideE
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
DELayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideE
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
e_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
e_grid_desc_mraw_nraw
;
}
}
// assume D is packed tensor
static
auto
MakeRGridDescriptor_M
(
index_t
MRaw
)
{
const
auto
r_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M
return
transform_tensor_descriptor
(
r_grid_desc_mraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
}
else
{
// not pad M
return
r_grid_desc_mraw
;
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
(
1
,
1
,
1
));
using
RGridDesc_M
=
decltype
(
MakeRGridDescriptor_M
(
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleDMultipleR_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
ReduceAccDataType
,
RsDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
QsElementwiseOperation
,
RsElementwiseOperation
,
ThreadReduceOperations
,
InMemoryDataOperationEnum
::
Set
,
RsGlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
EGridDesc_M_N
,
RGridDesc_M
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDRThreadTransferClusterLengths_MPerBlock_NPerBlock
,
CDEReduceThreadTransferScalarPerVector_NPerBlock
,
RThreadTransferDstScalarPerVector_MPerBlock
,
LoopSched
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
,
void
*
p_e_grid
,
std
::
array
<
void
*
,
NumRTensor
>
p_rs_grid
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
QsElementwiseOperation
qs_element_op
,
RsElementwiseOperation
rs_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_ds_grid_
{},
// FIXME
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
p_rs_grid_
{},
// FIXME
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideE
)},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
r_grid_desc_m_
{
DeviceOp
::
MakeRGridDescriptor_M
(
MRaw
)},
rs_grid_desc_mblock_mperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
qs_element_op_
{
qs_element_op
},
rs_element_op_
{
rs_element_op
}
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
e_grid_desc_m_n_
,
r_grid_desc_m_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds_grid
[
i
]);
const
auto
d_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideDs
[
i
]);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
(
i
)
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d_grid_desc_m_n
);
});
static_for
<
0
,
NumRTensor
,
1
>
{}([
&
](
auto
i
)
{
using
RDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
RsDataType
>>
;
p_rs_grid_
(
i
)
=
static_cast
<
RDataType
*>
(
p_rs_grid
[
i
]);
rs_grid_desc_mblock_mperblock_
(
i
)
=
GridwiseGemm
::
MakeRGridDescriptor_MBlock_MPerBlock
(
r_grid_desc_m_
);
});
}
}
// private:
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
typename
GridwiseGemm
::
RsGridPointer
p_rs_grid_
;
// tensor descriptors
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
// FIXME: Ds desc may be of different
// type from E
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
RGridDesc_M
r_grid_desc_m_
;
StaticallyIndexedArray
<
typename
GridwiseGemm
::
RGridDescriptor_MBlock_MPerBlock
,
NumRTensor
>
rs_grid_desc_mblock_mperblock_
;
// block-to-e-tile map
typename
GridwiseGemm
::
DefaultBlock2ETileMap
block_2_etile_map_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
QsElementwiseOperation
qs_element_op_
;
RsElementwiseOperation
rs_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
const
index_t
grid_size
=
arg
.
block_2_etile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
const
auto
kernel
=
kernel_gemm_multiple_d_multiple_r_xdl_cshuffle
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
EDataType
,
typename
GridwiseGemm
::
RsGridPointer
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
QsElementwiseOperation
,
RsElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
NumDTensor
>
,
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
StaticallyIndexedArray
<
typename
GridwiseGemm
::
RGridDescriptor_MBlock_MPerBlock
,
NumRTensor
>
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
p_rs_grid_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
qs_element_op_
,
arg
.
rs_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
rs_grid_desc_mblock_mperblock_
,
arg
.
block_2_etile_map_
);
};
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
e_grid_desc_m_n_
,
arg
.
r_grid_desc_m_
,
arg
.
block_2_etile_map_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
array
<
void
*
,
NumRTensor
>
p_rs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
QsElementwiseOperation
qs_element_op
,
RsElementwiseOperation
rs_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_ds
,
p_e
,
p_rs
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
,
qs_element_op
,
rs_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_e
,
std
::
array
<
void
*
,
NumRTensor
>
p_rs
,
index_t
MRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
StrideA
,
index_t
StrideB
,
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
index_t
StrideE
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
,
QsElementwiseOperation
qs_element_op
,
RsElementwiseOperation
rs_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a
,
p_b
,
p_ds
,
p_e
,
p_rs
,
MRaw
,
NRaw
,
KRaw
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
a_element_op
,
b_element_op
,
cde_element_op
,
qs_element_op
,
rs_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmMultipleDMultipleR_Xdl_CShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_gemm_xdl_skip_b_lds.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_skip_b_lds_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsAddExtraM
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockBufferSize
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceGemmXdlSkipBLds
:
public
DeviceGemm
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static_assert
(
BBlockBufferSize
>=
2
);
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_skip_b_lds_v1
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferSrcScalarPerVector
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockBufferSize
,
Sequence
<
0
,
2
,
4
,
5
,
6
,
1
,
3
,
7
>
,
// CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
>
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
a_grid_desc_k0_m_k1_
=
DeviceGemmXdlSkipBLds
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceGemmXdlSkipBLds
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmXdlSkipBLds
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
=
GridwiseGemm
::
MakeBGridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
b_grid_desc_k0_n_k1_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
;
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceGemmXdlSkipBLds
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
}
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K0
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
if
(
has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_gemm_xdlops_skip_b_lds_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSkipBLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipBLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdlops_skip_b_lds_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSkipBLds
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSkipBLds
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
typename
GridwiseGemm
::
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
M01_
,
arg
.
N01_
);
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
1
,
1
,
a_element_op
,
b_element_op
,
c_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdlSkipBLds"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NumDTensor
>
struct
ContractionDesc
{
std
::
vector
<
index_t
>
a_ms_ks_lengths
;
std
::
vector
<
index_t
>
a_ms_ks_strides
;
std
::
vector
<
index_t
>
b_ns_ks_lengths
;
std
::
vector
<
index_t
>
b_ns_ks_strides
;
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_lengths
;
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>
ds_ms_ns_strides
;
std
::
vector
<
index_t
>
e_ms_ns_lengths
;
std
::
vector
<
index_t
>
e_ms_ns_strides
;
};
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
struct
DeviceGroupedContractionMultipleD
:
public
BaseOperator
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>
p_ds_vec
,
std
::
vector
<
void
*>
p_e_vec
,
std
::
vector
<
ContractionDesc
<
NumDTensor
>>
contraction_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
0 → 100644
View file @
1dbdab56
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
typename
ContractionMultiDKernelArg
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_contraction_multiple_d_xdl_cshuffle
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
contraction_args
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
contraction_arg_ptr
=
reinterpret_cast
<
const
ContractionMultiDKernelArg
*>
(
cast_pointer_to_generic_address_space
(
contraction_args
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
contraction_arg_ptr
[
group_id
].
block_start_
&&
block_id
<
contraction_arg_ptr
[
group_id
].
block_end_
))
&&
left
<=
right
)
{
if
(
block_id
<
contraction_arg_ptr
[
group_id
].
block_start_
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
contraction_arg_ptr
[
group_id
].
p_a_grid_
,
contraction_arg_ptr
[
group_id
].
p_b_grid_
,
contraction_arg_ptr
[
group_id
].
p_ds_grid_
,
contraction_arg_ptr
[
group_id
].
p_e_grid_
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
contraction_arg_ptr
[
group_id
].
a_grid_desc_ak0_m_ak1_
,
contraction_arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
contraction_arg_ptr
[
group_id
].
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
contraction_arg_ptr
[
group_id
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
contraction_arg_ptr
[
group_id
].
block_2_etile_map_
);
#else
ignore
=
contraction_args
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
#endif
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// Tensor Contraction:
// input : A
// input : B
// input : D0, D1, ...
// output : E
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// A[M0, M1, M2, ..., K0, K1, K2, ...]
// B[N0, N1, N2, ..., K0, K1, K2, ...]
// D[M0, M1, M2, ..., N0, N1, N2, ...]
// E[M0, M1, M2, ..., N0, N1, N2, ...]
template
<
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
TensorSpecialization
ASpec
,
TensorSpecialization
BSpec
,
TensorSpecialization
DESpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceGroupedContractionMultipleD_Xdl_CShuffle
:
public
DeviceGroupedContractionMultipleD
<
NumDimM
,
NumDimN
,
NumDimK
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
using
DeviceOp
=
DeviceGroupedContractionMultipleD_Xdl_CShuffle
;
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static
auto
MakeAGridDescriptor_M_K
(
const
std
::
vector
<
index_t
>&
a_ms_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
a_ms_ks_strides_vec
)
{
assert
(
a_ms_ks_lengths_vec
.
size
()
==
NumDimM
+
NumDimK
&&
a_ms_ks_strides_vec
.
size
()
==
NumDimM
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
a_ms_ks_lengths
=
to_tuple
(
a_ms_ks_lengths_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
const
auto
a_ms_ks_strides
=
to_tuple
(
a_ms_ks_strides_vec
,
Number
<
NumDimM
+
NumDimK
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimK
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
a_ms_ks_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
a_ms_ks_lengths
,
kDimIds
);
if
constexpr
(
ASpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
a_ms_ks_strides
[
Number
<
NumDimM
-
1
>
{}],
a_ms_ks_strides
[
Number
<
NumDimM
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
else
{
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...]
const
auto
a_grid_desc_ms_ks
=
make_naive_tensor_descriptor
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
// transformed tensor A[MRaw = M0 * M1 * M2 * ... , KRaw = K0 * K1 * K2 * ...]
const
auto
a_grid_desc_mraw_kraw
=
transform_tensor_descriptor
(
a_grid_desc_ms_ks
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
mDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
// Assume: B[N0, N1, N2, ..., K0, K1, K2, ...]
static
auto
MakeBGridDescriptor_N_K
(
const
std
::
vector
<
index_t
>&
b_ns_ks_lengths_vec
,
const
std
::
vector
<
index_t
>&
b_ns_ks_strides_vec
)
{
assert
(
b_ns_ks_lengths_vec
.
size
()
==
NumDimN
+
NumDimK
&&
b_ns_ks_strides_vec
.
size
()
==
NumDimN
+
NumDimK
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
b_ns_ks_lengths
=
to_tuple
(
b_ns_ks_lengths_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
const
auto
b_ns_ks_strides
=
to_tuple
(
b_ns_ks_strides_vec
,
Number
<
NumDimN
+
NumDimK
>
{});
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimN
,
1
>::
type
{};
// dimension Ids for K0, K1, ...
constexpr
auto
kDimIds
=
typename
arithmetic_sequence_gen
<
NumDimN
,
NumDimN
+
NumDimK
,
1
>::
type
{};
// lengths for K0, K1, ...
const
auto
kLengths
=
get_container_subset
(
b_ns_ks_lengths
,
kDimIds
);
// lengths for N0, N1, ...
const
auto
nLengths
=
get_container_subset
(
b_ns_ks_lengths
,
nDimIds
);
if
constexpr
(
BSpec
==
TensorSpecialization
::
Packed
)
{
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
K
=
container_reduce
(
kLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
b_ns_ks_strides
[
Number
<
NumDimN
-
1
>
{}],
b_ns_ks_strides
[
Number
<
NumDimN
+
NumDimK
-
1
>
{}]));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
{
// naive tensor B[N0, N1, N2, ..., K0, K1, K2, ...]
const
auto
b_grid_desc_ns_ks
=
make_naive_tensor_descriptor
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
// transformed tensor B[NRaw = N0 * N1 * N2 * ..., KRaw = K0 * K1 * K2 * ...]
const
auto
b_grid_desc_nraw_kraw
=
transform_tensor_descriptor
(
b_grid_desc_ns_ks
,
make_tuple
(
make_merge_transform
(
nLengths
),
make_merge_transform
(
kLengths
)),
make_tuple
(
nDimIds
,
kDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}
// assume E[M0, M1, M2, ..., N0, N1, N2...]
static
auto
MakeEGridDescriptor_M_N
(
const
std
::
vector
<
index_t
>&
e_ms_ns_lengths_vec
,
const
std
::
vector
<
index_t
>&
e_ms_ns_strides_vec
)
{
assert
(
e_ms_ns_lengths_vec
.
size
()
==
NumDimM
+
NumDimN
&&
e_ms_ns_strides_vec
.
size
()
==
NumDimM
+
NumDimN
);
const
auto
to_tuple
=
[
&
](
auto
&
vec
,
auto
num
)
{
return
generate_tuple
([
&
](
auto
i
)
{
return
vec
[
i
];
},
num
);
};
const
auto
e_ms_ns_lengths
=
to_tuple
(
e_ms_ns_lengths_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
const
auto
e_ms_ns_strides
=
to_tuple
(
e_ms_ns_strides_vec
,
Number
<
NumDimM
+
NumDimN
>
{});
// dimension Ids for M0, M1, ...
constexpr
auto
mDimIds
=
typename
arithmetic_sequence_gen
<
0
,
NumDimM
,
1
>::
type
{};
// dimension Ids for N0, N1, ...
constexpr
auto
nDimIds
=
typename
arithmetic_sequence_gen
<
NumDimM
,
NumDimM
+
NumDimN
,
1
>::
type
{};
// lengths for M0, M1, ...
const
auto
mLengths
=
get_container_subset
(
e_ms_ns_lengths
,
mDimIds
);
// lengths for K0, K1, ...
const
auto
nLengths
=
get_container_subset
(
e_ms_ns_lengths
,
nDimIds
);
if
constexpr
(
DESpec
==
TensorSpecialization
::
Packed
)
{
auto
M
=
container_reduce
(
mLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
auto
N
=
container_reduce
(
nLengths
,
math
::
multiplies
{},
Number
<
1
>
{});
const
auto
e_grid_desc_mraw_nraw
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
e_ms_ns_strides
[
Number
<
NumDimM
-
1
>
{}],
e_ms_ns_strides
[
Number
<
NumDimM
+
NumDimN
-
1
>
{}]));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
else
{
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...]
const
auto
e_grid_desc_ms_ns
=
make_naive_tensor_descriptor
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
// transformed tensor E[MRaw = M0 * M1 * M2 * ... , NRaw = N0 * N1 * N2 * ...]
const
auto
e_grid_desc_mraw_nraw
=
transform_tensor_descriptor
(
e_grid_desc_ms_ns
,
make_tuple
(
make_merge_transform
(
mLengths
),
make_merge_transform
(
nLengths
)),
make_tuple
(
mDimIds
,
nDimIds
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
matrix_padder
.
PadCDescriptor_M_N
(
e_grid_desc_mraw_nraw
);
}
}
static
auto
MakeDsGridDescriptor_M_N
(
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_lengths_vec
,
const
std
::
array
<
std
::
vector
<
index_t
>
,
NumDTensor
>&
ds_ms_ns_strides_vec
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
return
DeviceOp
::
MakeEGridDescriptor_M_N
(
ds_ms_ns_lengths_vec
[
i
],
ds_ms_ns_strides_vec
[
i
]);
},
Number
<
NumDTensor
>
{});
}
using
AGridDesc_M_K
=
decltype
(
MakeAGridDescriptor_M_K
({},
{}));
using
BGridDesc_N_K
=
decltype
(
MakeBGridDescriptor_N_K
({},
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({{}},
{{}}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
({},
{}));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemmMultipleD_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CShuffleDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_M_K
,
BGridDesc_N_K
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
AK1
,
BK1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
false
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
false
,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
struct
GroupedContractionBlock2ETileMap
{
static_assert
(
std
::
is_same
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{})),
typename
GridwiseGemm
::
DefaultBlock2ETileMap
>::
value
,
"Wrong! Should be the same type name"
);
GroupedContractionBlock2ETileMap
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
ck
::
index_t
BlockStart
)
{
default_block_2_etile_map_
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
block_start_
=
BlockStart
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
return
default_block_2_etile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
I0
]
-
block_start_
));
}
// it's actually E-Tile
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
default_block_2_etile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
__host__
bool
CheckValidity
(
const
EGridDesc_M_N
&
e_grid_desc_m_n
)
const
{
return
default_block_2_etile_map_
.
CheckValidity
(
e_grid_desc_m_n
);
}
typename
GridwiseGemm
::
DefaultBlock2ETileMap
default_block_2_etile_map_
;
ck
::
index_t
block_start_
;
};
struct
ContractionMultiDKernelArg
{
// pointers
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// lock-to-e-tile map
GroupedContractionBlock2ETileMap
block_2_etile_map_
;
ck
::
index_t
block_start_
,
block_end_
;
};
struct
ContractionMultiDDeviceArg
{
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store
index_t
a_mz_stride_
;
index_t
a_kz_stride_
;
index_t
b_nz_stride_
;
index_t
b_kz_stride_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride_
;
// index_t e_mz_stride_;
index_t
e_nz_stride_
;
};
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>
p_ds_vec
,
std
::
vector
<
void
*>
p_e_vec
,
std
::
vector
<
ContractionDesc
<
NumDTensor
>>
contraction_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
:
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
{
group_count_
=
contraction_descs
.
size
();
if
(
!
(
group_count_
==
p_a_vec
.
size
()
&&
group_count_
==
p_b_vec
.
size
()
&&
group_count_
==
p_e_vec
.
size
()))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != a/b/e_vec.size"
);
}
contraction_multi_d_kernel_args_
.
reserve
(
group_count_
);
grid_size_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
group_count_
;
i
++
)
{
const
auto
p_a_grid
=
static_cast
<
const
ADataType
*>
(
p_a_vec
[
i
]);
const
auto
p_b_grid
=
static_cast
<
const
BDataType
*>
(
p_b_vec
[
i
]);
const
auto
p_e_grid
=
static_cast
<
EDataType
*>
(
p_e_vec
[
i
]);
const
auto
a_grid_desc_m_k
=
DeviceOp
::
MakeAGridDescriptor_M_K
(
contraction_descs
[
i
].
a_ms_ks_lengths
,
contraction_descs
[
i
].
a_ms_ks_strides
);
const
auto
b_grid_desc_n_k
=
DeviceOp
::
MakeBGridDescriptor_N_K
(
contraction_descs
[
i
].
b_ns_ks_lengths
,
contraction_descs
[
i
].
b_ns_ks_strides
);
DsGridDesc_M_N
ds_grid_desc_m_n
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid
;
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid
(
j
)
=
static_cast
<
const
DDataType
*>
(
p_ds_vec
[
i
][
j
]);
// D desc
ds_grid_desc_m_n
(
j
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
contraction_descs
[
i
].
ds_ms_ns_lengths
[
j
],
contraction_descs
[
i
].
ds_ms_ns_strides
[
j
]);
});
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
(
contraction_descs
[
i
].
e_ms_ns_lengths
,
contraction_descs
[
i
].
e_ms_ns_strides
);
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
);
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
const
index_t
grid_size_grp
=
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
)
.
CalculateGridSize
(
e_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockEnd
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
const
auto
block_2_etile_map
=
GroupedContractionBlock2ETileMap
(
e_grid_desc_m_n
,
BlockStart
);
// for sanity check of vector memory access
const
index_t
a_mz_stride
=
contraction_descs
[
i
].
a_ms_ks_strides
[
NumDimM
-
1
];
const
index_t
a_kz_stride
=
contraction_descs
[
i
].
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
];
const
index_t
b_nz_stride
=
contraction_descs
[
i
].
b_ns_ks_strides
[
NumDimN
-
1
];
const
index_t
b_kz_stride
=
contraction_descs
[
i
].
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
];
std
::
array
<
index_t
,
NumDTensor
>
ds_nz_stride
;
for
(
index_t
j
=
0
;
j
<
NumDTensor
;
++
j
)
{
ds_nz_stride
[
j
]
=
contraction_descs
[
i
].
ds_ms_ns_strides
[
j
][
NumDimM
+
NumDimN
-
1
];
}
const
index_t
e_nz_stride
=
contraction_descs
[
i
].
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
];
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
))
{
contraction_multi_d_kernel_args_
.
push_back
(
{
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_etile_map
,
BlockStart
,
BlockEnd
});
contraction_multi_d_device_args_
.
push_back
({
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
a_mz_stride
,
a_kz_stride
,
b_nz_stride
,
b_kz_stride
,
ds_nz_stride
,
e_nz_stride
});
}
}
}
std
::
vector
<
ContractionMultiDKernelArg
>
contraction_multi_d_kernel_args_
;
std
::
vector
<
ContractionMultiDDeviceArg
>
contraction_multi_d_device_args_
;
std
::
size_t
group_count_
;
index_t
grid_size_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
using
Argument
=
DeviceOp
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
bool
has_main_k_block_loop
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
const
auto
K
=
arg
.
contraction_multi_d_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
contraction_multi_d_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
)
!=
has_main_k_block_loop
)
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
}
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
contraction_multi_d_kernel_args_
.
data
(),
arg
.
contraction_multi_d_kernel_args_
.
size
()
*
sizeof
(
ContractionMultiDKernelArg
),
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_contraction_multiple_d_xdl_cshuffle
<
GridwiseGemm
,
ContractionMultiDKernelArg
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
has_main_k_block_loop_
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
group_count_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
);
};
if
(
has_main_k_block_loop
)
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
group_count_
;
i
++
)
{
const
auto
a_grid_desc_m_k_
=
arg
.
contraction_multi_d_device_args_
[
i
].
a_grid_desc_m_k_
;
const
auto
b_grid_desc_n_k_
=
arg
.
contraction_multi_d_device_args_
[
i
].
b_grid_desc_n_k_
;
const
auto
ds_grid_desc_m_n_
=
arg
.
contraction_multi_d_device_args_
[
i
].
ds_grid_desc_m_n_
;
const
auto
e_grid_desc_m_n_
=
arg
.
contraction_multi_d_device_args_
[
i
].
e_grid_desc_m_n_
;
const
auto
a_grid_desc_ak0_m_ak1_
=
arg
.
contraction_multi_d_kernel_args_
[
i
].
a_grid_desc_ak0_m_ak1_
;
const
auto
b_grid_desc_bk0_n_bk1_
=
arg
.
contraction_multi_d_kernel_args_
[
i
].
b_grid_desc_bk0_n_bk1_
;
const
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
arg
.
contraction_multi_d_kernel_args_
[
i
]
.
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
const
auto
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
arg
.
contraction_multi_d_kernel_args_
[
i
]
.
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
const
auto
block_2_etile_map_
=
arg
.
contraction_multi_d_kernel_args_
[
i
].
block_2_etile_map_
;
const
auto
a_mz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
a_mz_stride_
;
const
auto
a_kz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
a_kz_stride_
;
const
auto
b_nz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
b_nz_stride_
;
const
auto
b_kz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
b_kz_stride_
;
const
auto
ds_nz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
ds_nz_stride_
;
const
auto
e_nz_stride_
=
arg
.
contraction_multi_d_device_args_
[
i
].
e_nz_stride_
;
if
(
!
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
return
false
;
}
// check vector access
static_assert
((
ABlockTransferSrcVectorDim
==
1
||
ABlockTransferSrcVectorDim
==
2
)
&&
(
BBlockTransferSrcVectorDim
==
1
||
BBlockTransferSrcVectorDim
==
2
),
"wrong!"
);
// vector memory access of A: could be on M or AK1 dimension
if
constexpr
(
ABlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
a_mz_stride_
==
1
&&
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
a_kz_stride_
==
1
&&
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of B: could be on N or BK1 dimension
if
constexpr
(
BBlockTransferSrcVectorDim
==
1
)
{
if
(
!
(
b_nz_stride_
==
1
&&
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
else
{
if
(
!
(
b_kz_stride_
==
1
&&
b_grid_desc_bk0_n_bk1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
return
false
;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool
valid_d_access
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
if
(
!
(
ds_nz_stride_
[
j
]
==
1
&&
ds_grid_desc_mblock_mperblock_nblock_nperblock_
[
j
].
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
valid_d_access
=
false
;
}
});
if
(
valid_d_access
==
false
)
{
return
false
;
}
// vector memory access of E: always on NPerBlock dimension
if
(
!
(
e_nz_stride_
==
1
&&
e_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetLength
(
I3
)
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
))
{
return
false
;
}
}
return
true
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>
p_ds_vec
,
std
::
vector
<
void
*>
p_e_vec
,
std
::
vector
<
ContractionDesc
<
NumDTensor
>>
contraction_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
{
return
Argument
{
p_a_vec
,
p_b_vec
,
p_ds_vec
,
p_e_vec
,
contraction_descs
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>
p_a_vec
,
std
::
vector
<
const
void
*>
p_b_vec
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>
p_ds_vec
,
std
::
vector
<
void
*>
p_e_vec
,
std
::
vector
<
ContractionDesc
<
NumDTensor
>>
contraction_descs
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CDEElementwiseOperation
cde_element_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_a_vec
,
p_b_vec
,
p_ds_vec
,
p_e_vec
,
contraction_descs
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedContractionMultipleD_Xdl_CShuffle"
<<
"<"
<<
NumDimM
<<
", "
<<
NumDimN
<<
", "
<<
NumDimK
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
ABlockTransferSrcVectorDim
<<
", "
<<
BBlockTransferSrcVectorDim
<<
">"
;
// clang-format on
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
ContractionMultiDKernelArg
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
→
include/ck/tensor_operation/gpu/device/device_layernorm
_impl
.hpp
View file @
1dbdab56
...
...
@@ -7,15 +7,50 @@
#include <sstream>
#include "ck/utility/reduction_operator.hpp"
#include "ck/tensor_operation/gpu/device/device_
base
.hpp"
#include "ck/tensor_operation/gpu/device/device_
normalization
.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_multiblock.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce_common.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_layernorm
_welford_variance
.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_set_buffer_value.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
template
<
typename
GridwiseReduction
,
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
AccElementwiseOperation
,
typename
GridDesc_M_K
,
typename
GridDesc_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x_grid_desc_m_k
,
const
GridDesc_K
gamma_grid_desc_k
,
const
GridDesc_K
beta_grid_desc_k
,
const
GridDesc_M_K
y_grid_desc_m_k
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
const
XDataType
*
const
__restrict__
p_x_global
,
const
GammaDataType
*
const
__restrict__
p_gamma_global
,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
,
const
AccElementwiseOperation
acc_elementwise_op
)
{
GridwiseReduction
::
Run
(
x_grid_desc_m_k
,
gamma_grid_desc_k
,
beta_grid_desc_k
,
y_grid_desc_m_k
,
num_k_block_tile_iteration
,
epsilon
,
p_x_global
,
p_gamma_global
,
p_beta_global
,
p_y_global
,
acc_elementwise_op
);
};
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
...
...
@@ -39,7 +74,14 @@ template <typename XDataType,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
YDstVectorSize
>
struct
DeviceLayernorm
:
public
BaseOperator
struct
DeviceLayernormImpl
:
public
DeviceLayernorm
<
XDataType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
YDataType
,
AccElementwiseOperation
,
Rank
,
NumReduceDim
>
{
static_assert
(
(
KThreadSliceSize
%
GammaSrcVectorSize
==
0
),
...
...
@@ -51,27 +93,74 @@ struct DeviceLayernorm : public BaseOperator
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
// Used for freeloading of some handy functions from DeviceReduceMultiBlock
using
Reduction
=
DeviceReduceMultiBlock
<
XDataType
,
AccDataType
,
YDataType
,
Rank
,
NumReduceDim
,
reduce
::
Add
,
PassThrough
,
// InElementwiseOperation
AccElementwiseOperation
,
// AccElementwiseOperation
InMemoryDataOperationEnum
::
Set
,
false
,
// PropagateNan
false
,
// OutputIndex
false
,
// HaveIndexInputIfOutputIndex
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
1
>
;
// YDstVectorSize
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
auto
MakeSrc2dDescriptor
(
const
std
::
vector
<
index_t
>&
inLengths
,
const
std
::
vector
<
index_t
>&
inStrides
,
int
blkGroupSize
,
int
numBlockTileIteration
)
{
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
numSrcDim
=
Rank
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
const
auto
tupleSrcLengths
=
make_tuple_from_array
(
inLengths
,
Number
<
numSrcDim
>
{});
const
auto
tupleSrcStrides
=
make_tuple_from_array
(
inStrides
,
Number
<
numSrcDim
>
{});
const
auto
inDesc
=
make_naive_tensor_descriptor
(
tupleSrcLengths
,
tupleSrcStrides
);
const
auto
in_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
reduceAllDim
)
{
const
auto
one_dim_inDesc
=
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
tupleSrcLengths
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
numSrcDim
,
1
>::
type
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
transform_tensor_descriptor
(
one_dim_inDesc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
1
,
one_dim_inDesc
.
GetLength
(
Number
<
0
>
{})))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{}));
}
else
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
inLengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
inDesc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}();
const
auto
invariantLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
in_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
inPad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
inPad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceLength
;
auto
in_grid_desc_m_k_padded
=
transform_tensor_descriptor
(
in_grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
inPad_M
),
make_right_pad_transform
(
reduceLength
,
inPad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
(
in_grid_desc_m_k_padded
);
};
static
auto
MakeAffine1dDescriptor
(
const
std
::
vector
<
index_t
>&
Lengths
,
const
std
::
vector
<
index_t
>&
Strides
,
...
...
@@ -90,7 +179,7 @@ struct DeviceLayernorm : public BaseOperator
make_tuple
(
Sequence
<
0
>
{}));
const
auto
reduceTotalLength
=
grid_desc_k
.
GetLength
(
Number
<
0
>
{});
const
int
reduceSizePerBlock
=
Reduction
::
K_BlockTileSize
*
numBlockTileIteration
;
const
int
reduceSizePerBlock
=
K_BlockTileSize
*
numBlockTileIteration
;
const
auto
Pad_K
=
reduceSizePerBlock
*
blkGroupSize
-
reduceTotalLength
;
...
...
@@ -103,57 +192,60 @@ struct DeviceLayernorm : public BaseOperator
return
(
grid_desc_k_padded
);
};
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_M_K
=
decltype
(
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridDesc_K
=
decltype
(
MakeAffine1dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
true
>
;
struct
Argument
:
public
Reduction
::
Argument
using
GridwiseReduceLayernormGeneric
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
false
>
;
using
GridwiseReduceLayernormSweepOnce
=
GridwiseLayernormWelfordVariance_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
AccElementwiseOperation
,
GridDesc_M_K
,
GridDesc_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
XYSrcVectorDim
,
XSrcVectorSize
,
GammaSrcVectorSize
,
BetaSrcVectorSize
,
XYSrcVectorDim
,
YDstVectorSize
,
true
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccElementwiseOperation
acc_elementwise_op
,
AccDataType
epsilon
,
...
...
@@ -161,53 +253,76 @@ struct DeviceLayernorm : public BaseOperator
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
Reduction
::
Argument
(
lengths
,
xStrides
,
{},
{},
reduceDims
,
0.0
f
,
// alpha
0.0
f
,
// beta
p_x
,
nullptr
,
p_y
,
nullptr
,
acc_elementwise_op
,
PassThrough
{}),
epsilon_
(
epsilon
),
:
epsilon_
(
epsilon
),
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
p_y_
(
p_y
),
gammaStrides_
(
gammaStrides
),
betaStrides_
(
betaStrides
)
betaStrides_
(
betaStrides
),
acc_elementwise_op_
(
acc_elementwise_op
)
{
reduceLength_
.
resize
(
NumReduceDim
);
Lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
yStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
yStrides
,
reduceDims
);
long_index_t
invariant_total_length
;
long_index_t
reduce_total_length
;
std
::
tie
(
invariant_total_length
,
reduce_total_length
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
Lengths_
);
blkGroupSize_
=
1
;
numBlockTileIteration_
=
(
reduce_total_length
+
K_BlockTileSize
-
1
)
/
K_BlockTileSize
;
gridSize_
=
math
::
integer_least_multiple
(
invariant_total_length
,
M_BlockTileSize
)
/
M_BlockTileSize
*
blkGroupSize_
;
reduceLengths_
.
resize
(
NumReduceDim
);
for
(
int
i
=
0
;
i
<
NumReduceDim
;
++
i
)
{
reduceLength_
[
i
]
=
lengths
[
reduceDims
[
i
]];
reduceLength
s
_
[
i
]
=
lengths
[
reduceDims
[
i
]];
}
}
AccDataType
epsilon_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
std
::
vector
<
index_t
>
reduceLength_
;
YDataType
*
p_y_
;
std
::
vector
<
index_t
>
Lengths_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
reduceLengths_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
betaStrides_
;
std
::
vector
<
index_t
>
yStrides_
;
AccElementwiseOperation
acc_elementwise_op_
;
int
blkGroupSize_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
x_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLength_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLength_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
y_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
x_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
xStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
gamma_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
beta_grid_desc_k
=
MakeAffine1dDescriptor
(
arg
.
reduceLengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
const
auto
y_grid_desc_m_k
=
MakeSrc2dDescriptor
(
arg
.
Lengths_
,
arg
.
yStrides_
,
arg
.
blkGroupSize_
,
arg
.
numBlockTileIteration_
);
bool
sweep_once
=
x_grid_desc_m_k
.
GetLength
(
Number
<
1
>
{})
<=
KThreadClusterSize
*
KThreadSliceSize
;
...
...
@@ -234,19 +349,19 @@ struct DeviceLayernorm : public BaseOperator
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
dim3
(
arg
.
gridSize
),
dim3
(
arg
.
gridSize
_
),
dim3
(
BlockSize
),
0
,
x_grid_desc_m_k
,
gamma_grid_desc_k
,
beta_grid_desc_k
,
y_grid_desc_m_k
,
arg
.
numBlockTileIteration
,
arg
.
numBlockTileIteration
_
,
arg
.
epsilon_
,
arg
.
in_dev
_
,
arg
.
p_x
_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
arg
.
out_dev
_
,
arg
.
p_y
_
,
arg
.
acc_elementwise_op_
);
return
(
avg_time
);
...
...
@@ -263,12 +378,33 @@ struct DeviceLayernorm : public BaseOperator
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
!
Reduction
::
IsSupportedArgument
(
p_arg_
))
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
if
constexpr
(
XYSrcVectorDim
==
0
)
{
return
false
;
if
constexpr
(
NumInvariantDim
==
0
)
{
return
false
;
}
else
{
if
(
p_arg_
->
xStrides_
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
invariant_lowest_length
%
XSrcVectorSize
!=
0
)
return
false
;
};
}
else
{
if
(
p_arg_
->
xStrides_
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
XSrcVectorSize
!=
0
)
return
false
;
};
if
(
p_arg_
->
Lengths_
[
Rank
-
1
]
%
YDstVectorSize
!=
0
)
{
return
false
;
}
...
...
@@ -297,22 +433,25 @@ struct DeviceLayernorm : public BaseOperator
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
AccElementwiseOperation
acc_elementwise_op
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
yStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
AccDataType
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
,
AccElementwiseOperation
acc_elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
lengths
,
xStrides
,
gammaStrides
,
betaStrides
,
yStrides
,
reduceDims
,
acc_elementwise_op
,
epsilon
,
...
...
@@ -322,17 +461,19 @@ struct DeviceLayernorm : public BaseOperator
static_cast
<
YDataType
*>
(
p_y
));
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceLayernorm<"
<<
BlockSize
<<
","
;
str
<<
"DeviceLayernorm
Impl
<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"XYSrcVectorDim_"
<<
XYSrcVectorDim
<<
","
;
str
<<
"VectorSize_X"
<<
XSrcVectorSize
<<
"_Gamma"
<<
GammaSrcVectorSize
<<
"_Beta"
<<
BetaSrcVectorSize
<<
"_Y"
<<
YDstVectorSize
<<
">"
;
// clang-format on
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment