Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
562724bf
Commit
562724bf
authored
Feb 28, 2022
by
Shucai Xiao
Browse files
clang format
parent
83f89182
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
22 deletions
+29
-22
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+29
-22
No files found.
src/targets/gpu/device/mul_add.cpp
View file @
562724bf
...
...
@@ -41,7 +41,7 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
{
__shared__
int
shared_strides
[
18
];
int
tid
=
threadIdx
.
x
*
(
blockDim
.
y
*
blockDim
.
z
)
+
threadIdx
.
y
*
blockDim
.
z
+
threadIdx
.
z
;
if
(
tid
<
18
)
if
(
tid
<
18
)
{
shared_strides
[
tid
]
=
strides
[
tid
];
}
...
...
@@ -52,12 +52,19 @@ __global__ void mul_add_kernel(void* a, void* x, void* b, void* r, int* strides,
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
tid
=
tid
+
(
blockIdx
.
x
*
(
gridDim
.
y
*
gridDim
.
z
)
+
blockIdx
.
y
*
gridDim
.
z
+
blockIdx
.
z
)
*
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
tid
=
tid
+
(
blockIdx
.
x
*
(
gridDim
.
y
*
gridDim
.
z
)
+
blockIdx
.
y
*
gridDim
.
z
+
blockIdx
.
z
)
*
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
if
(
tid
<
elem_num
)
{
int
tida
=
shared_strides
[
1
]
*
blockIdx
.
x
+
shared_strides
[
2
]
*
blockIdx
.
y
+
shared_strides
[
3
]
*
blockIdx
.
z
+
shared_strides
[
4
]
*
threadIdx
.
x
+
shared_strides
[
5
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidx
=
shared_strides
[
7
]
*
blockIdx
.
x
+
shared_strides
[
8
]
*
blockIdx
.
y
+
shared_strides
[
9
]
*
blockIdx
.
z
+
shared_strides
[
10
]
*
threadIdx
.
x
+
shared_strides
[
11
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidb
=
shared_strides
[
13
]
*
blockIdx
.
x
+
shared_strides
[
14
]
*
blockIdx
.
y
+
shared_strides
[
15
]
*
blockIdx
.
z
+
shared_strides
[
16
]
*
threadIdx
.
x
+
shared_strides
[
17
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tida
=
shared_strides
[
1
]
*
blockIdx
.
x
+
shared_strides
[
2
]
*
blockIdx
.
y
+
shared_strides
[
3
]
*
blockIdx
.
z
+
shared_strides
[
4
]
*
threadIdx
.
x
+
shared_strides
[
5
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidx
=
shared_strides
[
7
]
*
blockIdx
.
x
+
shared_strides
[
8
]
*
blockIdx
.
y
+
shared_strides
[
9
]
*
blockIdx
.
z
+
shared_strides
[
10
]
*
threadIdx
.
x
+
shared_strides
[
11
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidb
=
shared_strides
[
13
]
*
blockIdx
.
x
+
shared_strides
[
14
]
*
blockIdx
.
y
+
shared_strides
[
15
]
*
blockIdx
.
z
+
shared_strides
[
16
]
*
threadIdx
.
x
+
shared_strides
[
17
]
*
threadIdx
.
y
+
threadIdx
.
z
;
hr
[
tid
]
=
__hadd2
(
__hmul2
(
ha
[
tida
],
hx
[
tidx
]),
hb
[
tidb
]);
}
}
...
...
@@ -89,28 +96,28 @@ void mul_add(hipStream_t stream,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
auto
sr
=
result
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
s2
=
arg2
.
get_shape
();
auto
s3
=
arg3
.
get_shape
();
auto
sr
=
result
.
get_shape
();
auto
s1
=
arg1
.
get_shape
();
auto
s2
=
arg2
.
get_shape
();
auto
s3
=
arg3
.
get_shape
();
auto
type
=
sr
.
type
();
if
(
type
==
sr
.
type
())
{
hip_visit_all
(
result
,
arg1
,
arg2
,
arg3
,
sr
,
s1
,
s2
,
s3
)([
&
](
auto
r
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
dsr
,
auto
ds1
,
auto
ds2
,
auto
ds3
)
{
__half2
*
rp
=
reinterpret_cast
<
__half2
*>
(
r
.
data
());
__half2
*
i1p
=
reinterpret_cast
<
__half2
*>
(
i1
.
data
());
__half2
*
i2p
=
reinterpret_cast
<
__half2
*>
(
i2
.
data
());
__half2
*
i3p
=
reinterpret_cast
<
__half2
*>
(
i3
.
data
());
gs_launch
(
stream
,
sr
.
elements
()
/
2
)([
=
](
auto
i
)
__device__
{
auto
idx
=
dsr
.
multi
(
i
);
auto
idx1
=
ds1
.
index
(
idx
);
auto
idx2
=
ds2
.
index
(
idx
);
auto
idx3
=
ds3
.
index
(
idx
);
rp
[
i
]
=
__hadd2
(
__hmul2
(
i1p
[
idx1
],
i2p
[
idx2
]),
i3p
[
idx3
]);
hip_visit_all
(
result
,
arg1
,
arg2
,
arg3
,
sr
,
s1
,
s2
,
s3
)(
[
&
](
auto
r
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
dsr
,
auto
ds1
,
auto
ds2
,
auto
ds3
)
{
__half2
*
rp
=
reinterpret_cast
<
__half2
*>
(
r
.
data
());
__half2
*
i1p
=
reinterpret_cast
<
__half2
*>
(
i1
.
data
());
__half2
*
i2p
=
reinterpret_cast
<
__half2
*>
(
i2
.
data
());
__half2
*
i3p
=
reinterpret_cast
<
__half2
*>
(
i3
.
data
());
gs_launch
(
stream
,
sr
.
elements
()
/
2
)([
=
](
auto
i
)
__device__
{
auto
idx
=
dsr
.
multi
(
i
);
auto
idx1
=
ds1
.
index
(
idx
);
auto
idx2
=
ds2
.
index
(
idx
);
auto
idx3
=
ds3
.
index
(
idx
);
rp
[
i
]
=
__hadd2
(
__hmul2
(
i1p
[
idx1
],
i2p
[
idx2
]),
i3p
[
idx3
]);
});
});
});
}
else
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment