Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
67903751
Commit
67903751
authored
Feb 27, 2022
by
Shucai Xiao
Browse files
reimlementation of mul_add
parent
287f7e9f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
28 deletions
+63
-28
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+63
-28
No files found.
src/targets/gpu/device/mul_add.cpp
View file @
67903751
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
...
@@ -21,49 +24,81 @@ namespace device {
...
@@ -21,49 +24,81 @@ namespace device {
// }
// }
//}
//}
__global__
void
mul_add_kernel
(
void
*
a
,
int
an
,
void
*
x
,
int
xn
,
void
*
b
,
int
bn
,
void
*
r
,
int
n
)
// __global__ void mul_add_kernel(void* a, int an, void* x, int xn, void* b, int bn, void* r, int n)
// {
// int id = blockDim.x * blockIdx.x + threadIdx.x;
// __half2* ha = reinterpret_cast<__half2*>(a);
// __half2* hb = reinterpret_cast<__half2*>(b);
// __half2* hx = reinterpret_cast<__half2*>(x);
// __half2* hr = reinterpret_cast<__half2*>(r);
// if(id < n)
// {
// hr[id] = __hadd2(__hmul2(ha[id % an], hx[id % xn]), hb[id % bn]);
// }
// }
__global__
void
mul_add_kernel
(
void
*
a
,
void
*
x
,
void
*
b
,
void
*
r
,
int
*
strides
,
int
elem_num
)
{
{
int
id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
__shared__
int
shared_strides
[
18
];
int
tid
=
threadIdx
.
x
*
(
blockDim
.
y
*
blockDim
.
z
)
+
threadIdx
.
y
*
blockDim
.
z
+
threadIdx
.
z
;
if
(
tid
<
18
)
{
shared_strides
[
tid
]
=
strides
[
tid
];
}
__syncthreads
();
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
if
(
id
<
n
)
tid
=
tid
+
(
blockIdx
.
x
*
(
gridDim
.
y
*
gridDim
.
z
)
+
blockIdx
.
y
*
gridDim
.
z
+
blockIdx
.
z
)
*
blockDim
.
x
*
blockDim
.
y
*
blockDim
.
z
;
if
(
tid
<
elem_num
)
{
{
hr
[
id
]
=
__hadd2
(
__hmul2
(
ha
[
id
%
an
],
hx
[
id
%
xn
]),
hb
[
id
%
bn
]);
int
tida
=
shared_strides
[
1
]
*
blockIdx
.
x
+
shared_strides
[
2
]
*
blockIdx
.
y
+
shared_strides
[
3
]
*
blockIdx
.
z
+
shared_strides
[
4
]
*
threadIdx
.
x
+
shared_strides
[
5
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidx
=
shared_strides
[
7
]
*
blockIdx
.
x
+
shared_strides
[
8
]
*
blockIdx
.
y
+
shared_strides
[
9
]
*
blockIdx
.
z
+
shared_strides
[
10
]
*
threadIdx
.
x
+
shared_strides
[
11
]
*
threadIdx
.
y
+
threadIdx
.
z
;
int
tidb
=
shared_strides
[
13
]
*
blockIdx
.
x
+
shared_strides
[
14
]
*
blockIdx
.
y
+
shared_strides
[
15
]
*
blockIdx
.
z
+
shared_strides
[
16
]
*
threadIdx
.
x
+
shared_strides
[
17
]
*
threadIdx
.
y
+
threadIdx
.
z
;
hr
[
tid
]
=
__hadd2
(
__hmul2
(
ha
[
tida
],
hx
[
tidx
]),
hb
[
tidb
]);
}
}
}
}
// void mul_add(hipStream_t stream,
// const argument& result,
// const argument& arg1,
// const argument& arg2,
// const argument& arg3)
// {
// auto type = result.get_shape().type();
// if(type == shape::half_type)
// {
// std::cout << "case1" << std::endl;
// mul_add_kernel<<<block_num, block_size>>>(
// arg1.data(), s1e, arg2.data(), s2e, arg3.data(), s3e, result.data(), elem_num);
// }
// else
// {
// std::cout << "mul_add" << std::endl;
// nary(stream, result, arg1, arg2, arg3)([](auto x, auto a, auto b)
// __device__ { return a * x + b; });
// }
// }
void
mul_add
(
hipStream_t
stream
,
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg2
,
const
argument
&
arg3
)
const
argument
&
arg3
)
{
{
auto
type
=
result
.
get_shape
().
type
();
auto
sr
=
result
.
get_shape
();
if
(
type
==
shape
::
half_type
)
auto
s2
=
arg2
.
get_shape
();
{
auto
s3
=
arg3
.
get_shape
();
std
::
cout
<<
"case1"
<<
std
::
endl
;
int
s1e
=
arg1
.
get_shape
().
element_space
()
/
2
;
hip_visit_all
(
result
,
arg1
,
arg2
,
arg3
,
sr
)([
&
](
auto
r
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
dsr
)
{
int
s2e
=
arg2
.
get_shape
().
element_space
()
/
2
;
gs_launch
(
stream
,
sr
.
elements
())([
=
](
auto
i
)
__device__
{
int
s3e
=
arg3
.
get_shape
().
element_space
()
/
2
;
auto
idx
=
dsr
.
multi
(
i
);
int
elem_num
=
result
.
get_shape
().
elements
()
/
2
;
r
[
i
]
=
i1
[
i
]
*
i2
[
idx
]
+
i3
[
idx
];
s1e
=
(
s1e
==
0
?
1
:
s1e
);
});
s2e
=
(
s2e
==
0
?
1
:
s2e
);
});
s3e
=
(
s3e
==
0
?
1
:
s3e
);
std
::
cout
<<
"re ="
<<
elem_num
<<
", s1e = "
<<
s1e
<<
", s2e = "
<<
s2e
<<
", s3e = "
<<
s3e
<<
std
::
endl
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
mul_add_kernel
<<<
block_num
,
block_size
>>>
(
arg1
.
data
(),
s1e
,
arg2
.
data
(),
s2e
,
arg3
.
data
(),
s3e
,
result
.
data
(),
elem_num
);
}
else
{
std
::
cout
<<
"case2"
<<
std
::
endl
;
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
__device__
{
return
a
*
x
+
b
;
});
}
}
}
}
// namespace device
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment