Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
a8629a98
Unverified
Commit
a8629a98
authored
Sep 26, 2023
by
zjing14
Committed by
GitHub
Sep 26, 2023
Browse files
Merge branch 'develop' into gemm_v2r3_kpad_fix
parents
8dc713ea
94bfa502
Changes
334
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1787 additions
and
372 deletions
+1787
-372
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+38
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
...de/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
+16
-55
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
+702
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+2
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+18
-28
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
...ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
+97
-0
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
...r_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
+0
-136
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-5
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+538
-0
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+9
-1
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
...operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
+18
-15
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+16
-2
include/ck/utility/amd_gemm_dpp.hpp
include/ck/utility/amd_gemm_dpp.hpp
+51
-5
include/ck/utility/amd_xdlops.hpp
include/ck/utility/amd_xdlops.hpp
+2
-0
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+123
-8
include/ck/utility/f8_utils.hpp
include/ck/utility/f8_utils.hpp
+109
-111
include/ck/utility/inner_product.hpp
include/ck/utility/inner_product.hpp
+12
-0
include/ck/utility/inner_product_dpp8.hpp
include/ck/utility/inner_product_dpp8.hpp
+4
-0
include/ck/utility/loop_scheduler.hpp
include/ck/utility/loop_scheduler.hpp
+26
-0
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
a8629a98
...
@@ -27,6 +27,12 @@ struct PassThrough
...
@@ -27,6 +27,12 @@ struct PassThrough
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
double
>
(
float
&
y
,
const
double
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
{
...
@@ -69,18 +75,36 @@ struct PassThrough
...
@@ -69,18 +75,36 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
half_t
>
(
float
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
float
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
{
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
int8_t
>
(
half_t
&
y
,
const
int8_t
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
__host__
__device__
void
operator
()
<
int8_t
,
int32_t
>
(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
{
y
=
type_convert
<
int8_t
>
(
x
);
y
=
type_convert
<
int8_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
float
>
(
int8_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
int8_t
>
(
x
);
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
__host__
__device__
void
operator
()
<
int4_t
,
int4_t
>
(
int4_t
&
y
,
const
int4_t
&
x
)
const
...
@@ -89,6 +113,7 @@ struct PassThrough
...
@@ -89,6 +113,7 @@ struct PassThrough
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
__host__
__device__
void
operator
()
<
f8_t
,
f8_t
>
(
f8_t
&
y
,
const
f8_t
&
x
)
const
{
{
...
@@ -118,6 +143,7 @@ struct PassThrough
...
@@ -118,6 +143,7 @@ struct PassThrough
{
{
y
=
type_convert
<
f8_t
>
(
x
);
y
=
type_convert
<
f8_t
>
(
x
);
}
}
#endif
};
};
struct
UnaryConvert
struct
UnaryConvert
...
@@ -146,6 +172,7 @@ struct ConvertBF16RTN
...
@@ -146,6 +172,7 @@ struct ConvertBF16RTN
}
}
};
};
#if defined CK_ENABLE_FP8
struct
ConvertF8SR
struct
ConvertF8SR
{
{
// convert to fp8 using stochastic rounding (SR)
// convert to fp8 using stochastic rounding (SR)
...
@@ -162,6 +189,7 @@ struct ConvertF8SR
...
@@ -162,6 +189,7 @@ struct ConvertF8SR
y
=
f8_convert_sr
<
Y
>
(
x
);
y
=
f8_convert_sr
<
Y
>
(
x
);
}
}
};
};
#endif
struct
Scale
struct
Scale
{
{
...
@@ -412,14 +440,19 @@ struct Swish
...
@@ -412,14 +440,19 @@ struct Swish
{
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
T
>
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
is_same
<
X
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
static_assert
(
is_same
<
Y
,
float
>::
value
||
is_same
<
Y
,
double
>::
value
||
is_same
<
Y
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
"Data type is not supported by this operation!"
);
y
=
x
/
(
ck
::
type_convert
<
T
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
float
bx
=
-
beta_
*
type_convert
<
float
>
(
x
);
y
=
type_convert
<
Y
>
(
x
/
(
1.
f
+
ck
::
math
::
exp
(
bx
)));
};
};
float
beta_
=
1.0
f
;
float
beta_
=
1.0
f
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp
View file @
a8629a98
...
@@ -7,11 +7,9 @@
...
@@ -7,11 +7,9 @@
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_dl_algorithm.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dl_dpp8.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_set.hpp"
...
@@ -19,8 +17,6 @@
...
@@ -19,8 +17,6 @@
namespace
ck
{
namespace
ck
{
using
GemmDlAlgorithm
=
tensor_operation
::
device
::
GemmDlAlgorithm
;
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -29,8 +25,7 @@ template <typename GridwiseGemm,
...
@@ -29,8 +25,7 @@ template <typename GridwiseGemm,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
CGridDesc_M0_M10_M11_N0_N10_N11
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
GemmDlAlgorithm
GemmDlAlg
=
GemmDlAlgorithm
::
Default
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
...
@@ -43,13 +38,6 @@ __global__ void
...
@@ -43,13 +38,6 @@ __global__ void
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
c_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
// DPP8 is currently only supported on gfx1030
#if !defined(__gfx1030__)
if
(
GemmDlAlg
==
GemmDlAlgorithm
::
Dpp8
)
{
return
;
}
#endif
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -100,8 +88,7 @@ template <index_t BlockSize,
...
@@ -100,8 +88,7 @@ template <index_t BlockSize,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1
,
typename
CThreadTransferSrcDstAccessOrder
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
CThreadTransferDstScalarPerVector
>
GemmDlAlgorithm
GemmDlAlg
=
GemmDlAlgorithm
::
Default
>
struct
GridwiseGemmDl_km_kn_mn_v1r3
struct
GridwiseGemmDl_km_kn_mn_v1r3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -257,45 +244,6 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
c_grid_desc_m_n
);
c_grid_desc_m_n
);
}
}
template
<
typename
ABlockDesc_BK0_BM_BK1
,
typename
BBlockDesc_BK0_BN_BK1
>
__host__
__device__
static
constexpr
auto
GetBlockwiseGemm
()
{
if
constexpr
(
GemmDlAlg
==
GemmDlAlgorithm
::
Dpp8
)
{
return
BlockwiseGemmDlDpp8_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_loop_BM0_BN0
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
ABlockDesc_BK0_BM_BK1
,
BBlockDesc_BK0_BN_BK1
,
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
}
else
{
return
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
ABlockDesc_BK0_BM_BK1
,
BBlockDesc_BK0_BN_BK1
,
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
}
}
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
AGridDesc_K0_M0_M1_K1
=
decltype
(
MakeAGridDescriptor_K0_M0_M1_K1
(
AGridDesc_K0_M_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
BGridDesc_K0_N0_N1_K1
=
decltype
(
MakeBGridDescriptor_K0_N0_N1_K1
(
BGridDesc_K0_N_K1
{}));
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
using
CGridDesc_M0_M10_M11_N0_N10_N11
=
...
@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
...
@@ -424,7 +372,20 @@ struct GridwiseGemmDl_km_kn_mn_v1r3
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// register
const
auto
blockwise_gemm
=
const
auto
blockwise_gemm
=
GetBlockwiseGemm
<
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
)
>
();
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
M1PerThreadM111
,
N1PerThreadN111
,
KPerThread
,
M11N11ThreadClusterM110Xs
,
M11N11ThreadClusterN110Xs
,
M1PerThreadM111
,
N1PerThreadN111
>
{};
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
constexpr
auto
c_m10_m11_n10_n11_thread_tensor_lengths
=
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
decltype
(
blockwise_gemm
)
::
GetCThreadTensorLengths_BM0_BM1_BN0_BN1
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_dpp.hpp
0 → 100644
View file @
a8629a98
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#if CK_USE_WAVES_PER_EU
__attribute__
((
amdgpu_waves_per_eu
(
CK_MIN_WAVES_PER_EU
,
CK_MAX_WAVES_PER_EU
)))
#endif
kernel_gemm_dpp
(
const
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1030__) || defined(__gfx1100__) || \
defined(__gfx1101__) || defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
auto
a_grid_desc_ak0_m_ak1
=
amd_wave_read_first_lane
(
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
karg
.
M
,
karg
.
K
,
karg
.
AK0
,
karg
.
StrideA
));
const
auto
b_grid_desc_bk0_n_bk1
=
amd_wave_read_first_lane
(
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
karg
.
K
,
karg
.
N
,
karg
.
BK0
,
karg
.
StrideB
));
const
auto
c_grid_desc_m_n
=
amd_wave_read_first_lane
(
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
p_shared
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
);
#else
ignore
=
karg
;
#endif
}
template
<
index_t
BlockSize
,
typename
ABDataType
,
typename
AccDataType
,
typename
CDataType
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MDppPerWave
,
index_t
NDppPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BBlockLdsExtraN
,
typename
CThreadTransferSrcDstAccessOrder
,
index_t
CThreadTransferSrcDstVectorDim
,
index_t
CThreadTransferDstScalarPerVector
,
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_ak0mak1_bk0nbk1_mn_dpp
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0PerBlock
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0PerBlock
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// return block_id to C matrix tile idx (m0, n0) mapping
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
)
*
MPerBlock
;
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
)
*
NPerBlock
;
}
__host__
static
auto
CalculateAK0
(
index_t
K
)
{
return
math
::
integer_divide_floor
(
K
,
AK1Value
);
}
__host__
static
auto
CalculateBK0
(
index_t
K
)
{
return
math
::
integer_divide_floor
(
K
,
BK1Value
);
}
// Argument
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
AK0
{
CalculateAK0
(
K
)},
BK0
{
CalculateBK0
(
K
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
AK0
;
index_t
BK0
;
};
// Argument
struct
Argument
:
public
Problem
,
public
tensor_operation
::
device
::
BaseArgument
{
__host__
Argument
(
const
ABDataType
*
p_a_grid_
,
const
ABDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
ABDataType
*
p_a_grid
;
const
ABDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
};
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
AK0PerBlock
>
{},
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
AK1
,
AK1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
AK0PerBlock
>
{},
Number
<
MPerBlock
>
{},
AK1
),
max_lds_align
);
}
}();
return
a_block_desc_ak0_m_ak1
;
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
BK0PerBlock
>
{},
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
BK1
,
BK1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
BK0PerBlock
>
{},
Number
<
NPerBlock
>
{},
BK1
),
max_lds_align
);
}
}();
return
b_block_desc_bk0_n_bk1
;
}
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABDataType
);
}
__host__
static
constexpr
bool
CheckValidity
(
const
Problem
&
problem
)
{
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
AK1
)
>>::
value
,
"Wrong! AK1 must be known at the time of compilation."
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
decltype
(
BK1
)
>>::
value
,
"Wrong! BK1 must be known at the time of compilation."
);
static_assert
(
MPerBlock
%
(
MPerDpp
*
MDppPerWave
)
==
0
,
"Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave."
);
static_assert
(
NPerBlock
%
(
NPerDpp
*
NDppPerWave
)
==
0
,
"Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave."
);
static_assert
(
KPerBlock
%
AK1Value
==
0
&&
KPerBlock
%
BK1Value
==
0
,
"Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1."
);
static_assert
(
AK1Value
%
ABlockTransferDstScalarPerVector_K1
==
0
,
"Invalid tuning parameters! AK1Value must be divisible by "
"ABlockTransferDstScalarPerVector_K1"
);
static_assert
(
BK1Value
%
BBlockTransferDstScalarPerVector_K1
==
0
,
"Invalid tuning parameters! BK1Value must be divisible by "
"BBlockTransferDstScalarPerVector_K1"
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
problem
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
problem
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
problem
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
problem
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
(
problem
.
K
%
KPerBlock
!=
0
)
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
problem
.
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
return
false
;
}
return
true
;
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
auto
num_loop
=
K
/
KPerBlock
;
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2
(
const
CGridDesc
&
c_grid_desc_m_n
)
{
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
DppSelector
<
ABDataType
,
MPerDpp
,
NPerDpp
>::
selected_dpp
.
k_per_dpp
);
using
BlockwiseGemm
=
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
<
BlockSize
,
ABDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerDpp
,
NPerDpp
,
MDppPerWave
,
NDppPerWave
,
KPack
>
;
return
BlockwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2
(
c_grid_desc_m_n
);
}
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
K
,
index_t
AK0
,
index_t
StrideA
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_k
=
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
N
,
index_t
BK0
,
index_t
StrideB
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_pass_through_transform
(
N
),
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}));
}
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
CDataType
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2
(
c_grid_desc_m_n
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_n2
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
auto
block_2_ctile_map
=
Block2CTileMap
{
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
)};
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
),
c_grid_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[AK0PerBlock, MPerBlock] is in LDS
// b_mtx[BK0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
DppSelector
<
ABDataType
,
MPerDpp
,
NPerDpp
>::
selected_dpp
.
k_per_dpp
);
auto
blockwise_gemm
=
BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
<
BlockSize
,
ABDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
MPerDpp
,
NPerDpp
,
MDppPerWave
,
NDppPerWave
,
KPack
>
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
AK0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
BK0PerBlock
,
0
,
0
);
// gridwise GEMM pipeline
const
auto
AK0
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
);
// (AK0 / AK0PerBlock) is always equal to (BK0 / BK0PerBlock)
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
AK0
/
AK0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
num_k_block_main_loop
);
// output: register to global memory
{
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_n2
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2
();
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
constexpr
auto
MPerThread
=
c_thread_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I4
);
constexpr
auto
NPerThread
=
c_thread_desc_m0_n0_m1_n1_m2_n2
.
GetLength
(
I5
);
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
);
const
index_t
m_thread_data_on_grid
=
m_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_grid
=
n_block_data_idx_on_grid
+
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_grid_to_m0_m1_m2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_grid_idx
=
m_thread_data_on_grid_to_m0_m1_m2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_grid
));
const
auto
n_thread_data_on_grid_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_grid_idx
=
n_thread_data_on_grid_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_grid
));
auto
c_thread_copy
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_n2
),
decltype
(
c_grid_desc_m0_n0_m1_n1_m2_n2
),
CElementwiseOperation
,
Sequence
<
M0
,
N0
,
I1
,
I1
,
MPerThread
,
NPerThread
>
,
CThreadTransferSrcDstAccessOrder
,
CThreadTransferSrcDstVectorDim
,
CThreadTransferDstScalarPerVector
,
CGlobalMemoryDataOperation
,
1
,
true
>
{
c_grid_desc_m0_n0_m1_n1_m2_n2
,
make_multi_index
(
m_thread_data_on_grid_idx
[
I0
],
n_thread_data_on_grid_idx
[
I0
],
m_thread_data_on_grid_idx
[
I1
],
n_thread_data_on_grid_idx
[
I1
],
m_thread_data_on_grid_idx
[
I2
],
n_thread_data_on_grid_idx
[
I2
]),
c_element_op
};
c_thread_copy
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_n2
,
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
c_grid_desc_m0_n0_m1_n1_m2_n2
,
c_grid_buf
);
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
a8629a98
...
@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -268,6 +268,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
"Invalid tuning param!"
);
static_assert
(
KPerBlock
%
AK1Value
==
0
&&
KPerBlock
%
BK1Value
==
0
,
"KPerBlock must be divisible by AK1Value and BK1Value!"
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
a8629a98
...
@@ -29,7 +29,9 @@ namespace ck {
...
@@ -29,7 +29,9 @@ namespace ck {
// E = cde_op(C, D0, D1, ...)
// E = cde_op(C, D0, D1, ...)
// Assume:
// Assume:
// D0, D1, ... and E have the same layout
// D0, D1, ... and E have the same layout
template
<
typename
ABDataType
,
// FIXME: don't assume A/B have same datatype
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
DsDataType
,
...
@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -96,17 +98,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
using
ABDataTypeAdjusted
=
ABDataType
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_KBatch_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -196,7 +187,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
ABData
Type
),
sizeof
(
Compute
Type
),
c_block_size
*
sizeof
(
CShuffleDataType
));
c_block_size
*
sizeof
(
CShuffleDataType
));
}
}
...
@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -401,8 +392,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// check tensor size: cannot be larger than 2GB each
// check tensor size: cannot be larger than 2GB each
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
A
B
DataType
)
<=
TwoGB
&&
if
(
!
(
a_grid_desc_kbatch_ak0_m_ak1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
A
BDataType
)
<=
TwoGB
&&
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
{
return
false
;
return
false
;
...
@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -470,8 +461,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CDEElementwiseOperation_
,
typename
CDEElementwiseOperation_
,
typename
Block2ETileMap
>
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
A
B
DataType
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
A
BDataType
*
__restrict__
p_b_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsGridPointer
p_ds_grid
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
...
@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -538,8 +529,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
Sequence
<
1
,
AK0PerBlock
,
MPerBlock
,
AK1
>
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_KBatch_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
A
B
DataType
,
ADataType
,
ABDataTypeAdjusted
,
ComputeType
,
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_grid_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
decltype
(
a_block_desc_kbatch_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -569,8 +560,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
Sequence
<
1
,
BK0PerBlock
,
NPerBlock
,
BK1
>
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_KBatch_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
A
BDataType
,
BDataType
,
ABDataTypeAdjusted
,
ComputeType
,
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_grid_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
decltype
(
b_block_desc_kbatch_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -606,11 +597,11 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataTypeAdjusted
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ComputeType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ABDataTypeAdjusted
,
ComputeType
,
AccDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -683,11 +674,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
),
static_cast
<
ComputeType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
ComputeType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
KPerBlock
/
AK1
,
0
,
0
);
...
@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
...
@@ -999,8 +989,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
index_t
KBatch
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
A
B
DataType
*>
(
p_a_grid_
);
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
A
BDataType
*>
(
p_b_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
using
DsGridDesc_M_N
=
using
DsGridDesc_M_N
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
a8629a98
...
@@ -4,7 +4,8 @@
...
@@ -4,7 +4,8 @@
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/utility/loop_scheduler.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_image_to_column.hpp
0 → 100644
View file @
a8629a98
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
InputGridDesc
,
typename
InputDataType
,
typename
OutputGridDesc
,
typename
OutputDataType
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
KPerBlock
,
typename
ThreadClusterLengths
,
index_t
ScalarPerVector
,
typename
Block2ETileMap
>
struct
GridwiseImageToColumn
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__device__
static
void
Run
(
const
InputGridDesc
&
in_grid_desc
,
const
InputDataType
*
__restrict__
p_in_global
,
const
OutputGridDesc
&
out_grid_desc
,
OutputDataType
*
__restrict__
p_out_global
,
const
Block2ETileMap
&
block_2_tile_map
)
{
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
MPerBlock
);
const
index_t
k_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
KPerBlock
);
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_desc
.
GetElementSpaceSize
());
auto
out_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_out_global
,
out_grid_desc
.
GetElementSpaceSize
());
auto
copy_global_to_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
Tuple
<
InputDataType
>
,
Tuple
<
OutputDataType
>
,
decltype
(
tie
(
in_grid_desc
)),
decltype
(
tie
(
out_grid_desc
)),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
Sequence
<
MPerBlock
,
KPerBlock
>
,
ThreadClusterLengths
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
I1
,
ScalarPerVector
,
Sequence
<
true
>
,
Sequence
<
true
>>
{
in_grid_desc
,
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
out_grid_desc
,
make_tuple
(
make_multi_index
(
m_block_data_idx_on_grid
,
k_block_data_idx_on_grid
)),
tensor_operation
::
element_wise
::
PassThrough
{}};
copy_global_to_global
.
Run
(
tie
(
in_grid_desc
),
tie
(
in_global_buf
),
tie
(
out_grid_desc
),
tie
(
out_global_buf
));
}
__host__
static
constexpr
bool
CheckValidity
(
const
InputGridDesc
&
in_grid_desc
,
const
OutputGridDesc
&
out_grid_desc
)
{
if
(
in_grid_desc
.
GetLength
(
I0
)
%
MPerBlock
!=
0
||
in_grid_desc
.
GetLength
(
I1
)
%
KPerBlock
!=
0
)
return
false
;
if
(
out_grid_desc
.
GetLength
(
I0
)
%
MPerBlock
!=
0
||
out_grid_desc
.
GetLength
(
I1
)
%
KPerBlock
!=
0
)
return
false
;
return
true
;
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl_dpp8.hpp
deleted
100644 → 0
View file @
8dc713ea
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/inner_product_dpp8.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
/**
* Threadwise contraction using dot instructions with DPP8 modifier.
*
* Assumptions:
* 1. `AThreadDesc_TK0_TM0_TM1_TK1`, `BThreadDesc_TK0_TN0_TN1_TK1`, `CThreadDesc_TM0_TM1_TN0_TN1`
* are known at compile-time;
* 2. `AOriginIdx`, `BOriginIdx`, `COriginIdx` are known at compile-time;
* 3. `TM0` is equal to 1 and `TN0` is equal to 1;
* 4. When `ShareA` is set (unset, respectively), `TM1` (`TN1`, respectively) is divisible by
* the size of the lane group (`dpp8::lane_group_size`).
*/
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
AThreadDesc_TK0_TM0_TM1_TK1
,
typename
BThreadDesc_TK0_TN0_TN1_TK1
,
typename
CThreadDesc_TM0_TM1_TN0_TN1
,
typename
TKLengths
,
typename
TMLengths
,
typename
TNLengths
,
bool
ShareA
,
typename
enable_if
<
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
TK0
=
TKLengths
{}[
I0
];
static
constexpr
index_t
TK1
=
TKLengths
{}[
I1
];
static
constexpr
index_t
TM0
=
TMLengths
{}[
I0
];
static
constexpr
index_t
TM1
=
TMLengths
{}[
I1
];
static
constexpr
index_t
TN0
=
TNLengths
{}[
I0
];
static
constexpr
index_t
TN1
=
TNLengths
{}[
I1
];
static_assert
(
TM0
==
1
&&
TN0
==
1
);
static_assert
((
ShareA
&&
TM1
%
dpp8
::
lane_group_size
==
0
)
||
(
!
ShareA
&&
TN1
%
dpp8
::
lane_group_size
==
0
));
static
constexpr
index_t
shared_elems_per_lane
=
ShareA
?
TM1
/
dpp8
::
lane_group_size
:
TN1
/
dpp8
::
lane_group_size
;
__device__
constexpr
ThreadwiseContractionDlDpp8_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
()
{
static_assert
(
AThreadDesc_TK0_TM0_TM1_TK1
::
IsKnownAtCompileTime
()
&&
BThreadDesc_TK0_TN0_TN1_TK1
::
IsKnownAtCompileTime
()
&&
CThreadDesc_TM0_TM1_TN0_TN1
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
TKLengths
::
Size
()
==
2
&&
TMLengths
::
Size
()
==
2
&&
TNLengths
::
Size
()
==
2
,
"wrong!"
);
}
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
is_known_at_compile_time
<
remove_cvref_t
<
AOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
BOriginIdx
>>::
value
&&
is_known_at_compile_time
<
remove_cvref_t
<
COriginIdx
>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cvref_t
<
typename
ABuffer
::
type
>
,
remove_cvref_t
<
FloatA
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
BBuffer
::
type
>
,
remove_cvref_t
<
FloatB
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
TK0
,
1
>
{}([
&
](
auto
tk0
)
{
static_for
<
0
,
TM1
,
1
>
{}([
&
](
auto
tm1
)
{
static_for
<
0
,
TN1
,
1
>
{}([
&
](
auto
tn1
)
{
vector_type
<
FloatA
,
TK1
>
a_vec
;
vector_type
<
FloatB
,
TK1
>
b_vec
;
static_for
<
0
,
TK1
,
1
>
{}([
&
](
auto
tk1
)
{
constexpr
index_t
local_tm1
=
ShareA
?
tm1
%
shared_elems_per_lane
:
tm1
;
constexpr
index_t
a_offset
=
AThreadDesc_TK0_TM0_TM1_TK1
{}.
CalculateOffset
(
a_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tm1
,
tk1
));
constexpr
index_t
local_tn1
=
ShareA
?
tn1
:
tn1
%
shared_elems_per_lane
;
constexpr
index_t
b_offset
=
BThreadDesc_TK0_TN0_TN1_TK1
{}.
CalculateOffset
(
b_origin_idx
+
make_multi_index
(
tk0
,
0
,
local_tn1
,
tk1
));
a_vec
.
template
AsType
<
FloatA
>()(
tk1
)
=
a_buf
[
Number
<
a_offset
>
{}];
b_vec
.
template
AsType
<
FloatB
>()(
tk1
)
=
b_buf
[
Number
<
b_offset
>
{}];
});
using
a_vector_t
=
typename
vector_type
<
FloatA
,
TK1
>::
type
;
using
b_vector_t
=
typename
vector_type
<
FloatB
,
TK1
>::
type
;
constexpr
index_t
c_offset
=
CThreadDesc_TM0_TM1_TN0_TN1
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
0
,
tm1
,
0
,
tn1
));
constexpr
int
src_lane
=
ShareA
?
(
tm1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
:
(
tn1
/
shared_elems_per_lane
)
%
dpp8
::
lane_group_size
;
dpp8
::
inner_product_dpp
<
a_vector_t
,
b_vector_t
,
FloatC
,
src_lane
,
ShareA
>
(
a_vec
.
template
AsType
<
a_vector_t
>()[
I0
],
b_vec
.
template
AsType
<
b_vector_t
>()[
I0
],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
a8629a98
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
...
@@ -137,13 +137,12 @@ struct ThreadwiseTensorSliceTransfer_v1r3
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
v
;
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
(
v
);
});
});
const
bool
is_dst_valid
=
const
bool
is_dst_valid
=
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1289,13 +1288,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
Src
Data
v
;
Dst
Data
v
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
)
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
});
});
}
}
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
0 → 100644
View file @
a8629a98
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/amd_gemm_dpp.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
namespace
ck
{
enum
struct
DppInstr
{
dpp8_f16_1x32x2
=
0
,
dpp8_f16_2x16x2
,
dpp8_f16_2x32x2
,
dpp8_f16_4x16x2
,
dpp8_f16_4x32x2
,
dpp8_f16_8x16x2
,
dpp8_f16_8x32x2
,
dpp8_f16_16x16x2
,
dpp8_f16_32x8x2
};
/**
* Structure representing DPP GEMM executed by a single wavefront.
*
* Each structure instantiation must contain the following fields:
* - wave_size - number of threads that execute a single DPP GEMM operation, usually equal to the
* number of threads in a wavefront;
* - lanegroup_size - number of threads (lanes) that share data using DPP instruction modifier,
* it's 8 in case of DPP8;
* - m_per_wave - size along M dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - n_per_wave - size along N dimension of matrix C that is processed in a single DPP GEMM
* operation;
* - m_per_lanegroup - size along M dimension that is processed by a single lanegroup;
* - n_per_lanegroup - size along N dimension that is processed by a single lanegroup;
* - m_per_thread - size along M dimension of the tile calculated by a single thread;
* - n_per_thread - size along N dimension of the tile calculated by a single thread;
* - k_per_dpp - size along K dimension that is reduced in a single DPP GEMM operation;
* - share_a - indicates whether we share matrix A or matrix B between lanes using DPP modifiers.
*
* Not all the combinarions are supported now, for current restrictions see the static asserts
* in the DppSelector's contructor.
*/
template
<
DppInstr
instr
>
struct
dpp_type
;
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_32x8x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
32
;
static
constexpr
index_t
n_per_wave
=
8
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_8x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
8
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_16x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
16
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
8
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
8
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
4
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
4
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_4x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
4
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_1x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
1
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x32x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
32
;
static
constexpr
index_t
m_per_lanegroup
=
2
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
2
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
>
struct
dpp_type
<
DppInstr
::
dpp8_f16_2x16x2
>
{
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lanegroup_size
=
8
;
static
constexpr
index_t
m_per_wave
=
2
;
static
constexpr
index_t
n_per_wave
=
16
;
static
constexpr
index_t
m_per_lanegroup
=
1
;
static
constexpr
index_t
n_per_lanegroup
=
8
;
static
constexpr
index_t
m_per_thread
=
1
;
static
constexpr
index_t
n_per_thread
=
1
;
static
constexpr
index_t
k_per_dpp
=
2
;
static
constexpr
bool
share_a
=
true
;
using
BaseType
=
half_t
;
template
<
index_t
MPerDpp
,
index_t
NPerDpp
,
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
run
(
const
ADataType
&
a
,
const
BDataType
&
b
,
CDataType
&
reg_c
)
const
{
dpp8
::
DppLanegroupGemm
<
m_per_thread
,
n_per_thread
,
k_per_dpp
,
BaseType
,
ADataType
,
BDataType
,
CDataType
,
share_a
>
{}
.
Run
(
a
,
b
,
reg_c
);
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
>
struct
DppSelector
{
template
<
typename
BaseType_
,
index_t
MPerDpp_
,
index_t
NPerDpp_
>
static
constexpr
auto
GetDpp
();
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
return
DppInstr
::
dpp8_f16_8x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
return
DppInstr
::
dpp8_f16_8x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
return
DppInstr
::
dpp8_f16_16x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
return
DppInstr
::
dpp8_f16_32x8x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
return
DppInstr
::
dpp8_f16_1x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
return
DppInstr
::
dpp8_f16_2x32x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
return
DppInstr
::
dpp8_f16_2x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
return
DppInstr
::
dpp8_f16_4x16x2
;
}
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
return
DppInstr
::
dpp8_f16_4x32x2
;
}
static
constexpr
auto
selected_dpp
=
dpp_type
<
GetDpp
<
BaseType
,
MPerDpp
,
NPerDpp
>
()
>
{};
__host__
__device__
constexpr
DppSelector
()
{
static_assert
(
selected_dpp
.
m_per_wave
%
selected_dpp
.
m_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
n_per_wave
%
selected_dpp
.
n_per_lanegroup
==
0
);
static_assert
(
selected_dpp
.
k_per_dpp
%
2
==
0
);
static_assert
(
selected_dpp
.
wave_size
%
selected_dpp
.
lanegroup_size
==
0
);
constexpr
index_t
num_dpp_per_wave
=
selected_dpp
.
wave_size
/
selected_dpp
.
lanegroup_size
;
constexpr
index_t
num_wave_c_elems
=
selected_dpp
.
m_per_wave
*
selected_dpp
.
n_per_wave
;
constexpr
index_t
num_dpp_c_elems
=
selected_dpp
.
m_per_lanegroup
*
selected_dpp
.
n_per_lanegroup
;
static_assert
(
num_wave_c_elems
%
num_dpp_c_elems
==
0
);
static_assert
(
num_dpp_per_wave
==
num_wave_c_elems
/
num_dpp_c_elems
);
if
constexpr
(
selected_dpp
.
share_a
)
{
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
n_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
}
else
{
static_assert
(
selected_dpp
.
m_per_lanegroup
%
selected_dpp
.
n_per_thread
==
0
);
static_assert
(
selected_dpp
.
m_per_lanegroup
/
selected_dpp
.
n_per_thread
==
selected_dpp
.
lanegroup_size
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
);
}
// Below checks come from the restrictions of the current implementation, could be removed
// in the future when the implementation is more generalized.
static_assert
(
selected_dpp
.
share_a
);
static_assert
(
selected_dpp
.
n_per_thread
==
1
);
static_assert
(
selected_dpp
.
m_per_lanegroup
==
selected_dpp
.
m_per_thread
);
static_assert
(
selected_dpp
.
n_per_lanegroup
==
selected_dpp
.
n_per_thread
*
selected_dpp
.
lanegroup_size
);
}
static
constexpr
index_t
GetK1PerDpp
()
{
return
selected_dpp
.
k_per_dpp
;
}
};
template
<
typename
BaseType
,
index_t
MPerDpp
,
index_t
NPerDpp
,
index_t
KPack
>
struct
DppGemm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__host__
__device__
constexpr
DppGemm
()
{
static_assert
(
KPack
%
dpp_instr
.
k_per_dpp
==
0
,
"KPack must be divisible by k_per_dpp."
);
}
__device__
static
constexpr
index_t
GetRegSizePerDpp
()
{
return
MPerDpp
*
NPerDpp
/
dpp_instr
.
wave_size
;
}
template
<
class
ADataType
,
class
BDataType
,
class
CDataType
>
__device__
void
Run
(
const
ADataType
&
p_a_wave
,
const
BDataType
&
p_b_wave
,
CDataType
&
p_c_thread
)
const
{
static_assert
(
is_same
<
BaseType
,
double
>::
value
||
is_same
<
BaseType
,
float
>::
value
||
is_same
<
BaseType
,
half_t
>::
value
||
is_same
<
BaseType
,
bhalf_t
>::
value
||
is_same
<
BaseType
,
int8_t
>::
value
||
is_same
<
BaseType
,
f8_t
>::
value
,
"base BaseType must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
dpp_instr
.
k_per_dpp
,
1
>
{}([
&
](
auto
k
)
{
dpp_instr
.
template
run
<
MPerDpp
,
NPerDpp
>(
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
});
}
__device__
static
auto
GetLaneIdInWave
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetWaveId
()
{
return
get_thread_local_1d_id
()
/
dpp_instr
.
wave_size
;
}
__device__
static
auto
GetLaneIdInLaneGroup
()
{
return
get_thread_local_1d_id
()
%
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetLaneGroupIdInWave
()
{
return
GetLaneIdInWave
()
/
dpp_instr
.
lanegroup_size
;
}
__device__
static
auto
GetDppOpIdx
()
{
const
auto
lanegroupId
=
GetLaneGroupIdInWave
();
constexpr
auto
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
dpp_instr
.
m_per_wave
/
dpp_instr
.
m_per_lanegroup
,
dpp_instr
.
n_per_wave
/
dpp_instr
.
n_per_lanegroup
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
dpp_idx
=
lanegroup_idx_1d_to_dpp_idx_2d_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
lanegroupId
));
const
auto
m_dpp_idx
=
dpp_idx
[
I0
];
const
auto
n_dpp_idx
=
dpp_idx
[
I1
];
return
make_tuple
(
m_dpp_idx
,
n_dpp_idx
);
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex_K_M
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
const
auto
wave_row
=
laneId
/
dpp_instr
.
n_per_wave
;
auto
m_idx
=
dpp_instr
.
m_per_thread
*
wave_row
+
GetLaneIdInLaneGroup
();
return
make_tuple
(
0
,
m_idx
%
dpp_instr
.
m_per_wave
);
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex_K_N
()
{
const
auto
laneId
=
get_thread_local_1d_id
();
return
make_tuple
(
0
,
laneId
%
dpp_instr
.
n_per_wave
);
}
__device__
static
CIndex
GetBeginOfThreadBlk
()
{
const
auto
dpp_op_idx
=
GetDppOpIdx
();
const
auto
m_dpp_op_idx
=
dpp_op_idx
[
I0
];
const
auto
n_dpp_op_idx
=
dpp_op_idx
[
I1
];
index_t
n_offset
=
n_dpp_op_idx
*
dpp_instr
.
n_per_lanegroup
+
GetLaneIdInLaneGroup
();
index_t
m_offset
=
m_dpp_op_idx
*
dpp_instr
.
m_per_lanegroup
;
return
CIndex
{
m_offset
,
n_offset
};
}
static
constexpr
auto
dpp
=
DppSelector
<
BaseType
,
MPerDpp
,
NPerDpp
>
{};
static
constexpr
auto
dpp_instr
=
dpp
.
selected_dpp
;
static
constexpr
auto
K0PerDpp
=
1
;
static
constexpr
auto
K1PerDpp
=
dpp
.
GetK1PerDpp
();
__host__
__device__
static
constexpr
auto
GetCMNThreadBlkLengths
()
{
return
make_tuple
(
Number
<
dpp_instr
.
m_per_thread
>
{},
Number
<
dpp_instr
.
n_per_thread
>
{});
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
a8629a98
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
...
@@ -456,6 +456,7 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
struct
mfma_type
<
MfmaInstr
::
mfma_f32_32x32x16f8f8
>
{
{
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
...
@@ -499,6 +500,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
intrin_mfma_f32_16x16x32f8f8
<
MPerXdlops
,
NPerXdlops
>::
Run
(
a
,
b
,
reg_c
);
}
}
};
};
#endif
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
NPerXdlops
>
struct
MfmaSelector
struct
MfmaSelector
...
@@ -640,6 +642,7 @@ struct MfmaSelector
...
@@ -640,6 +642,7 @@ struct MfmaSelector
}
}
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
...
@@ -651,6 +654,7 @@ struct MfmaSelector
...
@@ -651,6 +654,7 @@ struct MfmaSelector
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
#endif
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
>
()
>
{};
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
...
@@ -852,7 +856,11 @@ struct XdlopsGemm
{
{
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
static_assert
(
is_same
<
base_type
,
double
>::
value
||
is_same
<
base_type
,
float
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
||
is_same
<
base_type
,
int8_t
>::
value
||
is_same
<
base_type
,
f8_t
>::
value
,
is_same
<
base_type
,
int8_t
>::
value
#if defined CK_ENABLE_FP8
||
is_same
<
base_type
,
f8_t
>::
value
#endif
,
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
"base base_type must be double, float, half, bfloat16, and int8_t!"
);
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
KPack
/
mfma_instr
.
k_per_blk
,
1
>
{}([
&
](
auto
k
)
{
...
...
include/ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp
View file @
a8629a98
...
@@ -164,6 +164,7 @@ template <
...
@@ -164,6 +164,7 @@ template <
index_t
BK1
,
index_t
BK1
,
index_t
GemmMPerBlock
,
index_t
GemmMPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmNPerBlock
,
index_t
GemmKPerBlock
,
bool
DoPadGemmM
,
bool
DoPadGemmM
,
bool
DoPadGemmN
>
bool
DoPadGemmN
>
struct
TransformConvBwdDataToGemm_v1
struct
TransformConvBwdDataToGemm_v1
...
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -308,9 +309,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
AK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
AK1
);
if
constexpr
(
NDimSpatial
==
2
)
if
constexpr
(
NDimSpatial
==
2
)
{
{
// A: output tensor
// A: output tensor
...
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -367,9 +365,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -460,9 +460,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
out_gemmk_gemmm_padded_grid_desc
=
const
auto
out_gemmk_gemmm_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
out_gemmk_gemmmraw_grid_desc
,
out_gemmk_gemmmraw_grid_desc
,
make_tuple
(
AK1
,
GemmMPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmMPerBlock
),
Sequence
<
true
,
DoPadGemmM
>
{});
Sequence
<
true
,
DoPadGemmM
>
{});
const
index_t
AK0
=
out_gemmk_gemmm_padded_grid_desc
.
GetLength
(
I0
)
/
AK1
;
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
const
auto
out_gemmak0_gemmm_gemmak1_grid_desc
=
transform_tensor_descriptor
(
out_gemmk_gemmm_padded_grid_desc
,
out_gemmk_gemmm_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
...
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -568,9 +570,6 @@ struct TransformConvBwdDataToGemm_v1
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
YDotSlice
=
math
::
integer_divide_ceil
(
Y
-
i_ytilde
,
YTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
auto
XDotSlice
=
math
::
integer_divide_ceil
(
X
-
i_xtilde
,
XTilde
);
const
index_t
BK0
=
math
::
integer_divide_ceil
(
ZDotSlice
*
YDotSlice
*
XDotSlice
*
K
,
BK1
);
// B weight tensor
// B weight tensor
if
constexpr
(
NDimSpatial
==
2
)
if
constexpr
(
NDimSpatial
==
2
)
{
{
...
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -617,9 +616,11 @@ struct TransformConvBwdDataToGemm_v1
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
const
auto
wei_gemmk_gemmn_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmbk0_gemmn_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemmn_padded_grid_desc
,
wei_gemmk_gemmn_padded_grid_desc
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
...
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
...
@@ -690,17 +691,19 @@ struct TransformConvBwdDataToGemm_v1
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
1
,
2
,
3
,
0
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
wei_gemmk_gemm_padded_grid_desc
=
const
auto
wei_gemmk_gemm
n
_padded_grid_desc
=
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
ck
::
tensor_operation
::
device
::
PadTensorDescriptor
(
wei_gemmk_gemmnraw_grid_desc
,
wei_gemmk_gemmnraw_grid_desc
,
make_tuple
(
BK1
,
GemmNPerBlock
),
make_tuple
(
GemmKPerBlock
,
GemmNPerBlock
),
Sequence
<
true
,
DoPadGemmN
>
{});
Sequence
<
true
,
DoPadGemmN
>
{});
const
index_t
BK0
=
wei_gemmk_gemmn_padded_grid_desc
.
GetLength
(
I0
)
/
BK1
;
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
const
auto
wei_gemmbk0_gemm_gemmbk1_grid_desc
=
transform_tensor_descriptor
(
wei_gemmk_gemm_padded_grid_desc
,
wei_gemmk_gemm
n
_padded_grid_desc
,
make_tuple
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
make_pass_through_transform
(
wei_gemmk_gemm_padded_grid_desc
.
GetLength
(
I1
))),
wei_gemmk_gemm
n
_padded_grid_desc
.
GetLength
(
I1
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
a8629a98
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1127,7 +1127,7 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
uint32_t
src_addr_shift
=
src_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1136,10 +1136,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
}
else
else
{
{
#endif
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
return
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_addr_shift
+
src_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#else
#else
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
auto
tmp
=
amd_buffer_load_impl
<
int8_t
,
vector_size
,
coherence
>
(
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
...
@@ -1148,11 +1152,14 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
}
}
else
else
{
{
#endif
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
vector_t
tmp
=
amd_buffer_load_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
return
src_thread_element_valid
?
tmp
:
vector_t
(
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#endif
#endif
}
}
// buffer_load requires:
// buffer_load requires:
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1209,7 +1216,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
uint32_t
dst_addr_shift
=
dst_thread_element_valid
?
0
:
0x80000000
;
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
auto
tmp
=
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1219,12 +1226,16 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
}
else
else
{
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_addr_shift
+
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
#else
#else
if
(
dst_thread_element_valid
)
if
(
dst_thread_element_valid
)
{
{
#if defined CK_ENABLE_FP8
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
if
constexpr
(
is_same
<
scalar_t
,
f8_t
>::
value
)
{
{
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
auto
tmp
=
bit_cast
<
typename
vector_type_maker
<
int8_t
,
vector_size
>::
type
::
type
>
(
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
...
@@ -1234,9 +1245,12 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
}
}
else
else
{
{
#endif
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
amd_buffer_store_impl
<
scalar_t
,
vector_size
,
coherence
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
src_thread_data
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
0
);
#if defined CK_ENABLE_FP8
}
}
#endif
}
}
#endif
#endif
}
}
...
...
include/ck/utility/amd_gemm_dpp.hpp
View file @
a8629a98
...
@@ -5,17 +5,63 @@
...
@@ -5,17 +5,63 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/math.hpp"
#include "ck/utility/
amd_gemm
_dpp.hpp"
#include "ck/utility/
inner_product
_dpp
8
.hpp"
namespace
ck
{
namespace
ck
{
namespace
dpp8
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
template
<
class
ABDataType
>
constexpr
index_t
lane_group_size
=
8
;
struct
dpp_datatypes
;
__device__
index_t
get_lane_group_local_idx
()
{
return
threadIdx
.
x
/
lane_group_size
;
}
template
<
>
__device__
index_t
get_thread_idx_in_lane_group
()
{
return
threadIdx
.
x
%
lane_group_size
;
}
struct
dpp_datatypes
<
half_t
>
{
// Dot product of `half2_t` and `half2_t` to get `float`. Reducing 2 elements from K in a
// single instruction.
using
a_dtype
=
half_t
;
using
b_dtype
=
half_t
;
using
c_dtype
=
float
;
static
constexpr
index_t
k_per_instr
=
2
;
};
template
<
index_t
MPerThread
,
index_t
NPerThread
,
index_t
KPerThread
,
class
BaseInputType
,
class
AVecDataType
,
class
BVecDataType
,
class
CVecDataType
,
bool
ShareA
>
struct
DppLanegroupGemm
{
using
datatypes_conf
=
dpp_datatypes
<
BaseInputType
>
;
using
ADataType
=
typename
datatypes_conf
::
a_dtype
;
using
BDataType
=
typename
datatypes_conf
::
b_dtype
;
using
CDataType
=
typename
datatypes_conf
::
c_dtype
;
__device__
void
Run
(
const
AVecDataType
&
a_vec
,
const
BVecDataType
&
b_vec
,
CVecDataType
&
c_vec
)
{
constexpr
index_t
num_c_elems_per_thread
=
ShareA
?
MPerThread
:
NPerThread
;
const
vector_type
<
ADataType
,
KPerThread
>
a_vector
{
a_vec
};
const
vector_type
<
BDataType
,
KPerThread
>
b_vector
{
b_vec
};
static_for
<
0
,
num_c_elems_per_thread
,
1
>
{}([
&
](
auto
c_idx
)
{
float
c
=
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
);
// Next `c_idx` implies that we need to pull data from the next lane.
constexpr
index_t
source_lane
=
c_idx
;
static_for
<
0
,
KPerThread
/
datatypes_conf
::
k_per_instr
,
1
>
{}([
&
](
auto
k_chunk
)
{
const
auto
a_k_vec
=
a_vector
.
template
AsType
<
AVecDataType
>()[
k_chunk
];
const
auto
b_k_vec
=
b_vector
.
template
AsType
<
BVecDataType
>()[
k_chunk
];
ck
::
dpp8
::
inner_product_dpp
<
AVecDataType
,
BVecDataType
,
CDataType
,
source_lane
,
ShareA
>
(
a_k_vec
,
b_k_vec
,
c
);
});
c_vec
.
template
AsType
<
CDataType
>()(
c_idx
)
=
c
;
});
}
};
}
// namespace dpp8
}
// namespace dpp8
...
...
include/ck/utility/amd_xdlops.hpp
View file @
a8629a98
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
...
@@ -355,6 +355,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
}
}
};
};
#if defined CK_ENABLE_FP8
template
<
index_t
MPerWave
,
index_t
NPerWave
>
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_mfma_f32_32x32x16f8f8
;
struct
intrin_mfma_f32_32x32x16f8f8
;
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
...
@@ -417,5 +418,6 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
#endif
#endif
}
}
};
};
#endif
}
// namespace ck
}
// namespace ck
#endif
#endif
include/ck/utility/data_type.hpp
View file @
a8629a98
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
...
@@ -12,7 +12,12 @@ using half_t = _Float16;
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
using
int4_t
=
_BitInt
(
4
);
using
int4_t
=
_BitInt
(
4
);
#endif
#endif
using
f8_t
=
uint8_t
;
#if defined CK_ENABLE_FP8
using
f8_t
=
_BitInt
(
8
);
#endif
#if defined CK_ENABLE_BF8
using
bf8_t
=
unsigned
_BitInt
(
8
);
#endif
// vector_type
// vector_type
template
<
typename
T
,
index_t
N
>
template
<
typename
T
,
index_t
N
>
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
...
@@ -143,14 +148,24 @@ struct scalar_type<int4_t>
};
};
#endif
#endif
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
scalar_type
<
f8_t
>
struct
scalar_type
<
f8_t
>
{
{
using
type
=
f8_t
;
using
type
=
f8_t
;
static
constexpr
index_t
vector_size
=
1
;
static
constexpr
index_t
vector_size
=
1
;
};
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
scalar_type
<
bf8_t
>
{
using
type
=
bf8_t
;
static
constexpr
index_t
vector_size
=
1
;
};
#endif
//
template
<
typename
T
>
template
<
typename
T
>
struct
vector_type
<
T
,
1
>
struct
vector_type
<
T
,
1
>
{
{
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
...
@@ -953,12 +968,24 @@ using int8x32_t = typename vector_type<int8_t, 32>::type;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
using
int8x64_t
=
typename
vector_type
<
int8_t
,
64
>::
type
;
// f8
// f8
#if defined CK_ENABLE_FP8
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x2_t
=
typename
vector_type
<
f8_t
,
2
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x4_t
=
typename
vector_type
<
f8_t
,
4
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x8_t
=
typename
vector_type
<
f8_t
,
8
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x16_t
=
typename
vector_type
<
f8_t
,
16
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x32_t
=
typename
vector_type
<
f8_t
,
32
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
using
f8x64_t
=
typename
vector_type
<
f8_t
,
64
>::
type
;
#endif
// bf8
#if defined CK_ENABLE_BF8
using
bf8x2_t
=
typename
vector_type
<
bf8_t
,
2
>::
type
;
using
bf8x4_t
=
typename
vector_type
<
bf8_t
,
4
>::
type
;
using
bf8x8_t
=
typename
vector_type
<
bf8_t
,
8
>::
type
;
using
bf8x16_t
=
typename
vector_type
<
bf8_t
,
16
>::
type
;
using
bf8x32_t
=
typename
vector_type
<
bf8_t
,
32
>::
type
;
using
bf8x64_t
=
typename
vector_type
<
bf8_t
,
64
>::
type
;
#endif
template
<
typename
T
>
template
<
typename
T
>
struct
NumericLimits
struct
NumericLimits
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
...
@@ -1006,21 +1033,109 @@ struct NumericLimits<int4_t>
};
};
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#if defined CK_ENABLE_FP8
template
<
>
template
<
>
struct
NumericLimits
<
f8_t
>
struct
NumericLimits
<
f8_t
>
{
{
// negative zero nan mode with exp bias = 8
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_min
=
0x08
;
// 0b00001000
static
constexpr
uint8_t
binary_max
=
0x77
;
// 0b01110111
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xF7
;
// 0b11110111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 7
// static constexpr uint8_t binary_min = 0x08; // 0b00001000
// static constexpr uint8_t binary_max = 0x77; // 0b01110111
// static constexpr uint8_t binary_lowest = 0xF7; // 0b11110111
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=0
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
f8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
f8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
f8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
f8_t
(
binary_qnan
);
}
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericLimits
<
bf8_t
>
{
// negative zero nan mode with exp bias = 16
static
constexpr
uint8_t
binary_min
=
0x04
;
// 0b00000100
static
constexpr
uint8_t
binary_max
=
0x7F
;
// 0b01111111
static
constexpr
uint8_t
binary_lowest
=
0xFF
;
// 0b11111111
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
static
constexpr
uint8_t
binary_qnan
=
0x80
;
// 0b10000000
// ieee mode with exp bias = 15
// static constexpr uint8_t binary_min = 0x04; // 0b00000100
// static constexpr uint8_t binary_max = 0x7B; // 0b01111011
// static constexpr uint8_t binary_lowest = 0xFB; // 0b11111011
// static constexpr uint8_t binary_qnan = 0x79; // any sign, exp=1111, mant!=
__host__
__device__
static
constexpr
f8_t
Min
()
{
return
b
it_cast
<
f8_t
>
(
binary_min
);
}
__host__
__device__
static
constexpr
b
f8_t
Min
()
{
return
bf8_t
(
binary_min
);
}
__host__
__device__
static
constexpr
f8_t
Max
()
{
return
b
it_cast
<
f8_t
>
(
binary_max
);
}
__host__
__device__
static
constexpr
b
f8_t
Max
()
{
return
bf8_t
(
binary_max
);
}
__host__
__device__
static
constexpr
f8_t
Lowest
()
{
return
b
it_cast
<
f8_t
>
(
binary_lowest
);
}
__host__
__device__
static
constexpr
b
f8_t
Lowest
()
{
return
bf8_t
(
binary_lowest
);
}
__host__
__device__
static
constexpr
f8_t
QuietNaN
()
{
return
b
it_cast
<
f8_t
>
(
binary_qnan
);
}
__host__
__device__
static
constexpr
b
f8_t
QuietNaN
()
{
return
bf8_t
(
binary_qnan
);
}
};
};
#endif
template
<
typename
T
>
struct
NumericUtils
{
};
template
<
>
struct
NumericUtils
<
float
>
{
static
constexpr
int
exp
=
8
;
static
constexpr
int
mant
=
23
;
static
constexpr
uint32_t
nan_mask
=
0x7F800000
;
static
constexpr
uint32_t
head_mask
=
0xFF800000
;
static
constexpr
uint32_t
mant_mask
=
0x7FFFFF
;
static
constexpr
uint32_t
exp_mask
=
0xFF
;
static
constexpr
uint32_t
Inf
=
0x7F800000
;
static
constexpr
uint32_t
NegInf
=
0xFF800000
;
static
constexpr
uint32_t
NaN
=
0x7F800001
;
static
constexpr
uint32_t
Neg0
=
0x80000000
;
using
bitwise_type
=
uint32_t
;
};
template
<
>
struct
NumericUtils
<
half_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
10
;
static
constexpr
uint16_t
nan_mask
=
0x7C00
;
static
constexpr
uint16_t
head_mask
=
0xFC00
;
static
constexpr
uint16_t
mant_mask
=
0x3FF
;
static
constexpr
uint16_t
exp_mask
=
0x1F
;
static
constexpr
uint32_t
Inf
=
0x7C00
;
static
constexpr
uint32_t
NegInf
=
0xFC00
;
static
constexpr
uint32_t
NaN
=
0x7C01
;
static
constexpr
uint32_t
Neg0
=
0x8000
;
using
bitwise_type
=
uint16_t
;
};
#if defined CK_ENABLE_FP8
template
<
>
struct
NumericUtils
<
f8_t
>
{
static
constexpr
int
exp
=
4
;
static
constexpr
int
mant
=
3
;
};
#endif
#if defined CK_ENABLE_BF8
template
<
>
struct
NumericUtils
<
bf8_t
>
{
static
constexpr
int
exp
=
5
;
static
constexpr
int
mant
=
2
;
};
#endif
}
// namespace ck
}
// namespace ck
include/ck/utility/f8_utils.hpp
View file @
a8629a98
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
#include "ck/utility/data_type.hpp"
#include "ck/utility/data_type.hpp"
// these conversions are disabled if native conversions available
#if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
namespace
ck
{
namespace
ck
{
// fp8 rounding modes
// fp8 rounding modes
...
@@ -22,53 +25,38 @@ namespace ck::utils {
...
@@ -22,53 +25,38 @@ namespace ck::utils {
namespace
{
namespace
{
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
run_cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
Y
run_cast_to_f8
(
X
x
,
uint32_t
rng
)
{
{
//
check data type
//
fp8/bf8 exponent/mantissa layout
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
int
out_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
int
out_mant
=
NumericUtils
<
Y
>::
mant
;
// fp8 exponent/mantissa layout
// original type exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
int
f8_mant
=
3
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// resulting type exponent/mantissa layout
constexpr
int
type_exp
=
is_half
?
5
:
8
;
constexpr
int
type_mant
=
is_half
?
10
:
23
;
int
exponent
;
int
exponent
;
uint32_t
head
,
mantissa
,
sign
;
uint32_t
head
,
mantissa
,
sign
;
// nan code is same for float and half
// nan code is same for float and half
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
Y
nan_code
=
0x80
;
constexpr
uint32_t
nan_mask
=
is_half
?
0x7C00
:
0x7F800000
;
constexpr
uint32_t
nan_mask
=
NumericUtils
<
X
>::
nan_mask
;
// convert to bitwise
// convert to bitwise
typedef
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
type
using
T_bitwise
=
typename
NumericUtils
<
X
>::
bitwise_type
;
T_bitwise
;
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
T_bitwise
x_bitwise
=
*
(
reinterpret_cast
<
T_bitwise
*>
(
&
x
));
// unpack the input, depends on datatype
// unpack the input, depends on datatype
if
constexpr
(
is_float
)
head
=
x_bitwise
&
NumericUtils
<
X
>::
head_mask
;
{
mantissa
=
x_bitwise
&
NumericUtils
<
X
>::
mant_mask
;
head
=
x_bitwise
&
0xFF800000
;
exponent
=
(
head
>>
in_mant
)
&
NumericUtils
<
X
>::
exp_mask
;
mantissa
=
x_bitwise
&
0x7FFFFF
;
sign
=
head
>>
(
in_exp
+
in_mant
);
exponent
=
(
head
>>
type_mant
)
&
0xFF
;
sign
=
head
>>
(
type_exp
+
type_mant
);
uint32_t
signed_inf
=
(
sign
<<
(
in_exp
+
in_mant
))
+
(((
1
<<
in_exp
)
-
1
)
<<
in_mant
);
}
uint32_t
drop_mask
=
(
1
<<
(
in_mant
-
out_mant
))
-
1
;
else
if
constexpr
(
is_half
)
constexpr
int
max_exp
=
(
1
<<
out_exp
)
-
(
negative_zero_nan
?
1
:
2
);
{
head
=
x_bitwise
&
0xFC00
;
mantissa
=
x_bitwise
&
0x3FF
;
exponent
=
(
head
>>
type_mant
)
&
0x1F
;
sign
=
head
>>
(
type_exp
+
type_mant
);
}
uint32_t
signed_inf
=
(
sign
<<
(
type_exp
+
type_mant
))
+
(((
1
<<
type_exp
)
-
1
)
<<
type_mant
);
uint32_t
drop_mask
=
(
1
<<
(
type_mant
-
f8_mant
))
-
1
;
constexpr
int
max_exp
=
(
1
<<
f8_exp
)
-
(
negative_zero_nan
?
1
:
2
);
constexpr
int
exp_low_cutoff
=
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
in
_exp
-
1
))
-
(
1
<<
(
out
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
...
@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -81,22 +69,35 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
// if input is half and output is bf8
if
((
NumericUtils
<
X
>::
mant
==
10
)
&&
(
NumericUtils
<
Y
>::
mant
==
2
)
&&
negative_zero_nan
&&
exponent
==
0
)
{
exponent
+=
1
;
while
(
mantissa
<
(
1
<<
in_mant
))
{
mantissa
<<=
1
;
exponent
-=
1
;
}
mantissa
&=
~
(
1
<<
in_mant
);
}
// check if x is 0.0
// check if x is 0.0
if
(
x_bitwise
==
0
)
if
(
x_bitwise
==
0
)
return
0
;
return
0
;
exponent
-=
exp_low_cutoff
-
1
;
exponent
-=
exp_low_cutoff
-
1
;
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
drop_mask
=
(
1
<<
(
type
_mant
-
f8
_mant
+
1
-
exponent
))
-
1
;
drop_mask
=
(
1
<<
(
in
_mant
-
out
_mant
+
1
-
exponent
))
-
1
;
mantissa
+=
1
<<
type
_mant
;
mantissa
+=
1
<<
in
_mant
;
// apply random number if needed
// apply random number if needed
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
mantissa
+=
(
stoch
?
rng
:
mantissa
)
&
drop_mask
;
if
(
mantissa
>=
(
2
<<
type
_mant
))
if
(
mantissa
>=
(
2
<<
in
_mant
))
{
{
mantissa
>>=
1
;
mantissa
>>=
1
;
exponent
++
;
exponent
++
;
}
}
mantissa
>>=
(
type
_mant
-
f8
_mant
);
mantissa
>>=
(
in
_mant
-
out
_mant
);
// check negative exponent
// check negative exponent
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
...
@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -116,7 +117,7 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
{
{
if
(
clip
)
if
(
clip
)
{
{
mantissa
=
(
1
<<
f8
_mant
)
-
1
;
mantissa
=
(
1
<<
out
_mant
)
-
1
;
exponent
=
max_exp
;
exponent
=
max_exp
;
}
}
else
else
...
@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
...
@@ -127,124 +128,121 @@ __host__ __device__ f8_t run_cast_to_f8(T x, uint32_t rng)
// check if x is 0.0 or -0.0
// check if x is 0.0 or -0.0
if
(
exponent
==
0
&&
mantissa
==
0
)
if
(
exponent
==
0
&&
mantissa
==
0
)
return
negative_zero_nan
?
0
:
(
sign
<<
(
f8
_exp
+
f8
_mant
));
return
negative_zero_nan
?
0
:
(
sign
<<
(
out
_exp
+
out
_mant
));
mantissa
&=
(
1
<<
f8
_mant
)
-
1
;
mantissa
&=
(
1
<<
out
_mant
)
-
1
;
return
(
sign
<<
(
f8
_exp
+
f8
_mant
))
|
(
exponent
<<
f8
_mant
)
|
mantissa
;
return
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
}
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
T
run_cast_from_f8
(
f8_t
x
)
__host__
__device__
Y
run_cast_from_f8
(
X
x
)
{
{
// check data type
// fp8/bf8 exponent/mantissa layout
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
int
in_exp
=
NumericUtils
<
X
>::
exp
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
int
in_mant
=
NumericUtils
<
X
>::
mant
;
// fp8 exponent/mantissa layout
constexpr
int
f8_exp
=
4
;
constexpr
int
f8_mant
=
3
;
// resulting type exponent/mantissa layout
// resulting type exponent/mantissa layout
constexpr
int
type
_exp
=
is_half
?
5
:
8
;
constexpr
int
out
_exp
=
NumericUtils
<
Y
>::
exp
;
constexpr
int
type
_mant
=
is_half
?
10
:
23
;
constexpr
int
out
_mant
=
NumericUtils
<
Y
>::
mant
;
// prepare the codes
// prepare the codes
constexpr
uint8_t
nan_code
=
0x80
;
constexpr
X
nan_code
=
0x80
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
Y
Inf
,
NegInf
,
NaN
,
Neg0
;
if
constexpr
(
is_half
)
using
T_bitwise
=
typename
NumericUtils
<
Y
>::
bitwise_type
;
{
constexpr
uint16_t
ihInf
=
0x7C00
;
constexpr
T_bitwise
Inf_bitwise
=
NumericUtils
<
Y
>::
Inf
;
constexpr
uint16_t
ihNegInf
=
0xFC00
;
constexpr
T_bitwise
NegInf_bitwise
=
NumericUtils
<
Y
>::
NegInf
;
constexpr
uint16_t
ihNaN
=
0x7C01
;
constexpr
T_bitwise
NaN_bitwise
=
NumericUtils
<
Y
>::
NaN
;
constexpr
uint16_t
ihNeg0
=
0x8000
;
constexpr
T_bitwise
Neg0_bitwise
=
NumericUtils
<
Y
>::
Neg0
;
fInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNegInf
));
Inf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Inf_bitwise
));
fNaN
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNaN
));
NegInf
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NegInf_bitwise
));
fNeg0
=
*
(
reinterpret_cast
<
const
half_t
*>
(
&
ihNeg0
));
NaN
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
NaN_bitwise
));
}
Neg0
=
*
(
reinterpret_cast
<
const
Y
*>
(
&
Neg0_bitwise
));
else
if
constexpr
(
is_float
)
{
// check if x is 0.0
constexpr
uint32_t
ifInf
=
0x7F800000
;
if
(
x
==
0
)
constexpr
uint32_t
ifNegInf
=
0xFF800000
;
return
static_cast
<
Y
>
(
0
);
constexpr
uint32_t
ifNaN
=
0x7F800001
;
constexpr
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifInf
));
fNegInf
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNegInf
));
fNaN
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNaN
));
fNeg0
=
*
(
reinterpret_cast
<
const
float
*>
(
&
ifNeg0
));
}
// unpack the input
// unpack the input
uint32_t
sign
=
x
>>
(
f8
_exp
+
f8
_mant
);
uint32_t
sign
=
x
>>
(
in
_exp
+
in
_mant
);
uint32_t
mantissa
=
x
&
((
1
<<
f8
_mant
)
-
1
);
uint32_t
mantissa
=
x
&
((
1
<<
in
_mant
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
f8
_mant
;
int
exponent
=
(
x
&
0x7F
)
>>
in
_mant
;
constexpr
int
exp_low_cutoff
=
constexpr
int
exp_low_cutoff
=
(
1
<<
(
type
_exp
-
1
))
-
(
1
<<
(
f8
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
(
1
<<
(
out
_exp
-
1
))
-
(
1
<<
(
in
_exp
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
typename
std
::
conditional
<
std
::
is_same
<
T
,
half_t
>::
value
,
uint16_t
,
uint32_t
>::
typ
e
retval
;
T_bitwis
e
retval
;
if
constexpr
(
negative_zero_nan
)
if
constexpr
(
negative_zero_nan
)
{
{
if
(
x
==
nan_code
)
if
(
x
==
nan_code
)
return
f
NaN
;
return
NaN
;
}
}
else
else
{
{
if
(
x
==
nan_code
)
if
(
x
==
nan_code
)
return
fNeg0
;
return
Neg0
;
if
(
exponent
==
((
1
<<
f8_exp
)
-
1
))
if
(
exponent
==
((
1
<<
in_exp
)
-
1
))
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
return
(
mantissa
==
0
)
?
(
sign
?
NegInf
:
Inf
)
:
NaN
;
}
if
((
NumericUtils
<
Y
>::
mant
==
10
)
&&
(
NumericUtils
<
X
>::
mant
==
2
)
&&
!
negative_zero_nan
)
{
retval
=
x
;
retval
<<=
8
;
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
// subnormal input
// subnormal input
if
(
exponent
==
0
)
if
(
exponent
==
0
)
{
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
((
1
+
type_exp
+
type_mant
)
-
f8_mant
);
exponent
++
;
mantissa
<<=
sh
;
while
(
mantissa
<
(
1
<<
in_mant
))
mantissa
&=
((
1
<<
f8_mant
)
-
1
);
{
exponent
+=
1
-
sh
;
mantissa
<<=
1
;
exponent
--
;
}
mantissa
&=
((
1
<<
in_mant
)
-
1
);
}
}
exponent
+=
exp_low_cutoff
-
1
;
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
type
_mant
-
f8
_mant
;
mantissa
<<=
out
_mant
-
in
_mant
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
if
(
exponent
<=
0
)
{
{
mantissa
|=
1
<<
type
_mant
;
mantissa
|=
1
<<
out
_mant
;
mantissa
>>=
1
-
exponent
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
exponent
=
0
;
}
}
retval
=
(
sign
<<
(
type
_exp
+
type
_mant
))
|
(
exponent
<<
type
_mant
)
|
mantissa
;
retval
=
(
sign
<<
(
out
_exp
+
out
_mant
))
|
(
exponent
<<
out
_mant
)
|
mantissa
;
return
*
(
reinterpret_cast
<
const
T
*>
(
&
retval
));
return
*
(
reinterpret_cast
<
const
Y
*>
(
&
retval
));
}
}
}
// namespace
}
// namespace
template
<
typename
T
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
,
bool
clip
,
bool
stoch
>
__host__
__device__
f8_t
cast_to_f8
(
T
x
,
uint32_t
rng
)
__host__
__device__
Y
cast_to_f8
(
X
x
,
uint32_t
rng
)
{
{
// check datatype
// check datatype
s
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
X
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
X
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted
to f8
."
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be casted."
);
return
run_cast_to_f8
<
T
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
return
run_cast_to_f8
<
X
,
Y
,
negative_zero_nan
,
clip
,
stoch
>
(
x
,
rng
);
}
}
template
<
typename
T
,
bool
negative_zero_nan
>
template
<
typename
X
,
typename
Y
,
bool
negative_zero_nan
>
__host__
__device__
T
cast_from_f8
(
f8_t
x
)
__host__
__device__
Y
cast_from_f8
(
X
x
)
{
{
// check datatype
// check datatype
constexpr
bool
is_half
=
std
::
is_same
<
T
,
half_t
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
Y
,
half_t
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
constexpr
bool
is_float
=
std
::
is_same
<
Y
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
static_assert
(
is_half
||
is_float
,
"only half and float are supported."
);
// check if x is 0.0
return
run_cast_from_f8
<
X
,
Y
,
negative_zero_nan
>
(
x
);
if
(
x
==
0
)
return
static_cast
<
T
>
(
0
);
return
run_cast_from_f8
<
T
,
negative_zero_nan
>
(
x
);
}
}
}
// namespace ck::utils
}
// namespace ck::utils
#endif // #if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
#endif // #if !defined(__gfx940__) && !defined(__gfx941__) && !defined(__gfx942__)
include/ck/utility/inner_product.hpp
View file @
a8629a98
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
...
@@ -72,6 +72,18 @@ inner_product<float4_t, float4_t, float>(const float4_t& a, const float4_t& b, f
c
);
c
);
}
}
template
<
>
__device__
void
inner_product
<
bhalf_t
,
bhalf_t
,
float
>
(
const
bhalf_t
&
a
,
const
bhalf_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
__device__
void
inner_product
<
half_t
,
half_t
,
float
>
(
const
half_t
&
a
,
const
half_t
&
b
,
float
&
c
)
{
inner_product
(
type_convert
<
float
>
(
a
),
type_convert
<
float
>
(
b
),
c
);
}
template
<
>
template
<
>
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
__device__
void
inner_product
<
half2_t
,
half2_t
,
float
>
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
)
{
{
...
...
include/ck/utility/inner_product_dpp8.hpp
View file @
a8629a98
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "amd_gemm_dpp.hpp"
#include "amd_gemm_dpp.hpp"
#include "data_type.hpp"
#include "data_type.hpp"
#include "type_convert.hpp"
#include "type_convert.hpp"
...
@@ -10,6 +11,9 @@ namespace ck {
...
@@ -10,6 +11,9 @@ namespace ck {
namespace
dpp8
{
namespace
dpp8
{
/// Number of lanes that can share data using DPP8 modifiers.
constexpr
index_t
lane_group_size
=
8
;
template
<
int
SrcLaneIdx
>
template
<
int
SrcLaneIdx
>
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
__device__
void
inline_v_dot2c_dpp8_instr
(
const
half2_t
&
a
,
const
half2_t
&
b
,
float
&
c
);
...
...
include/ck/utility/loop_scheduler.hpp
0 → 100644
View file @
a8629a98
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
Interwave
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
#if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
return
LoopScheduler
::
Interwave
;
#else
return
LoopScheduler
::
Default
;
#endif // if CK_EXPERIMENTAL_DEFAULT_TO_INTER_WAVE_SCHEDULING
}
}
// namespace ck
Prev
1
2
3
4
5
6
7
8
9
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment