Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
38504cf4
Commit
38504cf4
authored
May 17, 2023
by
Po-Yen, Chen
Browse files
Add readfirstlane() to copy content into SGPRs
parent
aafba9b4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
15 deletions
+47
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+15
-15
include/ck/utility/type.hpp
include/ck/utility/type.hpp
+32
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
38504cf4
...
...
@@ -510,20 +510,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
__builtin_amdgcn_sched_barrier
(
0
);
#endif
const
auto
a_grid_desc_k0_m_k1
=
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
);
const
auto
b_grid_desc_k0_n_k1
=
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
);
const
auto
a_grid_desc_k0_m_k1
=
readfirstlane
(
MakeAGridDescriptor_K0_M_K1
(
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
K0
,
karg
.
StrideA
)
)
;
const
auto
b_grid_desc_k0_n_k1
=
readfirstlane
(
MakeBGridDescriptor_K0_N_K1
(
karg
.
K
,
karg
.
N
,
karg
.
NPadded
,
karg
.
K0
,
karg
.
StrideB
)
)
;
const
auto
c_grid_desc_m_n
=
readfirstlane
(
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
MPadded
,
karg
.
N
,
karg
.
NPadded
,
karg
.
StrideC
)
)
;
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n
);
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
*
c_grid_desc_m_n
);
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
p_a_grid
,
a_grid_desc_k0_m_k1
->
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc_k0_n_k1
->
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
...
...
@@ -572,7 +572,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatABAdjusted
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
*
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
...
...
@@ -585,7 +585,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
a_grid_desc_k0_m_k1
,
*
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_k0_m_k1
,
...
...
@@ -603,7 +603,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatABAdjusted
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
*
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
...
...
@@ -616,7 +616,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
NumGemmKPrefetchStage
>
(
b_grid_desc_k0_n_k1
,
*
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_k0_n_k1
,
...
...
@@ -665,13 +665,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
#if ENABLE_DUMP_CLOCK
long
loop_start
=
0
,
loop_end
=
0
;
#endif
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
*
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
*
b_grid_desc_k0_n_k1
,
b_block_desc_k0_n_k1
,
b_blockwise_copy
,
b_grid_buf
,
...
...
include/ck/utility/type.hpp
View file @
38504cf4
...
...
@@ -57,4 +57,36 @@ __host__ __device__ constexpr Y bit_cast(const X& x)
#endif
}
namespace
detail
{
template
<
typename
T
>
struct
sgpr_ptr
{
static_assert
(
!
std
::
is_const_v
<
T
>
&&
!
std
::
is_reference_v
<
T
>
&&
std
::
is_trivially_copyable_v
<
T
>
);
__device__
explicit
sgpr_ptr
(
const
T
&
obj
)
noexcept
{
/// TODO: copy object content into member 'memory' by __builtin_amdgcn_readfirstlane()
__builtin_memcpy
(
memory
,
&
obj
,
sizeof
(
obj
));
}
__device__
T
&
operator
*
()
{
return
*
(
this
->
operator
->
());
}
__device__
const
T
&
operator
*
()
const
{
return
*
(
this
->
operator
->
());
}
__device__
T
*
operator
->
()
{
return
reinterpret_cast
<
T
*>
(
memory
);
}
__device__
const
T
*
operator
->
()
const
{
return
reinterpret_cast
<
const
T
*>
(
memory
);
}
private:
alignas
(
T
)
unsigned
char
memory
[
sizeof
(
T
)
+
3
];
};
}
// namespace detail
template
<
typename
T
>
__device__
constexpr
auto
readfirstlane
(
const
T
&
obj
)
{
return
detail
::
sgpr_ptr
<
T
>
(
obj
);
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment