Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5459e4d8
Commit
5459e4d8
authored
Mar 01, 2022
by
Shucai Xiao
Browse files
clang format
parent
2d5e45b8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
16 deletions
+20
-16
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+4
-3
src/targets/gpu/device/mul.cpp
src/targets/gpu/device/mul.cpp
+4
-3
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+12
-10
No files found.
src/targets/gpu/device/add.cpp
View file @
5459e4d8
...
...
@@ -25,7 +25,7 @@ __global__ void add_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
int
idb
=
tid
%
n_dim
;
...
...
@@ -42,10 +42,11 @@ void add(hipStream_t stream, const argument& result, const argument& arg1, const
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
add_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
add_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
...
...
src/targets/gpu/device/mul.cpp
View file @
5459e4d8
...
...
@@ -25,7 +25,7 @@ __global__ void mul_kernel(void* a, void* b, int n_dim, void* r, int n)
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
int
idb
=
tid
%
n_dim
;
...
...
@@ -42,10 +42,11 @@ void mul(hipStream_t stream, const argument& result, const argument& arg1, const
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
mul_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
mul_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
...
...
src/targets/gpu/device/mul_add.cpp
View file @
5459e4d8
...
...
@@ -21,7 +21,7 @@ __global__ void mul_add_kernel_dim3(void* a, void* x, void* b, int dim3, void* r
if
(
id
<
n
)
{
auto
id1
=
id
%
dim3
;
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id1
]),
hb
[
id1
]);
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id1
]),
hb
[
id1
]);
}
}
...
...
@@ -35,7 +35,7 @@ __global__ void mul_add_kernel_dim4(void* a, void* x, void* b, int factor, int d
if
(
id
<
n
)
{
int
idb
=
id
/
factor
+
id
%
dim4
;
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id
]),
hb
[
idb
]);
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
],
hx
[
id
]),
hb
[
idb
]);
}
}
...
...
@@ -72,20 +72,22 @@ void mul_add(hipStream_t stream,
ss
.
push_back
(
arg3
.
get_shape
());
if
(
type
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
lens
=
sr
.
lens
();
int
last_dim
=
lens
.
back
()
/
2
;
auto
n_dim
=
lens
.
size
();
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
lens
=
sr
.
lens
();
int
last_dim
=
lens
.
back
()
/
2
;
auto
n_dim
=
lens
.
size
();
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
if
(
n_dim
==
2
)
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
if
(
n_dim
==
2
)
{
mul_add_kernel_dim3
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
mul_add_kernel_dim3
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
int
factor
=
lens
[
1
];
mul_add_kernel_dim4
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
factor
,
last_dim
,
result
.
data
(),
elem_num
);
mul_add_kernel_dim4
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
factor
,
last_dim
,
result
.
data
(),
elem_num
);
}
}
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment