Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
65ef1423
Commit
65ef1423
authored
Feb 17, 2022
by
Shucai Xiao
Browse files
Merge branch 'rocblas_api_opt' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into branch_for_ort
parents
96595c17
d45bd3ba
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
157 additions
and
68 deletions
+157
-68
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+131
-62
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+9
-2
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+4
-2
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+13
-2
No files found.
src/targets/gpu/gemm_impl.cpp
100755 → 100644
View file @
65ef1423
...
@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
...
@@ -42,7 +42,8 @@ void gemm_impl(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
T
alpha
,
T
alpha
,
T
beta
,
T
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
...
@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
...
@@ -65,6 +66,11 @@ void gemm_impl(context& ctx,
output_type
=
rocblas_datatype_i32_r
;
output_type
=
rocblas_datatype_i32_r
;
}
}
auto
compute_type
=
output_type
;
auto
compute_type
=
output_type
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
rocblas_gemm_flags
flag
=
rocblas_gemm_flags
flag
=
...
@@ -77,8 +83,9 @@ void gemm_impl(context& ctx,
...
@@ -77,8 +83,9 @@ void gemm_impl(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
...
@@ -97,64 +104,124 @@ void gemm_impl(context& ctx,
...
@@ -97,64 +104,124 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// A and args[0] as B in calling the rocblas_gemm.
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
if
(
compute_fp32
)
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
rocblas_invoke
(
&
rocblas_gemm_ex
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
ctx
.
get_stream
().
get_rocblas
(),
n
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
k
,
n
,
&
alpha_r
,
m
,
to_pointer
(
args
.
at
(
1
)),
k
,
arg_type
,
&
alpha
,
ldb
,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
ldb
,
lda
,
to_pointer
(
args
.
at
(
0
)),
&
beta_r
,
arg_type
,
to_pointer
(
args
[
2
]),
lda
,
output_type
,
&
beta
,
ldc
,
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
compute_type
,
output_type
,
rocblas_gemm_algo_standard
,
ldc
,
0
,
compute_type
,
flag
);
rocblas_gemm_algo_standard
,
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
}
else
else
{
{
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
if
(
compute_fp32
)
ctx
.
get_stream
().
get_rocblas
(),
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
ctx
.
get_stream
().
get_rocblas
(),
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
n
,
k
,
m
,
&
alpha_r
,
k
,
to_pointer
(
args
.
at
(
1
)),
&
alpha
,
arg_type
,
to_pointer
(
args
.
at
(
1
)),
ldb
,
arg_type
,
k
*
n
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
k
*
n
,
arg_type
,
to_pointer
(
args
.
at
(
0
)),
lda
,
arg_type
,
m
*
k
,
lda
,
&
beta_r
,
m
*
k
,
to_pointer
(
args
[
2
]),
&
beta
,
output_type
,
to_pointer
(
args
[
2
]),
ldc
,
output_type
,
m
*
n
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
m
*
n
,
output_type
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
ldc
,
output_type
,
m
*
n
,
ldc
,
num_matrices
,
m
*
n
,
compute_type
,
num_matrices
,
rocblas_gemm_algo_standard
,
compute_type
,
0
,
rocblas_gemm_algo_standard
,
flag
);
0
,
flag
);
else
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha_r
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
k
*
n
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
m
*
k
,
&
beta_r
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
m
*
n
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
}
});
});
}
}
...
@@ -164,9 +231,10 @@ void gemm(context& ctx,
...
@@ -164,9 +231,10 @@ void gemm(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
);
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
}
}
void
gemm
(
context
&
ctx
,
void
gemm
(
context
&
ctx
,
...
@@ -174,9 +242,10 @@ void gemm(context& ctx,
...
@@ -174,9 +242,10 @@ void gemm(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
alpha
,
int32_t
beta
,
int32_t
beta
,
bool
int8_x4_format
)
bool
int8_x4_format
,
bool
compute_fp32
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
);
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
65ef1423
...
@@ -25,6 +25,7 @@ struct rocblas_gemm
...
@@ -25,6 +25,7 @@ struct rocblas_gemm
float
alpha
=
1
;
float
alpha
=
1
;
float
beta
=
0
;
float
beta
=
0
;
bool
int8_x4_format
=
true
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -80,11 +81,17 @@ struct rocblas_gemm
...
@@ -80,11 +81,17 @@ struct rocblas_gemm
{
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
)
{
{
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
);
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
}
}
else
else
{
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
);
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
int8_x4_format
,
compute_fp32
);
}
}
return
args
.
back
();
return
args
.
back
();
}
}
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
65ef1423
...
@@ -14,13 +14,15 @@ void gemm(context& ctx,
...
@@ -14,13 +14,15 @@ void gemm(context& ctx,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
alpha
,
float
beta
,
float
beta
,
bool
int8_x4_format
);
bool
int8_x4_format
,
bool
compute_fp32
);
void
gemm
(
context
&
ctx
,
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
alpha
,
int32_t
beta
,
int32_t
beta
,
bool
int8_x4_format
);
bool
int8_x4_format
,
bool
compute_fp32
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/lowering.cpp
View file @
65ef1423
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/deconvolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/elu.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/equal.hpp>
#include <migraphx/gpu/gemm.hpp>
#include <migraphx/gpu/gemm.hpp>
...
@@ -60,6 +61,7 @@ struct miopen_apply
...
@@ -60,6 +61,7 @@ struct miopen_apply
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
prog_output_names
{};
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
prog_output_names
{};
bool
offload_copy
=
false
;
bool
offload_copy
=
false
;
bool
int8_x4_format
=
true
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
context
&
get_context
()
const
context
&
get_context
()
const
{
{
...
@@ -96,13 +98,22 @@ struct miopen_apply
...
@@ -96,13 +98,22 @@ struct miopen_apply
}
}
}
}
const
std
::
unordered_set
<
std
::
string
>&
get_rocblas_fp32_archs
()
{
static
std
::
unordered_set
<
std
::
string
>
supported_archs
{
"gfx908"
,
"gfx90a"
};
return
supported_archs
;
}
void
init
()
void
init
()
{
{
assert
(
mod
!=
nullptr
);
assert
(
mod
!=
nullptr
);
assert
(
pass
!=
nullptr
);
assert
(
pass
!=
nullptr
);
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
#if ROCBLAS_VERSION_MAJOR >= 2 && ROCBLAS_VERSION_MINOR >= 38
auto
&
ctx
=
get_context
();
auto
&
ctx
=
get_context
();
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
if
(
contains
(
get_rocblas_fp32_archs
(),
device_name
))
compute_fp32
=
true
;
rocblas_gemm_flags
flag
;
rocblas_gemm_flags
flag
;
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
int8_x4_format
=
(
flag
==
rocblas_gemm_flags_pack_int8x4
);
int8_x4_format
=
(
flag
==
rocblas_gemm_flags_pack_int8x4
);
...
@@ -337,7 +348,7 @@ struct miopen_apply
...
@@ -337,7 +348,7 @@ struct miopen_apply
}
}
}
}
return
mod
->
replace_instruction
(
return
mod
->
replace_instruction
(
ins
,
rocblas_gemm
<
Op
>
{
Op
{},
1
,
0
,
int8_x4_format
},
refs
);
ins
,
rocblas_gemm
<
Op
>
{
Op
{},
1
,
0
,
int8_x4_format
,
compute_fp32
},
refs
);
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment