Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
30d8a363
Commit
30d8a363
authored
Jul 31, 2023
by
Rostyslav Geyyer
Browse files
Enable f16/f8 mixed precision
parent
844b215d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
37 additions
and
21 deletions
+37
-21
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+37
-21
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
30d8a363
...
@@ -35,13 +35,17 @@ __global__ void
...
@@ -35,13 +35,17 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_xdl_cshuffle_v1
(
const
FloatA
B
*
__restrict__
p_a_grid
,
kernel_gemm_xdl_cshuffle_v1
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
typename
GridwiseGemm
::
Problem
problem
)
typename
GridwiseGemm
::
Problem
problem
)
{
{
...
@@ -61,7 +65,8 @@ __global__ void
...
@@ -61,7 +65,8 @@ __global__ void
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
typename
FloatAB
,
typename
FloatA
,
typename
FloatB
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
typename
FloatC
,
typename
FloatC
,
...
@@ -102,7 +107,8 @@ template <typename ALayout,
...
@@ -102,7 +107,8 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
,
LoopScheduler
LoopSched
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
bool
EnableMixedPrecision
=
false
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
struct
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -122,6 +128,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -122,6 +128,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// convert f8 to f16 if mixed precision enabled
using
FloatAAdjusted
=
conditional_t
<
EnableMixedPrecision
&&
is_same_v
<
FloatA
,
ck
::
f8_t
>
,
ck
::
half_t
,
FloatA
>
;
using
FloatBAdjusted
=
conditional_t
<
EnableMixedPrecision
&&
is_same_v
<
FloatB
,
ck
::
f8_t
>
,
ck
::
half_t
,
FloatB
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
...
@@ -463,8 +475,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -463,8 +475,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// Argument
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
{
__host__
Argument
(
const
FloatA
B
*
p_a_grid_
,
__host__
Argument
(
const
FloatA
*
p_a_grid_
,
const
Float
A
B
*
p_b_grid_
,
const
FloatB
*
p_b_grid_
,
FloatC
*
p_c_grid_
,
FloatC
*
p_c_grid_
,
index_t
M_
,
index_t
M_
,
index_t
N_
,
index_t
N_
,
...
@@ -479,8 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -479,8 +491,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
}
}
const
FloatA
B
*
p_a_grid
;
const
FloatA
*
p_a_grid
;
const
Float
A
B
*
p_b_grid
;
const
FloatB
*
p_b_grid
;
FloatC
*
p_c_grid
;
FloatC
*
p_c_grid
;
};
};
...
@@ -541,8 +553,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -541,8 +553,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_align
ed
)
*
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
FloatAAdjust
ed
)
+
sizeof
(
Float
A
B
),
b_block_space_size_aligned
*
sizeof
(
FloatB
Adjusted
)
),
c_block_size
*
sizeof
(
FloatCShuffle
));
c_block_size
*
sizeof
(
FloatCShuffle
));
}
}
...
@@ -676,8 +688,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -676,8 +688,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
template
<
bool
HasMainKBlockLoop
>
template
<
bool
HasMainKBlockLoop
>
__device__
static
void
Run
(
const
FloatA
B
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatA
*
__restrict__
p_a_grid
,
const
Float
A
B
*
__restrict__
p_b_grid
,
const
FloatB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
void
*
__restrict__
p_shared
,
const
Problem
&
problem
)
const
Problem
&
problem
)
...
@@ -743,8 +755,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -743,8 +755,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatA
B
,
FloatA
,
FloatA
B
,
FloatA
Adjusted
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -774,8 +786,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -774,8 +786,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
Float
A
B
,
FloatB
,
Float
A
B
,
FloatB
Adjusted
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -796,6 +808,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -796,6 +808,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// check if A and B types are same
static_assert
(
std
::
is_same
<
FloatAAdjusted
,
FloatBAdjusted
>::
value
,
"The A and B types do not match!"
);
// GEMM definition
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// a_mtx[K0PerBlock, MPerBlock] is in LDS
...
@@ -805,11 +821,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -805,11 +821,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
FloatA
B
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
FloatA
Adjusted
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
FloatA
B
,
FloatA
Adjusted
,
FloatGemmAcc
,
FloatGemmAcc
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -827,10 +843,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -827,10 +843,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatA
B
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
FloatA
Adjusted
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
Float
A
B
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
FloatB
Adjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment