Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e0b473b6
Commit
e0b473b6
authored
Oct 12, 2024
by
aska-0096
Browse files
Update warpshuffle
parent
4ee40bcc
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
153 additions
and
113 deletions
+153
-113
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
...ion_bilinear/run_complex_contraction_bilinear_example.inc
+110
-113
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
No files found.
example/66_complex_contraction_bilinear/run_complex_contraction_bilinear_example.inc
100755 → 100644
View file @
e0b473b6
...
...
@@ -154,17 +154,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
DeviceMem
a_device_buf_re
(
sizeof
(
ADataType
)
*
a_ms_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_re
(
sizeof
(
BDataType
)
*
b_ns_ks_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_re
(
sizeof
(
DDataType
)
*
d_ms_ns_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_device_buf_img
(
sizeof
(
ADataType
)
*
a_ms_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_device_buf_img
(
sizeof
(
BDataType
)
*
b_ns_ks_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d_device_buf_img
(
sizeof
(
DDataType
)
*
d_ms_ns_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
// Intermediate Value For E Real and Img
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
...
...
@@ -191,7 +194,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto
op
=
DeviceOpInstance
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_re
.
GetDeviceBuffer
()},
e_device_buf_re1
.
GetDeviceBuffer
(),
...
...
@@ -216,7 +220,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_re1
=
invoker
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
-
1.
f
;
beta
=
1.
f
;
...
...
@@ -228,7 +231,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// For real Intermediate Value re_2
// auto op = DeviceOpInstance{};
// auto invoker = op.MakeInvoker();
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re1
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
...
...
@@ -253,7 +257,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
;
beta
=
1.
f
;
...
...
@@ -261,7 +264,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op
=
BElementOp
{};
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
d_device_buf_img
.
GetDeviceBuffer
()},
e_device_buf_img1
.
GetDeviceBuffer
(),
...
...
@@ -277,7 +281,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img1
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
...
...
@@ -290,7 +293,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
alpha
=
1.
f
;
beta
=
1.
f
;
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img1
.
GetDeviceBuffer
()},
e_device_buf_img
.
GetDeviceBuffer
(),
...
...
@@ -306,8 +310,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op
,
cde_element_op
);
if
(
!
op
.
IsSupportedArgument
(
argument_img2
))
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
...
...
@@ -317,7 +319,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_img2
=
invoker
.
Run
(
argument_img2
,
StreamConfig
{
nullptr
,
time_kernel
});
ck
::
index_t
M
=
ck
::
accumulate_n
<
ck
::
index_t
>
(
e_ms_ns_lengths
.
begin
(),
NumDimM
,
1
,
std
::
multiplies
<>
{});
...
...
@@ -331,7 +332,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
*
2
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
...
...
@@ -366,8 +367,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
auto
ref_op
=
ReferenceOpInstance
{};
auto
ref_invoker
=
ref_op
.
MakeInvoker
();
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
auto
ref_argument_re
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_re
,
c_ms_ns_host_result_re
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re
);
...
...
@@ -376,7 +377,6 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
for
(
size_t
m0
=
0
;
m0
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
0
];
++
m0
)
{
for
(
size_t
m1
=
0
;
m1
<
e_ms_ns_host_result_re
.
mDesc
.
GetLengths
()[
1
];
++
m1
)
...
...
@@ -398,8 +398,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
auto
ref_argument_re1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_img
,
c_ms_ns_host_result_re1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_re1
);
...
...
@@ -421,15 +421,12 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
isRealOk
=
ck
::
utils
::
check_err
(
e_ms_ns_device_result_re
,
e_ms_ns_host_result_re
)
?
0
:
1
;
// Img Part Verification
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
Tensor
<
CShuffleDataType
>
c_ms_ns_host_result_img1
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
auto
ref_argument_img
=
ref_op
.
MakeArgument
(
a_ms_ks_re
,
b_ns_ks_img
,
c_ms_ns_host_result_img
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img
);
...
...
@@ -454,8 +451,8 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
}
}
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
auto
ref_argument_img1
=
ref_op
.
MakeArgument
(
a_ms_ks_img
,
b_ns_ks_re
,
c_ms_ns_host_result_img1
,
a_element_op
,
b_element_op
);
ref_invoker
.
Run
(
ref_argument_img1
);
...
...
include/ck_tile/core/arch/utility.hpp
View file @
e0b473b6
...
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment