Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6c834296
"src/targets/vscode:/vscode.git/clone" did not exist on "d6d386f7b2c0fd8ed19b4302eaf1021f880f1d2e"
Commit
6c834296
authored
Mar 03, 2022
by
Shucai Xiao
Browse files
use fma for the mul_add and refine add_gelu implementation
parent
9e5c56da
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
5 deletions
+9
-5
src/targets/gpu/device/gelu.cpp
src/targets/gpu/device/gelu.cpp
+8
-4
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+1
-1
No files found.
src/targets/gpu/device/gelu.cpp
View file @
6c834296
...
...
@@ -57,11 +57,15 @@ __global__ void add_gelu_kernel(void* a, void* b, int n_dim, void* r, int n)
int
idb
=
tid
%
n_dim
;
auto
sum
=
__hadd2
(
ha
[
tid
],
hb
[
idb
]);
__half2
sqrt2
=
__float2half2_rn
(
M_SQRT1_2
);
sum
=
__hmul2
(
sum
,
sqrt2
);
auto
f2
=
__half22float2
(
sum
);
f2
+=
1.0
f
;
auto
x
=
__hmul2
(
sum
,
sqrt2
);
auto
f2
=
__half22float2
(
x
);
f2
.
x
=
::
erf
(
f2
.
x
);
f2
.
y
=
::
erf
(
f2
.
y
);
auto
h2
=
__floats2half2_rn
(
f2
.
x
,
f2
.
y
);
auto
one
=
__float2half2_rn
(
1.0
f
);
h2
=
__hadd2
(
h2
,
one
);
__half2
point5
=
__float2half2_rn
(
0.5
f
);
hr
[
tid
]
=
__hmul2
(
sum
,
__hmul2
(
point5
,
h2
));
}
...
...
@@ -83,7 +87,7 @@ void add_gelu(hipStream_t stream,
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
add_gelu_kernel
<<<
block_num
,
block_size
>>>
(
add_gelu_kernel
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
...
...
src/targets/gpu/device/mul_add.cpp
View file @
6c834296
...
...
@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
if
(
id
<
n
)
{
int
idb
=
id
/
(
factor
*
dim4
)
*
dim4
+
id
%
dim4
;
hr
[
id
]
=
__h
add2
(
__hmul
2
(
ha
[
id
],
hx
[
id
]
)
,
hb
[
idb
]);
hr
[
id
]
=
__h
fma
2
(
ha
[
id
],
hx
[
id
],
hb
[
idb
]);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment