Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
781ce146
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "492d329a44f24ac8e1fcb3f2bf355793cef50497"
Commit
781ce146
authored
Apr 26, 2021
by
Khalique Ahmed
Browse files
add fp16 fixes
parent
6d937d80
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
27 additions
and
11 deletions
+27
-11
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+4
-4
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+11
-6
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+11
-0
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
781ce146
...
@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
...
@@ -76,7 +76,7 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline
auto
gs_launch
(
hipStream_t
stream
,
index_int
n
,
index_int
local
=
1024
)
inline
auto
gs_launch
(
hipStream_t
stream
,
index_int
n
,
index_int
local
=
1024
)
{
{
index_int
groups
=
(
n
+
local
-
1
)
/
local
;
index_int
groups
=
(
n
+
local
-
1
)
/
local
;
index_int
nglobal
=
std
::
min
<
index_int
>
(
25
6
,
groups
)
*
local
;
index_int
nglobal
=
std
::
min
<
index_int
>
(
104857
6
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
__device__
{
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
781ce146
...
@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl(
...
@@ -87,7 +87,7 @@ void nary_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
512
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
...
@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -134,7 +134,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
512
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
...
@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl(
...
@@ -178,7 +178,7 @@ void nary_double_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
512
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
@@ -234,7 +234,7 @@ void nary_double_broadcast_impl(
...
@@ -234,7 +234,7 @@ void nary_double_broadcast_impl(
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
512
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
...
src/targets/gpu/gemm_impl.cpp
View file @
781ce146
...
@@ -60,12 +60,17 @@ void gemm_impl(
...
@@ -60,12 +60,17 @@ void gemm_impl(
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
auto
compute_type
=
output_type
;
auto
compute_type
=
output_type
;
if
(
ctx
.
get_stream
().
get_device_name
()
==
"gfx908"
)
{
if
(
args
[
0
].
get_shape
().
type
()
==
shape
::
half_type
)
compute_type
=
rocblas_datatype_f32_r
;
}
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
...
@@ -91,14 +96,14 @@ void gemm_impl(
...
@@ -91,14 +96,14 @@ void gemm_impl(
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha
_r
,
&
alpha
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
&
beta
_r
,
&
beta
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
...
@@ -123,7 +128,7 @@ void gemm_impl(
...
@@ -123,7 +128,7 @@ void gemm_impl(
n
,
n
,
m
,
m
,
k
,
k
,
&
alpha
_r
,
&
alpha
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
...
@@ -132,7 +137,7 @@ void gemm_impl(
...
@@ -132,7 +137,7 @@ void gemm_impl(
arg_type
,
arg_type
,
lda
,
lda
,
m
*
k
,
m
*
k
,
&
beta
_r
,
&
beta
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
781ce146
...
@@ -87,6 +87,17 @@ struct hip_device
...
@@ -87,6 +87,17 @@ struct hip_device
return
rbhandle
.
get
();
return
rbhandle
.
get
();
}
}
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
// int device;
// if (not (hipGetDevice(&device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device");
// if (not (hipGetDeviceProperties(&props, device) == hipSuccess))
// MIGRAPHX_THROW("Unable to get hip device properties");
return
"gfx"
+
std
::
to_string
(
props
.
gcnArch
);
}
void
wait
(
hipEvent_t
event
)
void
wait
(
hipEvent_t
event
)
{
{
setup
();
setup
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment