Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d7c8b66f
Unverified
Commit
d7c8b66f
authored
Nov 07, 2023
by
Brian Pickrell
Committed by
GitHub
Nov 07, 2023
Browse files
Blas auto-tuning for GEMMs (#1668)
parent
4bd3f4e3
Changes
7
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
847 additions
and
166 deletions
+847
-166
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+12
-1
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+474
-145
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+37
-6
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+51
-13
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+1
-1
test/gpu/gemm_tune.cpp
test/gpu/gemm_tune.cpp
+225
-0
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+47
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
d7c8b66f
# ####################################################################################
# ####################################################################################
# The MIT License (MIT)
# The MIT License (MIT)
#
#
# Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
#
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# of this software and associated documentation files (the "Software"), to deal
...
@@ -245,10 +245,14 @@ else()
...
@@ -245,10 +245,14 @@ else()
endif
()
endif
()
# Check miopen find mode api
# Check miopen find mode api
include
(
CheckLibraryExists
)
include
(
CheckLibraryExists
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
MIOPEN_LOCATION MIOpen LOCATION
)
get_target_property
(
ROCBLAS_LOCATION roc::rocblas LOCATION
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenHiddenSetConvolutionFindMode"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_MODE_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
check_library_exists
(
MIOpen
"miopenFindSolutions"
"
${
MIOPEN_LOCATION
}
"
HAS_FIND_2_API
)
# Beta API for automated GEMM tuning
check_library_exists
(
roc::rocblas
"rocblas_gemm_ex_get_solutions"
"
${
ROCBLAS_LOCATION
}
"
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
set
(
MIGRAPHX_USE_FIND_2_API
"
${
HAS_FIND_2_API
}
"
CACHE BOOL
""
)
...
@@ -271,6 +275,13 @@ else()
...
@@ -271,6 +275,13 @@ else()
message
(
STATUS
"MIOpen does not have find mode api"
)
message
(
STATUS
"MIOpen does not have find mode api"
)
endif
()
endif
()
if
(
HAS_ROCBLAS_TUNING_BETA_FEATURE_API
)
target_compile_definitions
(
migraphx_gpu PUBLIC -DMIGRAPHX_USE_ROCBLAS_TUNING_API -DROCBLAS_BETA_FEATURES_API -DROCBLAS_NO_DEPRECATED_WARNINGS
)
message
(
STATUS
"MIGraphx is using Beta API of rocBLAS"
)
else
()
message
(
STATUS
"rocBLAS does not have User Tuning Beta API"
)
endif
()
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PUBLIC migraphx MIOpen roc::rocblas
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
target_link_libraries
(
migraphx_gpu PRIVATE migraphx_device migraphx_kernels
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
if
(
MIGRAPHX_USE_COMPOSABLEKERNEL
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
d7c8b66f
This diff is collapsed.
Click to expand it.
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
float
beta
=
0
;
float
beta
=
0
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
unsigned
trans_batch
=
0
;
int32_t
solution_idx
=
0
;
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
pack
(
f
(
self
.
alpha
,
"alpha"
),
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
f
(
self
.
trans_batch
,
"trans_batch"
),
f
(
self
.
solution_idx
,
"solution_idx"
)));
}
}
std
::
string
name
()
const
std
::
string
name
()
const
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
// When input shapes are A, B, C the GEMM equation is C = α AB+ β C where α, β are
// scalars
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
blas_shape
(
inputs
[
1
]);
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
{
{
if
(
this
->
name
()
==
"gpu::gemm"
)
if
(
this
->
name
()
==
"gpu::gemm"
)
{
{
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
gemm
_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
}
else
else
{
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
);
gemm_compute
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
return
args
.
back
();
return
args
.
back
();
}
}
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
enabled
(
MIGRAPHX_ENABLE_GEMM_TUNING
{})
or
ctx
.
get_exhaustive_tune_flag
())
{
if
(
this
->
name
()
==
"gpu::gemm"
)
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
#endif
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
@@ -24,26 +24,64 @@
...
@@ -24,26 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_GEMM_TUNING
);
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
void
gemm
(
context
&
ctx
,
/**
const
shape
&
output_shape
,
* @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
const
std
::
vector
<
argument
>&
args
,
* For each function there are overloads using either float or int32_t for the arguments
float
alpha
,
* alpha and beta.
float
beta
,
*
bool
compute_fp32
);
* @param ctx .
void
gemm
(
context
&
ctx
,
* @param output_shape .
const
shape
&
output_shape
,
* @param args .
const
std
::
vector
<
argument
>&
args
,
* @param alpha .
int32_t
alpha
,
* @param beta .
int32_t
beta
,
* @param compute_fp32 .
bool
compute_fp32
);
*/
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
compute_fp32
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
d7c8b66f
/*
/*
* The MIT License (MIT)
* The MIT License (MIT)
*
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* of this software and associated documentation files (the "Software"), to deal
...
...
test/gpu/gemm_tune.cpp
0 → 100644
View file @
d7c8b66f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/gpu/gemm.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
// includes needed for run_lowering
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
// Abbreviated lowering; we don't need the usual cleanup passes for this test
void
run_lowering
(
migraphx
::
program
&
p
,
bool
offload_copy
=
false
)
{
auto
ctx
=
migraphx
::
gpu
::
context
{};
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
{
migraphx
::
auto_contiguous
{},
migraphx
::
gpu
::
lowering
{
&
ctx
,
offload_copy
}});
}
/**
* Tests the automatic GEMM tuning feature. In the finalize() method of the gemm op,
* rocBLAS API functions are called to quickly benchmark all the GEMM solutions
* available in the currently installed rocBLAS library and choose the index of the fastest.
*/
TEST_CASE
(
gemm_tune_with_rocblas
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
// gemm_op.to_value().debug_print();
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; invokes rocblas_gemm_strided_batched_ex
TEST_CASE
(
gemm_tune_strided
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"beta"
,
2
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; created by lowering
TEST_CASE
(
gemm_tune_strided_lowered
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
// At time of writing this test, gemm_impl considers a shape is strided if it has
// at least three dimensions and the 3rd-to-last is nonzero, invoking
// rocblas_gemm_strided_batched_ex. Also, DOT operator requires all dimensions except the last
// two to be equal.
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
5
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
TEST_CASE
(
gemm_tune_invalid_sol_index
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"solution_idx"
,
987654321
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// given invalid starting index, should return default 0
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/gemm_2args_mm_8.cpp
0 → 100644
View file @
d7c8b66f
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
gemm_2args_mm_8
:
verify_program
<
gemm_2args_mm_8
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
a_shape
{
migraphx
::
shape
::
float_type
,
{
2
,
128
,
32
},
{
4096
,
1
,
128
}};
migraphx
::
shape
b_shape
{
migraphx
::
shape
::
float_type
,
{
32
,
32
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
a_shape
);
auto
b
=
mm
->
add_parameter
(
"b"
,
b_shape
);
auto
bb
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
{
2
,
32
,
32
}}}),
b
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
bb
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment