Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d2cd5658
"include/ck/utility/common_header.hpp" did not exist on "f63a23acb14867da6f4a234aae19227a0847b4e6"
Commit
d2cd5658
authored
Oct 27, 2023
by
Muhammed Ozturk
Browse files
Tensor Contraction Complex Data Type is working
parent
160cf6ed
Changes
2
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
286 additions
and
167 deletions
+286
-167
example/complex_contraction/4D_kernel.hpp
example/complex_contraction/4D_kernel.hpp
+218
-139
example/complex_contraction/main.cpp
example/complex_contraction/main.cpp
+68
-28
No files found.
example/complex_contraction/4D_kernel.hpp
View file @
d2cd5658
This diff is collapsed.
Click to expand it.
example/complex_contraction/main.cpp
View file @
d2cd5658
//
// Sample Code:
//
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "4D_kernel.hpp"
//#define DEBUG_CORRECTNESS
//
#define DEBUG_CORRECTNESS
//#define DEBUG_SIMPLE_CORRECTNESS
void
pre_Initializing_Input_Tensors
();
void
post_Correctness
();
// Initialize t3 (t3_temp), 9 t2 and 9 v2.
void
pre_Initializing_Input_Tensors
(
float
*
h_C
,
float
*
h_C_chk
,
int
size_C
,
float
*
h_A
,
int
size_A
,
float
*
h_B
,
int
size_B
)
void
pre_Initializing_Input_Tensors
(
Complex
*
h_C
,
Complex
*
h_C_chk
,
int
size_C
,
Complex
*
h_A
,
int
size_A
,
Complex
*
h_B
,
int
size_B
)
{
// t3
int
i
,
j
;
for
(
i
=
0
;
i
<
size_C
;
i
++
)
{
h_C
[
i
]
=
0.0
;
h_C_chk
[
i
]
=
0.0
;
h_C
[
i
].
re
=
0.0
;
h_C_chk
[
i
].
re
=
0.0
;
h_C
[
i
].
im
=
0.0
;
h_C_chk
[
i
].
im
=
0.0
;
}
for
(
j
=
0
;
j
<
size_A
;
j
++
)
{
h_A
[
j
]
=
((
float
)
rand
()
/
RAND_MAX
);
h_A
[
j
].
re
=
((
float
)
rand
()
/
RAND_MAX
);
h_A
[
j
].
im
=
((
float
)
rand
()
/
RAND_MAX
);
}
for
(
j
=
0
;
j
<
size_B
;
j
++
)
{
h_B
[
j
]
=
((
float
)
rand
()
/
RAND_MAX
);
h_B
[
j
].
re
=
((
float
)
rand
()
/
RAND_MAX
);
h_B
[
j
].
im
=
((
float
)
rand
()
/
RAND_MAX
);
}
}
//
void
post_Correctness
(
float
*
h_C
,
float
*
h_C_chk
,
float
*
h_A
,
float
*
h_B
,
int
size_idx_a
,
int
size_idx_b
,
int
size_idx_c
,
int
size_idx_d
,
int
size_idx_e
,
int
size_idx_f
)
void
post_Correctness
(
Complex
*
h_C
,
Complex
*
h_C_chk
,
Complex
*
h_A
,
Complex
*
h_B
,
int
size_idx_a
,
int
size_idx_b
,
int
size_idx_c
,
int
size_idx_d
,
int
size_idx_e
,
int
size_idx_f
)
{
// t3 [a,16,b,16,c,16,d,16] += sum(e,16,f,16) * t2 [a,e,b,f] * v2 [d,f,c,e];
int
size_C
=
size_idx_a
*
size_idx_b
*
size_idx_c
*
size_idx_d
;
...
...
@@ -59,8 +66,18 @@ void post_Correctness(float* h_C, float* h_C_chk, float* h_A, float* h_B, int si
{
for
(
idx_f
=
0
;
idx_f
<
size_idx_f
;
idx_f
++
)
{
h_C_chk
[
tmp_r_idx
]
+=
h_A
[
idx_a
+
(
idx_e
+
(
idx_b
+
(
idx_f
)
*
size_idx_b
)
*
size_idx_e
)
*
size_idx_a
]
*
h_B
[
idx_d
+
(
idx_f
+
(
idx_c
+
(
idx_e
)
*
size_idx_c
)
*
size_idx_f
)
*
size_idx_d
];
h_C_chk
[
tmp_r_idx
].
re
+=
(
h_A
[
idx_a
+
(
idx_e
+
(
idx_b
+
(
idx_f
)
*
size_idx_b
)
*
size_idx_e
)
*
size_idx_a
].
re
*
h_B
[
idx_d
+
(
idx_f
+
(
idx_c
+
(
idx_e
)
*
size_idx_c
)
*
size_idx_f
)
*
size_idx_d
].
re
)
-
(
h_A
[
idx_a
+
(
idx_e
+
(
idx_b
+
(
idx_f
)
*
size_idx_b
)
*
size_idx_e
)
*
size_idx_a
].
im
*
h_B
[
idx_d
+
(
idx_f
+
(
idx_c
+
(
idx_e
)
*
size_idx_c
)
*
size_idx_f
)
*
size_idx_d
].
im
);
h_C_chk
[
tmp_r_idx
].
im
+=
(
h_A
[
idx_a
+
(
idx_e
+
(
idx_b
+
(
idx_f
)
*
size_idx_b
)
*
size_idx_e
)
*
size_idx_a
].
re
*
h_B
[
idx_d
+
(
idx_f
+
(
idx_c
+
(
idx_e
)
*
size_idx_c
)
*
size_idx_f
)
*
size_idx_d
].
im
)
+
(
h_A
[
idx_a
+
(
idx_e
+
(
idx_b
+
(
idx_f
)
*
size_idx_b
)
*
size_idx_e
)
*
size_idx_a
].
im
*
h_B
[
idx_d
+
(
idx_f
+
(
idx_c
+
(
idx_e
)
*
size_idx_c
)
*
size_idx_f
)
*
size_idx_d
].
re
);
ops
++
;
}
tmp_ops
=
tmp_ops
+
ops
;
...
...
@@ -68,28 +85,51 @@ void post_Correctness(float* h_C, float* h_C_chk, float* h_A, float* h_B, int si
}
printf
(
"======================================= Correctness Check ==========================================
\n
"
);
float
epsilon
=
0.00000001
;
int
diff
=
0
;
int
same
=
0
;
float
epsilon
=
0.01
;
int
diff_re
=
0
;
int
diff_im
=
0
;
int
same_re
=
0
;
int
same_im
=
0
;
int
i
;
for
(
i
=
0
;
i
<
size_C
;
i
++
)
{
float
check
=
h_C_chk
[
i
]
-
h_C
[
i
];
if
(
check
<
0
)
check
*=
-
1
;
if
(
check
>
epsilon
)
float
check_re
=
h_C_chk
[
i
].
re
-
h_C
[
i
].
re
;
float
check_im
=
h_C_chk
[
i
].
im
-
h_C
[
i
].
im
;
if
(
check_re
<
0
)
check_re
*=
-
1
;
if
(
check_re
>
epsilon
)
{
diff_re
++
;
if
(
diff_re
<
8
)
printf
(
"Index: %5d, (Host) %8.4f, (Dev.) %8.4f >> (Diff.) %8.4f
\n
"
,
i
,
h_C_chk
[
i
].
re
,
h_C
[
i
].
re
,
check_re
);
}
else
{
same_re
++
;
}
if
(
check_im
<
0
)
check_im
*=
-
1
;
if
(
check_im
>
epsilon
)
{
diff
++
;
if
(
diff
<
8
)
printf
(
"Index: %5d, (Host) %8.4f, (Dev.) %8.4f >> (Diff.) %8.4f
\n
"
,
i
,
h_C_chk
[
i
],
h_C
[
i
],
check
);
diff
_im
++
;
if
(
diff
_im
<
8
)
printf
(
"Index: %5d, (Host) %8.4f, (Dev.) %8.4f >> (Diff.) %8.4f
\n
"
,
i
,
h_C_chk
[
i
]
.
im
,
h_C
[
i
]
.
im
,
check
_im
);
}
else
{
same
++
;
same
_im
++
;
}
}
printf
(
" >>> PASSED: %'10d among %'10d in t3
\n
"
,
same
,
size_C
);
printf
(
" >>> ERROR : %'10d among %'10d in t3
\n
"
,
diff
,
size_C
);
printf
(
" >>> PASSED on Re: %'10d among %'10d in t3
\n
"
,
same_re
,
size_C
);
printf
(
" >>> PASSED on Im: %'10d among %'10d in t3
\n
"
,
same_im
,
size_C
);
printf
(
" >>> ERROR on Re : %'10d among %'10d in t3
\n
"
,
diff_re
,
size_C
);
printf
(
" >>> ERROR on Im : %'10d among %'10d in t3
\n
"
,
diff_im
,
size_C
);
printf
(
" >>> Total Operations: %'lld
\n
"
,
tmp_ops
*
2
);
printf
(
"====================================================================================================
\n
"
);
}
...
...
@@ -101,9 +141,9 @@ void post_Correctness(float* h_C, float* h_C_chk, float* h_A, float* h_B, int si
int
main
(
int
argc
,
char
**
argv
)
{
// for sd2
float
*
host_C
,
*
host_C_chk
;
float
*
host_A
;
float
*
host_B
;
Complex
*
host_C
,
*
host_C_chk
;
Complex
*
host_A
;
Complex
*
host_B
;
int
size_idx_a
,
size_idx_b
,
size_idx_c
,
size_idx_d
,
size_idx_e
,
size_idx_f
;
// Problem Size
...
...
@@ -137,10 +177,10 @@ int main(int argc, char** argv)
size_B
=
size_idx_d
*
size_idx_f
*
size_idx_c
*
size_idx_e
;
//
host_C
=
(
float
*
)
malloc
(
sizeof
(
float
)
*
size_C
);
host_C_chk
=
(
float
*
)
malloc
(
sizeof
(
float
)
*
size_C
);
host_A
=
(
float
*
)
malloc
(
sizeof
(
float
)
*
size_A
);
host_B
=
(
float
*
)
malloc
(
sizeof
(
float
)
*
size_B
);
host_C
=
(
Complex
*
)
malloc
(
sizeof
(
Complex
)
*
size_C
);
host_C_chk
=
(
Complex
*
)
malloc
(
sizeof
(
Complex
)
*
size_C
);
host_A
=
(
Complex
*
)
malloc
(
sizeof
(
Complex
)
*
size_A
);
host_B
=
(
Complex
*
)
malloc
(
sizeof
(
Complex
)
*
size_B
);
printf
(
"==========================================================================================================
\n
"
);
printf
(
">>> abcd-aebf-dfce
\n
"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment