Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bed6f33c
"...composable_kernel_rocm.git" did not exist on "64350affc5767e7ce3fb211d8145b5c9d18017d8"
Commit
bed6f33c
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Move argument field computing logic into device op side
parent
2a43fc3b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
29 deletions
+51
-29
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+20
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+31
-27
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
bed6f33c
...
@@ -209,7 +209,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -209,7 +209,20 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
)
index_t
StrideC
)
{
{
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
};
return
Argument
{
p_a
,
p_b
,
p_c
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
)};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -236,7 +249,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -236,7 +249,12 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
);
StrideC
,
GridwiseGemm
::
CalculateMPadded
(
M
),
GridwiseGemm
::
CalculateNPadded
(
N
),
GridwiseGemm
::
CalculateKPadded
(
K
),
GridwiseGemm
::
CalculateAK0
(
K
),
GridwiseGemm
::
CalculateBK0
(
K
));
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
bed6f33c
...
@@ -91,10 +91,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -91,10 +91,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
// K1 should be Number<...>
static
constexpr
auto
AK0_
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0_
c
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0_
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
BK0_
c
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1_
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1_
c
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1_
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1_
c
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
@@ -398,7 +398,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -398,7 +398,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
index_t
K_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideB_
,
index_t
StrideC_
)
index_t
StrideC_
,
index_t
MPadded_
,
index_t
NPadded_
,
index_t
KPadded_
,
index_t
AK0_
,
index_t
BK0_
)
:
p_a_grid
{
p_a_grid_
},
:
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
},
p_c_grid
{
p_c_grid_
},
...
@@ -408,17 +413,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -408,17 +413,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
StrideA
{
StrideA_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
StrideC
{
StrideC_
},
MPadded
{
CalculateMPadded
(
M_
)},
MPadded
{
MPadded_
},
NPadded
{
CalculateNPadded
(
N_
)},
NPadded
{
NPadded_
},
KPadded
{
CalculateKPadded
(
K_
)},
KPadded
{
KPadded_
},
AK0
{
CalculateAK0
(
K_
)},
AK0
{
AK0_
},
BK0
{
CalculateBK0
(
K_
)},
BK0
{
BK0_
},
a_grid_desc_ak0_m_ak1
{
MakeAGridDescriptor_AK0_M_AK1
(
a_grid_desc_ak0_m_ak1
{
M_
,
CalculateMPadded
(
M_
),
K_
,
CalculateKPadded
(
K_
),
StrideA_
,
CalculateAK0
(
K_
))},
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
MPadded_
,
K_
,
KPadded_
,
StrideA_
,
AK0_
)},
b_grid_desc_bk0_n_bk1
{
MakeBGridDescriptor_BK0_N_BK1
(
b_grid_desc_bk0_n_bk1
{
K_
,
CalculateKPadded
(
K_
),
N_
,
CalculateNPadded
(
N_
),
StrideB_
,
CalculateBK0
(
K_
))},
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
KPadded_
,
N_
,
NPadded_
,
StrideB_
,
BK0_
)},
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
c_grid_desc_m_n
{
MakeCGridDescriptor_M_N
(
M_
,
MPadded_
,
N_
,
NPadded_
,
StrideC_
)}
M_
,
CalculateMPadded
(
M_
),
N_
,
CalculateNPadded
(
N_
),
StrideC_
)}
{
{
}
}
...
@@ -470,16 +474,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -470,16 +474,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0_
,
Number
<
MPerBlock
>
{},
AK1_
),
make_tuple
(
AK0_
c
,
Number
<
MPerBlock
>
{},
AK1_
c
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1_
,
AK1_
,
I1
));
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1_
c
,
AK1_
c
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0_
,
Number
<
NPerBlock
>
{},
BK1_
),
make_tuple
(
BK0_
c
,
Number
<
NPerBlock
>
{},
BK1_
c
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1_
,
BK1_
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1_
c
,
BK1_
c
,
I1
));
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -505,7 +509,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -505,7 +509,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1_
,
BK1_
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1_
c
,
BK1_
c
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
@@ -728,7 +732,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -728,7 +732,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1_
,
BK1_
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1_
c
,
BK1_
c
);
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
@@ -742,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -742,7 +746,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0_
,
MPerBlock
,
AK1_
>
,
Sequence
<
AK0_
c
,
MPerBlock
,
AK1_
c
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
@@ -773,7 +777,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -773,7 +777,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0_
,
NPerBlock
,
BK1_
>
,
Sequence
<
BK0_
c
,
NPerBlock
,
BK1_
c
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
...
@@ -806,7 +810,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -806,7 +810,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// register
// register
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1_
,
BK1_
),
math
::
max
(
math
::
lcm
(
AK1_
c
,
BK1_
c
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
...
@@ -835,8 +839,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
...
@@ -835,8 +839,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1_
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1_
c
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1_
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1_
c
,
0
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment