Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
352e6668
Commit
352e6668
authored
Feb 10, 2023
by
Brian Pickrell
Browse files
Brian/Umang changes, still doesn't build
parent
d46b6972
Pipeline
#666
failed with stages
in 0 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
82 additions
and
62 deletions
+82
-62
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+60
-59
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+22
-3
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
352e6668
...
@@ -108,16 +108,6 @@ static rocblas_int get_batch_stride(const argument& a)
...
@@ -108,16 +108,6 @@ static rocblas_int get_batch_stride(const argument& a)
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
}
}
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum
ROCBLAS_CALL
{
ROCBLAS_GEMM_EX
,
ROCBLAS_GEMM_STRIDED_BATCHED_EX
,
ROCBLAS_GEMM_EX_GET_SOLUTIONS
,
};
template
<
class
T
>
template
<
class
T
>
void
gemm_impl
(
context
&
ctx
,
void
gemm_impl
(
context
&
ctx
,
...
@@ -129,6 +119,7 @@ void gemm_impl(context& ctx,
...
@@ -129,6 +119,7 @@ void gemm_impl(context& ctx,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
// TODO: not needed?
output_shape
.
visit_type
([
&
](
auto
as
)
{
// TODO: not needed?
(
void
)
as
;
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
auto
num_matrices
=
std
::
accumulate
(
auto
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
...
@@ -142,7 +133,7 @@ void gemm_impl(context& ctx,
...
@@ -142,7 +133,7 @@ void gemm_impl(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// A and args[0] as B in calling the rocblas_gemm.
//
auto to_invoke =
auto
to_invoke
=
create_gemm_args
(
ctx
,
ROCBLAS_CALL
::
ROCBLAS_GEMM_EX
,
output_shape
,
args
,
create_gemm_args
(
ctx
,
ROCBLAS_CALL
::
ROCBLAS_GEMM_EX
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
// rocblas_invoke(&rocblas_gemm_ex,
// rocblas_invoke(&rocblas_gemm_ex,
...
@@ -150,7 +141,7 @@ void gemm_impl(context& ctx,
...
@@ -150,7 +141,7 @@ void gemm_impl(context& ctx,
}
}
else
else
{
{
//
auto to_invoke =
auto
to_invoke
=
create_gemm_args
(
ctx
,
ROCBLAS_CALL
::
ROCBLAS_GEMM_STRIDED_BATCHED_EX
,
create_gemm_args
(
ctx
,
ROCBLAS_CALL
::
ROCBLAS_GEMM_STRIDED_BATCHED_EX
,
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
output_shape
,
args
,
alpha
,
beta
,
int8_x4_format
,
compute_fp32
);
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
// rocblas_invoke(&rocblas_gemm_strided_batched_ex,
...
@@ -225,7 +216,7 @@ void gemm(context& ctx,
...
@@ -225,7 +216,7 @@ void gemm(context& ctx,
* a set of MigraphX arguments.
* a set of MigraphX arguments.
*/
*/
template
<
class
T
>
template
<
class
T
>
auto
create_gemm_args
(
context
&
ctx
,
static
auto
create_gemm_args
(
context
&
ctx
,
ROCBLAS_CALL
rocblas_call
,
ROCBLAS_CALL
rocblas_call
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
inputs
,
const
std
::
vector
<
argument
>&
inputs
,
...
@@ -273,25 +264,27 @@ auto create_gemm_args(context& ctx,
...
@@ -273,25 +264,27 @@ auto create_gemm_args(context& ctx,
auto
a_lens
=
inputs
[
0
].
get_shape
().
lens
();
auto
a_lens
=
inputs
[
0
].
get_shape
().
lens
();
auto
b_lens
=
inputs
[
1
].
get_shape
().
lens
();
auto
b_lens
=
inputs
[
1
].
get_shape
().
lens
();
return
output_shape
.
visit_type
([
&
](
auto
as
)
{
void
*
alpha_v
=
nullptr
;
void
*
beta_v
=
nullptr
;
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
// use void pointer to select different data type if using fp32 mode
// use void pointer to select different data type if using fp32 mode
void
*
alpha_v
=
&
alpha_r
;
alpha_v
=
&
alpha_r
;
void
*
beta_v
=
&
beta_r
;
beta_v
=
&
beta_r
;
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
alpha_v
=
&
alpha
;
alpha_v
=
&
alpha
;
beta_v
=
&
beta
;
beta_v
=
&
beta
;
}
}
});
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
inputs
[
0
].
get_shape
().
lens
()[
dim_1
];
rocblas_int
k
=
inputs
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
reinterpret_cast
<
T
*>
(
arg
.
data
());
};
if
(
inputs
[
0
].
get_shape
().
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
if
(
inputs
[
0
].
get_shape
().
type
()
==
shape
::
int8_type
and
(
k
%
4
)
!=
0
and
int8_x4_format
)
{
{
MIGRAPHX_THROW
(
"create_gemm_args: k size of int8 type input must be multiple of 4!"
);
MIGRAPHX_THROW
(
"create_gemm_args: k size of int8 type input must be multiple of 4!"
);
...
@@ -302,7 +295,7 @@ auto create_gemm_args(context& ctx,
...
@@ -302,7 +295,7 @@ auto create_gemm_args(context& ctx,
switch
(
rocblas_call
){
switch
(
rocblas_call
){
case
ROCBLAS_GEMM_EX
:
case
ROCBLAS_GEMM_EX
:
{
{
m
*=
num_matrices
;
m
*=
num_matrices
;
return
pack
(
return
pack
(
...
@@ -337,46 +330,49 @@ auto create_gemm_args(context& ctx,
...
@@ -337,46 +330,49 @@ auto create_gemm_args(context& ctx,
0
,
0
,
flag
);
flag
);
}
}
// case ROCBLAS_GEMM_STRIDED_BATCHED_EX:
// {
case
ROCBLAS_GEMM_STRIDED_BATCHED_EX
:
// auto a_stride = get_batch_stride(inputs[0]);
default:
// auto b_stride = get_batch_stride(inputs[1]);
{
// auto c_stride = get_batch_stride(inputs[2]);
auto
a_stride
=
get_batch_stride
(
inputs
[
0
]);
// auto d_stride = is_3inputs ? get_batch_stride(inputs[3]) : c_stride;
auto
b_stride
=
get_batch_stride
(
inputs
[
1
]);
// return pack(
auto
c_stride
=
get_batch_stride
(
inputs
[
2
]);
// // rocblas_invoke( &rocblas_gemm_strided_batched_ex,
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
inputs
[
3
])
:
c_stride
;
// ctx.get_stream().get_rocblas(),
return
pack
(
// transb ? rocblas_operation_transpose : rocblas_operation_none,
// rocblas_invoke( &rocblas_gemm_strided_batched_ex,
// transa ? rocblas_operation_transpose : rocblas_operation_none,
ctx
.
get_stream
().
get_rocblas
(),
// n,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
// m,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
// k,
n
,
// alpha_v,
m
,
// to_pointer(inputs.at(1)),
k
,
// arg_type,
alpha_v
,
// ldb,
to_pointer
(
inputs
.
at
(
1
)),
// b_stride,
arg_type
,
// to_pointer(inputs.at(0)),
ldb
,
// arg_type,
b_stride
,
// lda,
to_pointer
(
inputs
.
at
(
0
)),
// a_stride,
arg_type
,
// beta_v,
lda
,
// to_pointer(inputs[2]),
a_stride
,
// output_type,
beta_v
,
// ldc,
to_pointer
(
inputs
[
2
]),
// c_stride,
output_type
,
// is_3inputs ? to_pointer(inputs[3]) : to_pointer(inputs[2]),
ldc
,
// output_type,
c_stride
,
// ldd,
is_3inputs
?
to_pointer
(
inputs
[
3
])
:
to_pointer
(
inputs
[
2
]),
// d_stride,
output_type
,
// num_matrices,
ldd
,
// compute_type,
d_stride
,
// rocblas_gemm_algo_standard,
num_matrices
,
// 0,
compute_type
,
// flag);
rocblas_gemm_algo_standard
,
// }
0
,
flag
);
}
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// case ROCBLAS_GEMM_EX_GET_SOLUTIONS:
// default:
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // the original macro in rocBLAS-internal/rocBLAS/clients/samples/example_user_driven_tuning.cpp is
// // Note different order of m, n, k
// // Note different order of m, n, k
// // #define GEMM_EX_ARGS \
// // #define GEMM_EX_ARGS \
...
@@ -386,10 +382,15 @@ auto create_gemm_args(context& ctx,
...
@@ -386,10 +382,15 @@ auto create_gemm_args(context& ctx,
// #define GEMM_EX_ARGS \
// #define GEMM_EX_ARGS \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// handle, transa, transb, m, n, k, alpha_v, da, type, lda, db, type, ldb, beta_v, dc, type, ldc, \
// dc, type, ldc, type, rocblas_gemm_algo_solution_index
// dc, type, ldc, type, rocblas_gemm_algo_solution_index
// return pack(ctx.get_stream().get_rocblas());
// return pack(ctx.get_stream().get_rocblas());
// Get number of solutions
// rocblas_int size;
// CHECK_ROCBLAS_ERROR(
// rocblas_gemm_ex_get_solutions(GEMM_EX_ARGS, rocblas_gemm_flags_none, NULL, &size));
}
// end switch
// default:
// default:
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
// MIGRAPHX_THROW ("create_gemm_args(): rocBLAS command not supported");
}});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
352e6668
...
@@ -31,6 +31,16 @@
...
@@ -31,6 +31,16 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
/**
* The rocblas API calls we may be interested in. Each one takes a slightly different
* argument list, generated by create_gemm_args().
*/
enum
ROCBLAS_CALL
{
ROCBLAS_GEMM_EX
,
ROCBLAS_GEMM_STRIDED_BATCHED_EX
,
ROCBLAS_GEMM_EX_GET_SOLUTIONS
,
};
void
gemm
(
context
&
ctx
,
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -47,13 +57,22 @@ void gemm(context& ctx,
...
@@ -47,13 +57,22 @@ void gemm(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
);
bool
compute_fp32
);
template
<
class
T
>
//
template <class T>
auto
create_gemm_args
(
context
&
ctx
,
//
auto create_gemm_args(context& ctx,
const
std
::
vector
<
argument
>&
inputs
);
//
const std::vector<argument>& inputs);
// The version with just shapes will use null pointers for the buffers
// The version with just shapes will use null pointers for the buffers
template
<
class
T
>
template
<
class
T
>
auto
create_gemm_args
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
);
auto
create_gemm_args
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
);
template
<
class
T
>
static
auto
create_gemm_args
(
context
&
ctx
,
ROCBLAS_CALL
rocblas_call
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
inputs
,
T
alpha
,
T
beta
,
bool
int8_x4_format
,
bool
compute_fp32
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment