Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d7c8b66f
Unverified
Commit
d7c8b66f
authored
Nov 07, 2023
by
Brian Pickrell
Committed by
GitHub
Nov 07, 2023
Browse files
Blas auto-tuning for GEMMs (#1668)
parent
4bd3f4e3
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
847 additions
and
166 deletions
+847
-166
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+12
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+474
-145
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+37
-6
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+51
-13
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+1
-1
test/gpu/gemm_tune.cpp
test/gpu/gemm_tune.cpp
+225
-0
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+47
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
d7c8b66f
# ####################################################################################
# ####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -245,10 +245,14 @@ else()
...
@@ -245,10 +245,14 @@ else()
endif
()
endif
()
# Check miopen find mode api
# Check miopen find mode api
include
(
CheckLibraryExists
)
include
(
CheckLibraryExists
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
ROCBLAS_LOCATION roc::rocblas LOCATION
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
@@ -271,6 +275,13 @@ else()
...
@@ -271,6 +275,13 @@ else()
message
(
STATUS
"MIOpen does not have find mode api"
)
message
(
STATUS
"MIOpen does not have find mode api"
)
endif
()
endif
()
if
(
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphx is using Beta API of rocBLAS"
)
else
()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -21,15 +21,20 @@
...
@@ -21,15 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <rocblas/rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
{
switch
(
type
)
switch
(
type
)
...
@@ -81,184 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
...
@@ -81,184 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
/**
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
* Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template
<
class
F
,
class
Pack
,
class
...
Ts
>
auto
rocblas_invoke
(
F
f
,
Pack
p
,
Ts
...
xs
)
{
{
if
constexpr
(
sizeof
...(
Ts
)
==
sizeof
...(
Us
))
return
p
([
=
](
auto
...
ws
)
{
return
f
(
xs
...);
auto
status
=
f
(
ws
...,
xs
...);
else
if
(
status
!=
rocblas_status_success
and
status
!=
rocblas_status_invalid_value
)
return
f
(
xs
...,
nullptr
,
nullptr
);
{
if
(
status
==
rocblas_status_perf_degraded
)
{
std
::
cerr
<<
"WARNING: degraded perf. in rocBLAS call"
<<
std
::
endl
;
}
else
MIGRAPHX_THROW
(
"rocblas_invoke: rocBLAS call failed with status "
+
std
::
to_string
(
status
));
}
return
status
;
});
}
}
static
bool
is_transposed
(
const
shape
&
s
)
static
bool
is_transposed
(
const
shape
&
s
)
{
return
s
.
transposed
()
and
s
.
strides
().
back
()
!=
1
;
}
{
if
(
not
s
.
transposed
())
return
false
;
return
s
.
strides
().
back
()
!=
1
;
}
static
rocblas_int
get_batch_stride
(
const
argument
&
a
)
static
rocblas_int
get_batch_stride
(
const
shape
&
s
)
{
{
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
// This value is not needed for non-strided inputs
if
(
s
.
strides
().
size
()
<
3
)
return
0
;
else
return
s
.
strides
()[
s
.
strides
().
size
()
-
3
];
}
}
template
<
class
T
>
/**
void
gemm_impl
(
context
&
ctx
,
* Wrapper for multiple rocBLAS calls. The constructor creates parameters for
const
shape
&
output_shape
,
* these calls based on data shapes and other values contained in the associated
const
std
::
vector
<
argument
>&
args
,
* instruction and operation.
T
alpha
,
*
T
beta
,
* The template parameter T is not the type of the matrix data but of the weighting
bool
compute_fp32
)
* coefficients alpha and beta (these are float in rocBLAS internals)
*/
template
<
typename
T
>
struct
gemm_impl
{
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
gemm_impl
(
const
shape
&
output_shape
,
if
(
not
is_3inputs
)
const
std
::
vector
<
shape
>&
input_shapes
,
{
T
alpha_param
,
beta
=
0
;
T
beta_param
,
}
bool
compute_fp32_flag
)
:
alpha
(
alpha_param
),
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
beta
(
beta_param
),
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
is_3inputs
(
input_shapes
.
size
()
==
4
),
auto
n_dim
=
output_shape
.
lens
().
size
();
compute_fp32
(
compute_fp32_flag
)
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
auto
compute_type
=
output_type
;
if
(
compute_fp32
)
{
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
if
(
not
is_3inputs
)
compute_type
=
rocblas_datatype_f32_r
;
{
}
beta
=
0
;
}
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
// Create lambdas that will cast alpha, beta to the output shape's type
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
// and retain the values being pointed to
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
auto
beta_r
=
as
(
beta
);
if
(
compute_fp32
)
{
get_alpha
=
[
=
]
{
return
&
alpha
;
};
get_beta
=
[
=
]
{
return
&
beta
;
};
}
else
{
get_alpha
=
[
=
]
{
return
&
alpha_r
;
};
get_beta
=
[
=
]
{
return
&
beta_r
;
};
}
});
// use void pointer to select different data type if using fp32 mode
transa
=
is_transposed
(
input_shapes
[
0
]);
void
*
alpha_v
=
&
alpha_r
;
transb
=
is_transposed
(
input_shapes
[
1
]);
void
*
beta_v
=
&
beta_r
;
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_0
=
n_dim
-
2
;
auto
dim_1
=
n_dim
-
1
;
// Leading dimensions of matrices
lda
=
input_shapes
[
0
].
strides
()[
transa
?
dim_1
:
dim_0
];
ldb
=
input_shapes
[
1
].
strides
()[
transb
?
dim_1
:
dim_0
];
ldc
=
input_shapes
[
2
].
strides
()[
dim_0
];
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
if
(
compute_fp32
)
if
(
compute_fp32
)
{
{
alpha_v
=
&
alpha
;
if
(
arg_type
==
rocblas_datatype_f16_r
)
beta_v
=
&
beta
;
compute_type
=
rocblas_datatype_f32_r
;
}
}
auto
out_lens
=
output_shape
.
lens
();
auto
a_lens
=
input_shapes
[
0
].
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
auto
b_lens
=
input_shapes
[
1
].
lens
();
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
auto
num_matrices
=
std
::
accumulate
(
auto
out_lens
=
output_shape
.
lens
();
m
=
out_lens
[
dim_0
];
n
=
out_lens
[
dim_1
];
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
c_stride
=
get_batch_stride
(
input_shapes
[
2
]);
d_stride
=
is_3inputs
?
get_batch_stride
(
input_shapes
[
3
])
:
c_stride
;
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
get_batch_stride
(
args
[
1
])
==
0
))
strided_batched
=
num_matrices
>
1
;
if
(
strided_batched
and
b_stride
==
0
and
input_shapes
[
0
].
standard
())
{
{
// If the batch dimension of B is broadcasted, then we can
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
m
*=
num_matrices
;
strided_batched
=
false
;
}
}
// the rocblas_gemm API handles inputs and output matrices as
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
// column-major format. When doing a C = A * B, we actually do
{
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
if
(
strided_batched
)
// A and args[0] as B in calling the rocblas_gemm.
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
common_args
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
rocblas_gemm_algo_solution_index
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
solution_idx
,
n
,
gemm_flags
);
m
,
}
k
,
}
alpha_v
,
to_pointer
(
args
.
at
(
1
)),
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
arg_type
,
auto
validate
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
solution_idx
)
const
ldb
,
{
to_pointer
(
args
.
at
(
0
)),
// Create dummy arguments for the shapes, and call the overloaded method
arg_type
,
std
::
vector
<
argument
>
input_args
;
lda
,
std
::
transform
(
input_shapes
.
begin
(),
beta_v
,
input_shapes
.
end
(),
to_pointer
(
args
[
2
]),
std
::
back_inserter
(
input_args
),
output_type
,
[](
const
shape
&
x
)
{
return
to_gpu
(
generate_argument
(
x
));
});
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
return
validate
(
ctx
,
input_args
,
solution_idx
);
output_type
,
}
ldd
,
compute_type
,
/**
rocblas_gemm_algo_standard
,
* Checks a particular solution for validity by running it with the flag
0
,
* rocblas_gemm_flags_check_solution_index (could be invalid if this model was
flag
);
* tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
)
const
{
rocblas_status_
check_valid
(
rocblas_status_success
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
}
}
else
else
{
{
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
common_args
,
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_gemm_algo_solution_index
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
solution_idx
,
ctx
.
get_stream
().
get_rocblas
(),
rocblas_gemm_flags_check_solution_index
);
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
alpha_v
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
b_stride
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
a_stride
,
beta_v
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldd
,
d_stride
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
}
}
});
if
(
check_valid
==
rocblas_status_invalid_value
)
{
std
::
cerr
<<
"WARNING: tuned solution is invalid; reverting to default"
<<
std
::
endl
;
return
0
;
}
return
solution_idx
;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto
create_strided_batched_args_common
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
b_stride
,
args
[
0
].
data
(),
arg_type
,
lda
,
a_stride
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
d_stride
,
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto
create_gemm_ex_args_common
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
args
[
0
].
data
(),
arg_type
,
lda
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int
tune
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
)
const
{
// tuning meta parameters
const
int
hot_calls
=
40
;
std
::
vector
<
argument
>
input_args
;
std
::
transform
(
input_shapes
.
begin
(),
input_shapes
.
end
(),
std
::
back_inserter
(
input_args
),
[](
const
shape
&
x
)
{
return
to_gpu
(
generate_argument
(
x
));
});
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int
list_size
=
0
;
std
::
vector
<
rocblas_int
>
solution_indices
;
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
&
list_size
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
&
list_size
);
}
double
best_time
=
std
::
numeric_limits
<
double
>::
max
();
double
first_time
=
-
1
;
// Initialize to default solution index
rocblas_int
best_sol
=
0
;
for
(
auto
sol
:
solution_indices
)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run
(
ctx
,
input_args
,
sol
);
double
host_time
=
time
<
milliseconds
>
([
&
]
{
for
([[
maybe_unused
]]
int
hc
:
range
(
hot_calls
))
run
(
ctx
,
input_args
,
sol
);
ctx
.
finish
();
});
host_time
/=
hot_calls
;
// dev/evaluation only: track time for first solution.
if
(
first_time
<
0
)
first_time
=
host_time
;
// track current best
if
(
host_time
<
best_time
)
{
best_sol
=
sol
;
best_time
=
host_time
;
}
}
std
::
cout
<<
"Winning GEMM solution: "
<<
best_sol
<<
" in "
<<
best_time
<<
" ms, beats "
<<
first_time
<<
"ms"
<<
std
::
endl
;
return
best_sol
;
}
#endif
private:
size_t
num_matrices
=
0
;
rocblas_int
m
=
0
;
rocblas_int
n
=
0
;
rocblas_int
k
=
0
;
bool
transa
=
false
;
bool
transb
=
false
;
T
alpha
=
0
;
T
beta
=
0
;
std
::
function
<
const
void
*
()
>
get_alpha
{};
std
::
function
<
const
void
*
()
>
get_beta
{};
rocblas_gemm_flags
gemm_flags
=
rocblas_gemm_flags_none
;
rocblas_int
lda
=
0
;
rocblas_int
ldb
=
0
;
rocblas_int
ldc
=
0
;
rocblas_int
ldd
=
0
;
rocblas_int
a_stride
=
0
;
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
bool
compute_fp32
=
true
;
};
// gemm_impl
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
std
::
vector
<
shape
>
input_shapes
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
}
void
gemm
(
context
&
ctx
,
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
int32_t
alpha
,
float
beta
,
int32_t
beta
,
bool
compute_fp32
)
bool
compute_fp32
,
int32_t
solution_idx
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
std
::
vector
<
shape
>
input_shapes
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
}
void
gemm
(
context
&
ctx
,
/**
const
shape
&
output_shape
,
* Decides if the tune() or validate() method is appropriate and calls it.
const
std
::
vector
<
argument
>&
args
,
* Return value is the chosen solution index, or 0 to let picker choose it.
int32_t
alpha
,
*/
int32_t
beta
,
int32_t
gemm_finalize
(
context
&
ctx
,
bool
compute_fp32
)
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
float
beta
=
0
;
float
beta
=
0
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
unsigned
trans_batch
=
0
;
int32_t
solution_idx
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
pack
(
f
(
self
.
alpha
,
"alpha"
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
f
(
self
.
trans_batch
,
"trans_batch"
),
f
(
self
.
solution_idx
,
"solution_idx"
)));
}
}
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
// When input shapes are A, B, C the GEMM equation is C = α AB+ β C where α, β are
// scalars
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
{
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
)
{
{
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
gemm
_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
}
else
else
{
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
);
gemm_compute
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
return
args
.
back
();
return
args
.
back
();
}
}
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
enabled
(
MIGRAPHX_ENABLE_GEMM_TUNING
{})
or
ctx
.
get_exhaustive_tune_flag
())
{
if
(
this
->
name
()
==
"gpu::gemm"
)
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
#endif
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -24,26 +24,64 @@
...
@@ -24,26 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_GEMM_TUNING
);
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
void
gemm
(
context
&
ctx
,
/**
const
shape
&
output_shape
,
* @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
const
std
::
vector
<
argument
>&
args
,
* For each function there are overloads using either float or int32_t for the arguments
float
alpha
,
* alpha and beta.
float
beta
,
*
bool
compute_fp32
);
* @param ctx .
void
gemm
(
context
&
ctx
,
* @param output_shape .
const
shape
&
output_shape
,
* @param args .
const
std
::
vector
<
argument
>&
args
,
* @param alpha .
int32_t
alpha
,
* @param beta .
int32_t
beta
,
* @param compute_fp32 .
bool
compute_fp32
);
*/
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
compute_fp32
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
test/gpu/gemm_tune.cpp
0 → 100644
View file @
d7c8b66f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/gpu/gemm.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
// includes needed for run_lowering
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
// Abbreviated lowering; we don't need the usual cleanup passes for this test
void
run_lowering
(
migraphx
::
program
&
p
,
bool
offload_copy
=
false
)
{
auto
ctx
=
migraphx
::
gpu
::
context
{};
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
{
migraphx
::
auto_contiguous
{},
migraphx
::
gpu
::
lowering
{
&
ctx
,
offload_copy
}});
}
/**
* Tests the automatic GEMM tuning feature. In the finalize() method of the gemm op,
* rocBLAS API functions are called to quickly benchmark all the GEMM solutions
* available in the currently installed rocBLAS library and choose the index of the fastest.
*/
TEST_CASE
(
gemm_tune_with_rocblas
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
// gemm_op.to_value().debug_print();
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; invokes rocblas_gemm_strided_batched_ex
TEST_CASE
(
gemm_tune_strided
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"beta"
,
2
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; created by lowering
TEST_CASE
(
gemm_tune_strided_lowered
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
// At time of writing this test, gemm_impl considers a shape is strided if it has
// at least three dimensions and the 3rd-to-last is nonzero, invoking
// rocblas_gemm_strided_batched_ex. Also, DOT operator requires all dimensions except the last
// two to be equal.
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
5
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
TEST_CASE
(
gemm_tune_invalid_sol_index
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"solution_idx"
,
987654321
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// given invalid starting index, should return default 0
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/gemm_2args_mm_8.cpp
0 → 100644
View file @
d7c8b66f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
32
,
32
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bb
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment