Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1fceec2f
Commit
1fceec2f
authored
Apr 26, 2022
by
Jehandad Khan
Browse files
Merge branch 'develop' into jd/dev_pkg
parents
698a442e
b39f07f1
Changes
18
Show whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
620 additions
and
82 deletions
+620
-82
example/01_gemm/gemm_xdl_fp16.cpp
example/01_gemm/gemm_xdl_fp16.cpp
+9
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+45
-20
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+11
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+9
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
+12
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+12
-10
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
+10
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
+7
-6
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
+7
-6
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+32
-3
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
+23
-3
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+120
-0
include/ck/utility/transpose_vectors.hpp
include/ck/utility/transpose_vectors.hpp
+176
-0
library/include/ck/library/host_tensor/host_tensor.hpp
library/include/ck/library/host_tensor/host_tensor.hpp
+1
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/fp16_transfer_bf16/CMakeLists.txt
test/fp16_transfer_bf16/CMakeLists.txt
+2
-0
test/fp16_transfer_bf16/fp16_transfer_bf16.cpp
test/fp16_transfer_bf16/fp16_transfer_bf16.cpp
+140
-0
No files found.
example/01_gemm/gemm_xdl_fp16.cpp
View file @
1fceec2f
...
...
@@ -45,12 +45,21 @@ using CElementOp = ck::tensor_operation::element_wise::PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// clang-format off
#if 1
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemm_Xdl_CShuffle
//######| ALayout| BLayout| CLayout|AData| BData| CData| GemmAcc| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//######| | | | Type| Type| Type| DataType| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector|
//######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
Row
,
Col
,
Row
,
F16
,
F16
,
F16
,
F32
,
F32
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
8
,
8
,
32
,
32
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
#else
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdl
//###########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
//###########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//###########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//###########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
F16
,
F16
,
F16
,
F32
,
Row
,
Row
,
Row
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
128
,
16
,
256
,
4
,
8
,
16
,
16
,
1
,
8
,
S
<
4
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
8
,
true
,
7
,
1
>
;
#endif
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
1fceec2f
...
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Compute
BasePr
tOfBatch
,
typename
Compute
PtrOffse
tOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -43,7 +68,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Compute
BasePr
tOfBatch
compute_
base_ptr
_of_batch
_
,
const
Compute
PtrOffse
tOfBatch
compute_
ptr_offset
_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -52,11 +77,11 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetA
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetA
PtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetB
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetB
PtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetC
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetC
PtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
return
globalblockid_to_m0_n0_block_cluster_adaptor
;
}
struct
Compute
BasePtr
OfStridedBatch
struct
Compute
PtrOffset
OfStridedBatch
{
Compute
BasePtr
OfStridedBatch
(
index_t
BatchStrideA
,
Compute
PtrOffset
OfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetA
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetA
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetB
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetC
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetC
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
...
...
@@ -359,7 +384,7 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_
base_ptr
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
compute_
ptr_offset
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
block_2_ctile_map_
{},
...
...
@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Compute
BasePtr
OfStridedBatch
compute_
base_ptr
_of_batch_
;
Compute
PtrOffset
OfStridedBatch
compute_
ptr_offset
_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -451,7 +476,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
...
...
@@ -472,7 +497,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -487,7 +512,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
...
...
@@ -508,7 +533,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
1fceec2f
...
...
@@ -18,6 +18,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
1fceec2f
...
...
@@ -110,6 +110,8 @@ template <typename FloatAB,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -180,7 +182,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
LDSDataType
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
...
...
@@ -366,7 +368,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -397,7 +399,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -425,12 +427,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -447,10 +450,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
LDSDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
LDSDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
1fceec2f
...
...
@@ -190,6 +190,8 @@ template <index_t BlockSize,
index_t
NumPrefetch
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -261,7 +263,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
LDSDataType
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -380,7 +382,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
...
@@ -486,7 +488,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -517,7 +519,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -548,7 +550,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
...
@@ -565,10 +567,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
LDSDataType
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
LDSDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4.hpp
View file @
1fceec2f
...
...
@@ -38,10 +38,11 @@ __global__ void
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
LDSDataType
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
LDSDataType
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
...
...
@@ -108,6 +109,7 @@ template <index_t BlockSize,
index_t
CThreadTransferDstScalarPerVector
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -161,7 +163,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr
auto
b_block_space_size
=
math
::
integer_least_multiple
(
b_k0_n_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
);
return
(
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
LDSDataType
);
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
...
@@ -263,7 +265,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
@@ -320,7 +322,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
LDSDataType
*
__restrict__
p_shared_block
,
const
ABK0MK1GridDesc
&
a_b_k0_m_k1_grid_desc
,
const
BBK0NK1GridDesc
&
b_b_k0_n_k1_grid_desc
,
const
CM0N0M1N1M2M3M4N2GridDesc
&
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc
,
...
...
@@ -428,7 +430,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -458,7 +460,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -488,7 +490,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
@@ -504,8 +506,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
LDSDataType
*
p_a_block
=
p_shared_block
;
LDSDataType
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
1fceec2f
...
...
@@ -40,10 +40,11 @@ __global__ void
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
LDSDataType
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
__shared__
LDSDataType
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
...
...
@@ -111,6 +112,7 @@ template <index_t BlockSize,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
struct
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -167,7 +169,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
c_block_size
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
().
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
FloatAB
),
return
math
::
max
((
a_block_space_size
+
b_block_space_size
)
*
sizeof
(
LDSDataType
),
c_block_size
*
sizeof
(
FloatC
));
}
...
...
@@ -308,7 +310,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
LDSDataType
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -417,7 +419,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -447,7 +449,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -477,7 +479,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
...
@@ -493,8 +495,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
LDSDataType
*
p_a_block
=
p_shared_block
;
LDSDataType
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
...
@@ -574,7 +576,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared_block
),
static_cast
<
FloatC
*>
(
static_cast
<
void
*>
(
p_shared_block
)
)
,
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp
View file @
1fceec2f
...
...
@@ -116,6 +116,7 @@ template <
index_t
NumPrefetch
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -216,7 +217,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
LDSDataType
),
c_block_size
*
sizeof
(
FloatCShuffle
));
}
...
...
@@ -421,7 +422,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -452,7 +453,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -480,12 +481,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
k_pack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
k_pack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
LDSDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
...
@@ -502,10 +504,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
LDSDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
LDSDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r2.hpp
View file @
1fceec2f
...
...
@@ -122,6 +122,7 @@ template <
index_t
NumPrefetch
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -221,7 +222,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
LDSDataType
),
c_block_size
*
sizeof
(
FloatC
));
}
...
...
@@ -442,7 +443,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -473,7 +474,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -504,7 +505,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
...
@@ -521,10 +522,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
LDSDataType
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
LDSDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp
View file @
1fceec2f
...
...
@@ -131,6 +131,7 @@ template <
index_t
NumPrefetch
=
1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
{
using
LDSDataType
=
typename
TypeMap
<
FloatAB
>::
type
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -230,7 +231,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
LDSDataType
),
c_block_size
*
sizeof
(
FloatC
));
}
...
...
@@ -463,7 +464,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
...
...
@@ -493,7 +494,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
LDSDataType
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
...
...
@@ -523,7 +524,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
LDSDataType
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
...
@@ -540,10 +541,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
LDSDataType
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
LDSDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
1fceec2f
...
...
@@ -278,7 +278,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
((
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
||
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
||
(
is_same
<
int8_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
int8_t
,
remove_cvref_t
<
DstData
>>::
value
&&
...
...
@@ -340,8 +341,36 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// do data transpose
// TODO type_convert is not used yet!!!!!
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>::
value
)
{
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
}
else
{
transpose_convert_vectors
<
SrcData
,
DstData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
}
});
}
else
if
constexpr
(
SrcVectorDim
==
DstVectorDim
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
{
auto
NewSliceLengths
=
SliceLengths
{}.
template
Modify
(
Number
<
SrcVectorDim
>{},
Number
<
SliceLengths
{}[
SrcVectorDim
]
/
2
>
{});
auto
VectorStep
=
SliceLengths
{}
/
NewSliceLengths
;
static_ford
<
decltype
(
NewSliceLengths
)
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
auto
nidx
=
idx
*
VectorStep
;
auto
vhalf
=
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
half2_t
>(
nidx
);
dst_thread_scratch_
.
template
SetAsType
<
bhalf2_t
>(
nidx
,
type_convert
<
bhalf2_t
>
(
vhalf
));
});
}
else
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r3.hpp
View file @
1fceec2f
...
...
@@ -284,7 +284,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// TODO make this logic more generic for more sub-dword datatype
if
constexpr
(
SrcVectorDim
!=
DstVectorDim
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
&&
(
is_same
<
half_t
,
remove_cvref_t
<
DstData
>>::
value
||
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
)
{
// each transpose does
...
...
@@ -343,8 +344,27 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// do data transpose
// TODO type_convert is not used yet!!!!!
transpose_vectors
<
SrcData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
transpose_convert_vectors
<
SrcData
,
DstData
,
DstScalarPerVector
,
SrcScalarPerVector
>
{}(
src_vector_refs
,
dst_vector_refs
);
});
}
else
if
constexpr
(
SrcVectorDim
==
DstVectorDim
&&
SrcScalarPerVector
%
2
==
0
&&
DstScalarPerVector
%
2
==
0
&&
is_same
<
half_t
,
remove_cvref_t
<
SrcData
>>::
value
&&
is_same
<
bhalf_t
,
remove_cvref_t
<
DstData
>>::
value
)
{
auto
NewSliceLengths
=
SliceLengths
{}.
template
Modify
(
Number
<
SrcVectorDim
>{},
Number
<
SliceLengths
{}[
SrcVectorDim
]
/
2
>
{});
auto
VectorStep
=
SliceLengths
{}
/
NewSliceLengths
;
static_ford
<
decltype
(
NewSliceLengths
)
>
{}([
&
](
auto
idx
)
{
// convert from SrcData to DstData here
auto
nidx
=
idx
*
VectorStep
;
auto
vhalf
=
src_thread_scratch_tuple_
[
thread_scratch_id
].
template
GetAsType
<
half2_t
>(
nidx
);
dst_thread_scratch_
.
template
SetAsType
<
bhalf2_t
>(
nidx
,
type_convert
<
bhalf2_t
>
(
vhalf
));
});
}
else
...
...
include/ck/utility/data_type.hpp
View file @
1fceec2f
...
...
@@ -992,6 +992,113 @@ inline __host__ __device__ bhalf_t type_convert<bhalf_t, float>(float x)
return
uint16_t
(
u
.
int32
>>
16
);
}
// convert fp16 to bf16
template
<
>
inline
__host__
__device__
bhalf_t
type_convert
<
bhalf_t
,
half_t
>
(
half_t
x
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
static_cast
<
float
>
(
x
)};
return
uint16_t
(
u
.
int32
>>
16
);
}
template
<
>
inline
__host__
__device__
bhalf2_t
type_convert
<
bhalf2_t
,
half2_t
>
(
half2_t
x
)
{
float
y0
{
0
},
y1
{
0
};
bhalf2_t
y
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y
)
:
"v"
(
y0
),
"v"
(
y1
));
return
y
;
}
// TODO: deprecate this
template
<
typename
T
>
struct
inner_product_with_conversion
{
template
<
typename
X
,
index_t
N
>
__device__
T
operator
()(
typename
vector_type
<
X
,
N
>::
type
a
,
typename
vector_type
<
X
,
N
>::
type
b
)
const
{
const
vector_type
<
X
,
N
>
a_vector
{
a
};
const
vector_type
<
X
,
N
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
Scalars
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
Scalars
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
float_t
a
,
float_t
b
)
const
{
return
type_convert
<
T
>
(
a
)
*
type_convert
<
T
>
(
b
);
}
__device__
T
operator
()(
int8x4_t
a
,
int8x4_t
b
)
const
{
const
vector_type
<
int8_t
,
4
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
4
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
4
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
int8x8_t
a
,
int8x8_t
b
)
const
{
const
vector_type
<
int8_t
,
8
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
8
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
__device__
T
operator
()(
int8x16_t
a
,
int8x16_t
b
)
const
{
const
vector_type
<
int8_t
,
16
>
a_vector
{
a
};
const
vector_type
<
int8_t
,
16
>
b_vector
{
b
};
T
acc
=
0
;
static_for
<
0
,
16
,
1
>
{}([
&
](
auto
i
)
{
acc
+=
type_convert
<
T
>
(
a_vector
.
AsType
<
int8_t
>
()[
i
])
*
type_convert
<
T
>
(
b_vector
.
AsType
<
int8_t
>
()[
i
]);
});
return
acc
;
}
};
template
<
typename
T
>
struct
NumericLimits
{
...
...
@@ -1016,4 +1123,17 @@ struct NumericLimits<half_t>
__host__
__device__
static
constexpr
half_t
Lowest
()
{
return
bit_cast
<
half_t
>
(
binary_lowest
);
}
};
template
<
typename
T
>
struct
TypeMap
{
using
type
=
T
;
};
#if defined(__gfx90a__)
template
<
>
struct
TypeMap
<
ck
::
half_t
>
{
using
type
=
ck
::
bhalf_t
;
};
#endif
}
// namespace ck
include/ck/utility/transpose_vectors.hpp
View file @
1fceec2f
...
...
@@ -13,6 +13,182 @@ template <typename S,
typename
enable_if
<
is_scalar_type
<
S
>
::
value
,
bool
>::
type
=
false
>
struct
transpose_vectors
;
template
<
typename
Sx
,
typename
Sy
,
index_t
NX
,
index_t
NY
,
typename
enable_if
<
is_scalar_type
<
Sx
>
::
value
,
bool
>::
type
=
false
,
typename
enable_if
<
is_scalar_type
<
Sy
>::
value
,
bool
>::
type
=
false
>
struct
transpose_convert_vectors
;
__device__
void
convert_half2_to_bhalf2
(
const
half2_t
&
x
,
bhalf2_t
&
y
)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx{x};
vector_type<bhalf_t, 2> vy;
float v0 = static_cast<float>(vx.template AsType<half_t>()[I0]);
float v1 = static_cast<float>(vx.template AsType<half_t>()[I1]);
vy.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
y = vy.template AsType<bhalf2_t>()[I0];
#else
float
y0
{
0
},
y1
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
x
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y
)
:
"v"
(
y0
),
"v"
(
y1
));
#endif
}
__device__
void
transpose_half_to_bhalf_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
bhalf2_t
&
y0
,
bhalf2_t
&
y1
)
{
#if 0
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
const vector_type<half_t, 2> vx0{x0}, vx1{x1};
vector_type<bhalf_t, 2> vy0, vy1;
float v0 = static_cast<float>(vx0.template AsType<half_t>()[I0]);
float v1 = static_cast<float>(vx1.template AsType<half_t>()[I0]);
vy0.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy0.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
v0 = static_cast<float>(vx0.template AsType<half_t>()[I1]);
v1 = static_cast<float>(vx1.template AsType<half_t>()[I1]);
vy1.template AsType<bhalf_t>()(I0) = ck::type_convert<bhalf_t>(v0);
vy1.template AsType<bhalf_t>()(I1) = ck::type_convert<bhalf_t>(v1);
y0 = vy0.template AsType<bhalf2_t>()[I0];
y1 = vy1.template AsType<bhalf2_t>()[I0];
#else
float
yv0
{
0
},
yv1
{
0
};
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
yv0
)
:
"v"
(
x0
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1
\n
\
"
:
"=v"
(
yv1
)
:
"v"
(
x1
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y0
)
:
"v"
(
yv0
),
"v"
(
yv1
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
yv0
)
:
"v"
(
x0
));
asm
volatile
(
"
\n
\
v_cvt_f32_f16 %0, %1 src0_sel:WORD_1
\n
\
"
:
"=v"
(
yv1
)
:
"v"
(
x1
));
asm
volatile
(
"
\n
\
v_pack_b32_f16 %0, %1, %2 op_sel:[1, 1]
\n
\
"
:
"=v"
(
y1
)
:
"v"
(
yv0
),
"v"
(
yv1
));
#endif
}
template
<
index_t
NX
,
index_t
NY
>
struct
transpose_convert_vectors
<
half_t
,
half_t
,
NX
,
NY
>
{
// we got [NY * NX] ammount of S data to be transposed
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
half_t
;
using
VX
=
vector_type
<
half_t
,
s_per_x
>
;
using
VY
=
vector_type
<
half_t
,
s_per_y
>
;
__device__
void
operator
()(
const
StaticallyIndexedArray
<
const
VX
&
,
NX
>&
vx_tuple
,
StaticallyIndexedArray
<
VY
&
,
NY
>&
vy_tuple
)
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// reference to 2 half2_t data from vx_tuple
const
auto
&
x_s2_0
=
vx_tuple
[
ix
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
const
auto
&
x_s2_1
=
vx_tuple
[
ix
+
I1
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
// reference to 2 half2_t data from vy_tuple
auto
&
y_s2_0
=
vy_tuple
(
iy
).
template
AsType
<
half2_t
>()(
ix
/
I2
);
auto
&
y_s2_1
=
vy_tuple
(
iy
+
I1
).
template
AsType
<
half2_t
>()(
ix
/
I2
);
// transpose
transpose_fp16_2x2
(
x_s2_0
,
x_s2_1
,
y_s2_0
,
y_s2_1
);
});
});
}
};
template
<
index_t
NX
,
index_t
NY
>
struct
transpose_convert_vectors
<
half_t
,
bhalf_t
,
NX
,
NY
>
{
// we got [NY * NX] ammount of S data to be transposed
static
constexpr
index_t
s_per_x
=
NY
;
static
constexpr
index_t
s_per_y
=
NX
;
using
S
=
half_t
;
using
VX
=
vector_type
<
half_t
,
s_per_x
>
;
using
VY
=
vector_type
<
bhalf_t
,
s_per_y
>
;
__device__
void
operator
()(
const
StaticallyIndexedArray
<
const
VX
&
,
NX
>&
vx_tuple
,
StaticallyIndexedArray
<
VY
&
,
NY
>&
vy_tuple
)
{
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static_assert
((
NX
%
2
==
0
&&
NY
%
2
==
0
),
"wrong!"
);
// loop over 2x2 tile and transpose data from vx_tuple into vy_tuple
static_for
<
0
,
NY
,
2
>
{}([
&
](
auto
iy
)
{
static_for
<
0
,
NX
,
2
>
{}([
&
](
auto
ix
)
{
// reference to 2 half2_t data from vx_tuple
const
auto
&
x_s2_0
=
vx_tuple
[
ix
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
const
auto
&
x_s2_1
=
vx_tuple
[
ix
+
I1
].
template
AsType
<
half2_t
>()[
iy
/
I2
];
// reference to 2 half2_t data from vy_tuple
auto
&
y_s2_0
=
vy_tuple
(
iy
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
auto
&
y_s2_1
=
vy_tuple
(
iy
+
I1
).
template
AsType
<
bhalf2_t
>()(
ix
/
I2
);
// transpose
transpose_half_to_bhalf_2x2
(
x_s2_0
,
x_s2_1
,
y_s2_0
,
y_s2_1
);
});
});
}
};
// transpose fp16 2x2
__device__
void
transpose_fp16_2x2
(
const
half2_t
&
x0
,
const
half2_t
&
x1
,
half2_t
&
y0
,
half2_t
&
y1
)
{
...
...
library/include/ck/library/host_tensor/host_tensor.hpp
View file @
1fceec2f
...
...
@@ -35,7 +35,7 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
first
=
false
;
else
os
<<
delim
;
os
<<
static_cas
t
<
T
>
(
v
);
os
<<
ck
::
type_conver
t
<
T
>
(
v
);
}
return
os
;
}
...
...
test/CMakeLists.txt
View file @
1fceec2f
...
...
@@ -47,3 +47,4 @@ add_subdirectory(convnd_fwd)
add_subdirectory
(
reduce
)
add_subdirectory
(
conv2d_bwd_weight
)
# DONOT add client_app, that is tested via CI independently
add_subdirectory
(
fp16_transfer_bf16
)
test/fp16_transfer_bf16/CMakeLists.txt
0 → 100644
View file @
1fceec2f
add_test_executable
(
test_fp16_transfer_bf16 fp16_transfer_bf16.cpp
)
target_link_libraries
(
test_fp16_transfer_bf16 PRIVATE host_tensor
)
test/fp16_transfer_bf16/fp16_transfer_bf16.cpp
0 → 100644
View file @
1fceec2f
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "check_err.hpp"
#include "transpose_vectors.hpp"
#include "common_header.hpp"
using
SrcDataType
=
ck
::
half_t
;
using
DstDataType
=
ck
::
bhalf_t
;
__global__
void
gpu_convert_data
(
SrcDataType
*
in
,
DstDataType
*
out
,
int
size
)
{
using
namespace
ck
;
ck
::
index_t
num
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
*
2
;
const
auto
src_buf
=
ck
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
in
,
size
);
auto
dst_buf
=
ck
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
out
,
size
);
auto
src_data
=
src_buf
.
template
Get
<
ck
::
half2_t
>(
num
,
true
);
ck
::
bhalf2_t
dst_data
;
convert_half2_to_bhalf2
(
src_data
,
dst_data
);
dst_buf
.
template
Set
<
ck
::
bhalf2_t
>(
num
,
true
,
dst_data
);
}
__global__
void
gpu_transpose_convert_data
(
SrcDataType
*
in
,
DstDataType
*
out
,
const
int
size
,
const
int
stride
)
{
using
namespace
ck
;
ck
::
index_t
num
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
*
2
;
const
auto
src_buf
=
ck
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
in
,
size
);
auto
dst_buf
=
ck
::
make_dynamic_buffer
<
ck
::
AddressSpaceEnum
::
Global
>
(
out
,
size
);
int
x
=
num
%
stride
;
int
y
=
num
/
stride
;
int
num1
=
(
y
+
1
)
*
stride
+
x
;
auto
src_data0
=
src_buf
.
template
Get
<
ck
::
half2_t
>(
num
,
true
);
auto
src_data1
=
src_buf
.
template
Get
<
ck
::
half2_t
>(
num1
,
true
);
ck
::
bhalf2_t
dst_data0
,
dst_data1
;
transpose_half_to_bhalf_2x2
(
src_data0
,
src_data1
,
dst_data0
,
dst_data1
);
// rewrite
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
const
vector_type
<
bhalf_t
,
2
>
vx0
{
dst_data0
},
vx1
{
dst_data1
};
vector_type
<
bhalf_t
,
2
>
vy0
,
vy1
;
vy0
.
template
AsType
<
bhalf_t
>()(
I0
)
=
vx0
.
template
AsType
<
bhalf_t
>()[
I0
];
vy0
.
template
AsType
<
bhalf_t
>()(
I1
)
=
vx1
.
template
AsType
<
bhalf_t
>()[
I0
];
vy1
.
template
AsType
<
bhalf_t
>()(
I0
)
=
vx0
.
template
AsType
<
bhalf_t
>()[
I1
];
vy1
.
template
AsType
<
bhalf_t
>()(
I1
)
=
vx1
.
template
AsType
<
bhalf_t
>()[
I1
];
dst_buf
.
template
Set
<
ck
::
bhalf2_t
>(
num
,
true
,
vy0
.
template
AsType
<
ck
::
bhalf2_t
>()[
I0
]);
dst_buf
.
template
Set
<
ck
::
bhalf2_t
>(
num1
,
true
,
vy1
.
template
AsType
<
ck
::
bhalf2_t
>()[
I0
]);
}
void
host_convert_data
(
SrcDataType
*
in
,
DstDataType
*
out
,
size_t
len
)
{
for
(
int
i
=
0
;
i
<
len
;
i
++
)
{
out
[
i
]
=
ck
::
type_convert
<
ck
::
bhalf_t
,
ck
::
half_t
>
(
in
[
i
]);
}
}
int
main
(
int
,
char
*
[])
{
bool
pass
=
true
;
constexpr
int
N
=
4
;
constexpr
int
K
=
4
;
constexpr
int
size
=
N
*
K
;
constexpr
int
thread_num
=
size
/
2
;
// create tensor
Tensor
<
SrcDataType
>
src_n_k_host
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N
,
K
}),
std
::
vector
<
std
::
size_t
>
({
K
,
1
})));
Tensor
<
DstDataType
>
dst_n_k_host_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N
,
K
}),
std
::
vector
<
std
::
size_t
>
({
K
,
1
})));
Tensor
<
DstDataType
>
dst_n_k_device_result
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
N
,
K
}),
std
::
vector
<
std
::
size_t
>
({
K
,
1
})));
// init data
src_n_k_host
.
GenerateTensorValue
(
GeneratorTensor_3
<
SrcDataType
>
{
-
5
,
5
});
dst_n_k_host_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
DstDataType
>
{
0
});
dst_n_k_device_result
.
GenerateTensorValue
(
GeneratorTensor_1
<
DstDataType
>
{
0
});
// alloc gpu memory
DeviceMem
in_dev_buf
(
sizeof
(
SrcDataType
)
*
src_n_k_host
.
mDesc
.
GetElementSpace
());
DeviceMem
out_dev_buf
(
sizeof
(
DstDataType
)
*
dst_n_k_host_result
.
mDesc
.
GetElementSpace
());
// init gpu memory
in_dev_buf
.
ToDevice
(
src_n_k_host
.
mData
.
data
());
out_dev_buf
.
SetZero
();
// run cpu data convert
host_convert_data
(
src_n_k_host
.
mData
.
data
(),
dst_n_k_host_result
.
mData
.
data
(),
size
);
// run kernel to convert data
gpu_convert_data
<<<
1
,
thread_num
>>>
(
static_cast
<
SrcDataType
*>
(
in_dev_buf
.
GetDeviceBuffer
()),
static_cast
<
DstDataType
*>
(
out_dev_buf
.
GetDeviceBuffer
()),
src_n_k_host
.
mDesc
.
GetElementSpace
());
// read from gpu
out_dev_buf
.
FromDevice
(
dst_n_k_device_result
.
mData
.
data
());
pass
=
ck
::
utils
::
check_err
(
dst_n_k_device_result
.
mData
,
dst_n_k_host_result
.
mData
);
// run kernel to tanspos and convert data
gpu_transpose_convert_data
<<<
1
,
thread_num
/
2
>>>
(
static_cast
<
SrcDataType
*>
(
in_dev_buf
.
GetDeviceBuffer
()),
static_cast
<
DstDataType
*>
(
out_dev_buf
.
GetDeviceBuffer
()),
src_n_k_host
.
mDesc
.
GetElementSpace
(),
K
);
// read from gpu
out_dev_buf
.
FromDevice
(
dst_n_k_device_result
.
mData
.
data
());
pass
&=
ck
::
utils
::
check_err
(
dst_n_k_device_result
.
mData
,
dst_n_k_host_result
.
mData
);
#if 1
LogRangeAsType
<
float
>
(
std
::
cout
<<
"in : "
,
src_n_k_host
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out device: "
,
dst_n_k_device_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"out host: "
,
dst_n_k_host_result
.
mData
,
","
)
<<
std
::
endl
;
#endif
if
(
pass
)
{
std
::
cout
<<
"fp16 transfer to bf16: Pass"
<<
std
::
endl
;
return
0
;
}
else
{
std
::
cout
<<
"fp16 transfer to bf16: Fail"
<<
std
::
endl
;
return
-
1
;
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment