Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
571e8728
Commit
571e8728
authored
Dec 15, 2023
by
muozturk
Browse files
it was used to work but after merge there is some problem
parent
75cf3655
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
36 deletions
+32
-36
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp32.cpp
..._contraction_scale/complex_contraction_scale_xdl_fp32.cpp
+1
-0
example/65_complex_contraction_scale/run_complex_contraction_scale_example.inc
...ntraction_scale/run_complex_contraction_scale_example.inc
+31
-36
No files found.
example/65_complex_contraction_scale/complex_contraction_scale_xdl_fp32.cpp
View file @
571e8728
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "common_instances.hpp"
#include "common_instances.hpp"
...
...
example/65_complex_contraction_scale/run_complex_contraction_scale_example.inc
View file @
571e8728
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "ck/library/utility/numeric.hpp"
#include "ck/library/utility/numeric.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_contraction.hpp"
int
run_complex_contraction_
bilinear
_example
(
int
argc
,
char
*
argv
[])
int
run_complex_contraction_
scale
_example
(
int
argc
,
char
*
argv
[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
...
@@ -159,10 +159,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -159,10 +159,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf_img
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
// // Intermediate Value For E Real and Img
// Intermediate Value For E Real and Img
// DeviceMem e_device_buf_re1(sizeof(EDataType) * e_ms_ns_device_result_re.mDesc.GetElementSpaceSize());
DeviceMem
e_device_buf_re1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_re
.
mDesc
.
GetElementSpaceSize
());
// DeviceMem e_device_buf_img1(sizeof(EDataType) * e_ms_ns_device_result_img.mDesc.GetElementSpaceSize());
DeviceMem
e_device_buf_img1
(
sizeof
(
EDataType
)
*
e_ms_ns_device_result_img
.
mDesc
.
GetElementSpaceSize
());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
a_device_buf_re
.
ToDevice
(
a_ms_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
b_device_buf_re
.
ToDevice
(
b_ns_ks_re
.
mData
.
data
());
...
@@ -175,25 +174,23 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -175,25 +174,23 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
// set zero
// set zero
e_device_buf_re
.
SetZero
();
e_device_buf_re
.
SetZero
();
e_device_buf_img
.
SetZero
();
e_device_buf_img
.
SetZero
();
e_device_buf_re1
.
SetZero
();
e_device_buf_img1
.
SetZero
();
// // set zero for intermediate values
// e_device_buf_re1.SetZero();
// e_device_buf_img1.SetZero();
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
cde_element_op_scale
=
CDEElementOp_Scale
{
scale
};
auto
cde_element_op_scale
=
CDEElementOp_Scale
{
scale
};
// device operation
// device operation
//
C
_real = A_real * B_real
//
E1
_real = A_real * B_real
auto
op
=
DeviceOpInstance
{};
auto
op
_scale
=
DeviceOpInstance
_Scale
{};
auto
invoker
=
op
.
MakeInvoker
();
auto
invoker
_scale
=
op_scale
.
MakeInvoker
();
auto
argument_re1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
auto
argument_re1
=
op
_scale
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
// std::array<const void*, 1>{d_device_buf_re.GetDeviceBuffer()},
std
::
array
<
const
void
*
,
0
>
{},
std
::
array
<
const
void
*
,
0
>
{},
e_device_buf_re
.
GetDeviceBuffer
(),
e_device_buf_re
1
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
b_ns_ks_lengths
,
b_ns_ks_lengths
,
...
@@ -206,29 +203,32 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -206,29 +203,32 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
b_element_op
,
b_element_op
,
cde_element_op_scale
);
cde_element_op_scale
);
if
(
!
op
.
IsSupportedArgument
(
argument_re1
))
if
(
!
op
_scale
.
IsSupportedArgument
(
argument_re1
))
{
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
op
_scale
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
return
0
;
}
}
float
ave_time_re1
=
invoker
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time_re1
=
invoker
_scale
.
Run
(
argument_re1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
-
1.
f
*
scale
;
alpha
=
-
1.
f
*
scale
;
beta
=
1.
f
;
beta
=
1.
f
;
a_element_op
=
AElementOp
{};
b_element_op
=
BElementOp
{};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
auto
cde_element_op
=
CDEElementOp
{
alpha
,
beta
};
// device operation
// For real Intermediate Value re_2
// For real Intermediate Value
// E_real = E1_real + A_img * B_img
auto
op
=
DeviceOpInstance
{}
;
auto
invoker
=
op
.
MakeInvoker
();
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
auto
argument_re2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_re
1
.
GetDeviceBuffer
()},
e_device_buf_re
.
GetDeviceBuffer
(),
e_device_buf_re
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
...
@@ -252,15 +252,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -252,15 +252,9 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time_re2
=
invoker
.
Run
(
argument_re2
,
StreamConfig
{
nullptr
,
time_kernel
});
// scale = 1.f ;
auto
argument_img1
=
op_scale
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
// a_element_op = AElementOp{};
// b_element_op = BElementOp{};
// cde_element_op = CDEElementOp{alpha, beta};
auto
argument_img1
=
op
.
MakeArgument
(
a_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_img
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{}
}
,
std
::
array
<
const
void
*
,
0
>
{},
e_device_buf_img
.
GetDeviceBuffer
(),
e_device_buf_img
.
GetDeviceBuffer
(),
a_ms_ks_lengths
,
a_ms_ks_lengths
,
a_ms_ks_strides
,
a_ms_ks_strides
,
...
@@ -275,18 +269,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -275,18 +269,20 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
cde_element_op_scale
);
cde_element_op_scale
);
if
(
!
op
.
IsSupportedArgument
(
argument_img1
))
if
(
!
op
_scale
.
IsSupportedArgument
(
argument_img1
))
{
{
std
::
cout
<<
op
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
std
::
cout
<<
op
_scale
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
0
;
return
0
;
}
}
float
ave_time_img1
=
invoker
.
Run
(
argument_img1
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time_img1
=
invoker
_scale
.
Run
(
argument_img1
,
StreamConfig
{
nullptr
,
time_kernel
});
alpha
=
1.
f
*
scale
;
alpha
=
1.
f
*
scale
;
beta
=
1.
f
;
beta
=
1.
f
;
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
auto
argument_img2
=
op
.
MakeArgument
(
a_device_buf_img
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
b_device_buf_re
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img
.
GetDeviceBuffer
()},
std
::
array
<
const
void
*
,
1
>
{
e_device_buf_img
.
GetDeviceBuffer
()},
...
@@ -325,8 +321,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
...
@@ -325,8 +321,7 @@ int run_complex_contraction_bilinear_example(int argc, char* argv[])
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
a_ms_ks_lengths
.
begin
()
+
NumDimM
,
NumDimK
,
1
,
std
::
multiplies
<>
{});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
*
2
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
*
2
;
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
std
::
size_t
num_btype
=
(
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
EDataType
)
*
M
*
N
)
*
2
;
sizeof
(
DDataType
)
*
M
*
N
+
sizeof
(
EDataType
)
*
M
*
N
*
2
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
float
ave_time
=
ave_time_img2
+
ave_time_img1
+
ave_time_re2
+
ave_time_re1
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment