Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2c8ba41a
Commit
2c8ba41a
authored
Nov 09, 2023
by
charlie
Browse files
Merge branch 'develop' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into simplify_dyn_reshape
parents
5d998fb2
6dc96db7
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
996 additions
and
224 deletions
+996
-224
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+474
-145
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+37
-6
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
+51
-13
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+3
-0
test/gpu/codegen_literal.cpp
test/gpu/codegen_literal.cpp
+1
-1
test/gpu/gemm_tune.cpp
test/gpu/gemm_tune.cpp
+225
-0
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+53
-3
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+50
-11
test/onnx/reshape_variable_input_test0.onnx
test/onnx/reshape_variable_input_test0.onnx
+17
-0
test/onnx/reshape_variable_input_test1.onnx
test/onnx/reshape_variable_input_test1.onnx
+0
-0
test/onnx/round_half_test.onnx
test/onnx/round_half_test.onnx
+13
-0
test/onnx/slice_var_input_default_steps.onnx
test/onnx/slice_var_input_default_steps.onnx
+0
-0
test/onnx/upsample_ver7_test.onnx
test/onnx/upsample_ver7_test.onnx
+0
-0
test/onnx/verify_onnx.cpp
test/onnx/verify_onnx.cpp
+37
-0
test/py/CMakeLists.txt
test/py/CMakeLists.txt
+13
-5
test/py/onnx_backend_test.py
test/py/onnx_backend_test.py
+0
-1
test/py/requirements.txt
test/py/requirements.txt
+0
-25
test/py/test_gpu.py
test/py/test_gpu.py
+20
-12
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
2c8ba41a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -21,15 +21,20 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <rocblas/rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/time.hpp>
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
switch
(
type
)
...
...
@@ -81,184 +86,508 @@ shape transpose_batch(const shape& s, unsigned trans_batch)
return
shape
::
from_permutation
(
s
.
type
(),
s
.
lens
(),
perm
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
/**
* Returns results of rocblas_status_success, rocblas_status_perf_degraded,
* or rocblas_status_invalid_value. Caller
* is expected to check for invalid index. Any other result causes an exception.
*
*/
template
<
class
F
,
class
Pack
,
class
...
Ts
>
auto
rocblas_invoke
(
F
f
,
Pack
p
,
Ts
...
xs
)
{
if
constexpr
(
sizeof
...(
Ts
)
==
sizeof
...(
Us
))
return
f
(
xs
...);
else
return
f
(
xs
...,
nullptr
,
nullptr
);
return
p
([
=
](
auto
...
ws
)
{
auto
status
=
f
(
ws
...,
xs
...);
if
(
status
!=
rocblas_status_success
and
status
!=
rocblas_status_invalid_value
)
{
if
(
status
==
rocblas_status_perf_degraded
)
{
std
::
cerr
<<
"WARNING: degraded perf. in rocBLAS call"
<<
std
::
endl
;
}
else
MIGRAPHX_THROW
(
"rocblas_invoke: rocBLAS call failed with status "
+
std
::
to_string
(
status
));
}
return
status
;
});
}
static
bool
is_transposed
(
const
shape
&
s
)
{
if
(
not
s
.
transposed
())
return
false
;
return
s
.
strides
().
back
()
!=
1
;
}
static
bool
is_transposed
(
const
shape
&
s
)
{
return
s
.
transposed
()
and
s
.
strides
().
back
()
!=
1
;
}
static
rocblas_int
get_batch_stride
(
const
argument
&
a
)
static
rocblas_int
get_batch_stride
(
const
shape
&
s
)
{
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
// This value is not needed for non-strided inputs
if
(
s
.
strides
().
size
()
<
3
)
return
0
;
else
return
s
.
strides
()[
s
.
strides
().
size
()
-
3
];
}
template
<
class
T
>
void
gemm_impl
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
T
alpha
,
T
beta
,
bool
compute_fp32
)
/**
* Wrapper for multiple rocBLAS calls. The constructor creates parameters for
* these calls based on data shapes and other values contained in the associated
* instruction and operation.
*
* The template parameter T is not the type of the matrix data but of the weighting
* coefficients alpha and beta (these are float in rocBLAS internals)
*/
template
<
typename
T
>
struct
gemm_impl
{
const
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
not
is_3inputs
)
{
beta
=
0
;
}
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldd
=
is_3inputs
?
args
[
3
].
get_shape
().
strides
()[
dim_0
]
:
ldc
;
rocblas_datatype
arg_type
=
get_type
(
args
[
0
].
get_shape
().
type
());
auto
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
auto
compute_type
=
output_type
;
if
(
compute_fp32
)
gemm_impl
(
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
T
alpha_param
,
T
beta_param
,
bool
compute_fp32_flag
)
:
alpha
(
alpha_param
),
beta
(
beta_param
),
is_3inputs
(
input_shapes
.
size
()
==
4
),
compute_fp32
(
compute_fp32_flag
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
if
(
not
is_3inputs
)
{
beta
=
0
;
}
rocblas_gemm_flags
flag
=
rocblas_gemm_flags_none
;
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
// Create lambdas that will cast alpha, beta to the output shape's type
// and retain the values being pointed to
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
as
(
alpha
);
auto
beta_r
=
as
(
beta
);
if
(
compute_fp32
)
{
get_alpha
=
[
=
]
{
return
&
alpha
;
};
get_beta
=
[
=
]
{
return
&
beta
;
};
}
else
{
get_alpha
=
[
=
]
{
return
&
alpha_r
;
};
get_beta
=
[
=
]
{
return
&
beta_r
;
};
}
});
// use void pointer to select different data type if using fp32 mode
void
*
alpha_v
=
&
alpha_r
;
void
*
beta_v
=
&
beta_r
;
transa
=
is_transposed
(
input_shapes
[
0
]);
transb
=
is_transposed
(
input_shapes
[
1
]);
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_0
=
n_dim
-
2
;
auto
dim_1
=
n_dim
-
1
;
// Leading dimensions of matrices
lda
=
input_shapes
[
0
].
strides
()[
transa
?
dim_1
:
dim_0
];
ldb
=
input_shapes
[
1
].
strides
()[
transb
?
dim_1
:
dim_0
];
ldc
=
input_shapes
[
2
].
strides
()[
dim_0
];
ldd
=
is_3inputs
?
input_shapes
[
3
].
strides
()[
dim_0
]
:
ldc
;
arg_type
=
get_type
(
input_shapes
[
0
].
type
());
output_type
=
arg_type
;
if
(
output_type
==
rocblas_datatype_i8_r
)
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
if
(
compute_fp32
)
{
alpha_v
=
&
alpha
;
beta_v
=
&
beta
;
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
auto
num_matrices
=
std
::
accumulate
(
auto
out_lens
=
output_shape
.
lens
();
m
=
out_lens
[
dim_0
];
n
=
out_lens
[
dim_1
];
k
=
input_shapes
[
0
].
lens
()[
dim_1
];
a_stride
=
get_batch_stride
(
input_shapes
[
0
]);
b_stride
=
get_batch_stride
(
input_shapes
[
1
]);
c_stride
=
get_batch_stride
(
input_shapes
[
2
]);
d_stride
=
is_3inputs
?
get_batch_stride
(
input_shapes
[
3
])
:
c_stride
;
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
get_batch_stride
(
args
[
1
])
==
0
))
strided_batched
=
num_matrices
>
1
;
if
(
strided_batched
and
b_stride
==
0
and
input_shapes
[
0
].
standard
())
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
strided_batched
=
false
;
}
}
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
void
run
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
=
0
)
const
{
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
alpha_v
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
beta_v
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldd
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
auto
validate
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
solution_idx
)
const
{
// Create dummy arguments for the shapes, and call the overloaded method
std
::
vector
<
argument
>
input_args
;
std
::
transform
(
input_shapes
.
begin
(),
input_shapes
.
end
(),
std
::
back_inserter
(
input_args
),
[](
const
shape
&
x
)
{
return
to_gpu
(
generate_argument
(
x
));
});
return
validate
(
ctx
,
input_args
,
solution_idx
);
}
/**
* Checks a particular solution for validity by running it with the flag
* rocblas_gemm_flags_check_solution_index (could be invalid if this model was
* tuned with a different rocBLAS version)
*
* @return Returns either solution_idx if valid, or else the default value 0
* if not. The default does not mean list index 0, but tells the picker
* to choose a solution.
*/
int32_t
validate
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
input_args
,
int32_t
solution_idx
)
const
{
rocblas_status_
check_valid
(
rocblas_status_success
);
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
}
else
{
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
auto
d_stride
=
is_3inputs
?
get_batch_stride
(
args
[
3
])
:
c_stride
;
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
alpha_v
,
to_pointer
(
args
.
at
(
1
)),
arg_type
,
ldb
,
b_stride
,
to_pointer
(
args
.
at
(
0
)),
arg_type
,
lda
,
a_stride
,
beta_v
,
to_pointer
(
args
[
2
]),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
ldd
,
d_stride
,
num_matrices
,
compute_type
,
rocblas_gemm_algo_standard
,
0
,
flag
);
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
}
});
if
(
check_valid
==
rocblas_status_invalid_value
)
{
std
::
cerr
<<
"WARNING: tuned solution is invalid; reverting to default"
<<
std
::
endl
;
return
0
;
}
return
solution_idx
;
}
#endif
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "...strided_batched..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
*/
auto
create_strided_batched_args_common
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
b_stride
,
args
[
0
].
data
(),
arg_type
,
lda
,
a_stride
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
c_stride
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
d_stride
,
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
* to multiple "gemm_ex..." calls.
*
* The rocblas_gemm API handles inputs and output matrices as
* column-major format. When doing a C = A * B, we actually do
* C^T = (B^T) * (A^T). That is the reason we input args[1] as
* A and args[0] as B in calling the rocblas_gemm.
*
* */
auto
create_gemm_ex_args_common
(
context
&
ctx
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
pack
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
get_alpha
(),
args
[
1
].
data
(),
arg_type
,
ldb
,
args
[
0
].
data
(),
arg_type
,
lda
,
get_beta
(),
args
[
2
].
data
(),
output_type
,
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
/**
* Find best rocBLAS solution: Get list of solutions and try them all, returning the index
* of the fastest one.
*/
int
tune
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
input_shapes
)
const
{
// tuning meta parameters
const
int
hot_calls
=
40
;
std
::
vector
<
argument
>
input_args
;
std
::
transform
(
input_shapes
.
begin
(),
input_shapes
.
end
(),
std
::
back_inserter
(
input_args
),
[](
const
shape
&
x
)
{
return
to_gpu
(
generate_argument
(
x
));
});
// Get the solutions list in 2 rocBLAS steps:
// 1. Find out how many solutions there are and allocate the array
// 2. Get the solutions
//
rocblas_int
list_size
=
0
;
std
::
vector
<
rocblas_int
>
solution_indices
;
if
(
strided_batched
)
{
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
&
list_size
);
}
else
{
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
&
list_size
);
solution_indices
.
resize
(
list_size
);
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
&
list_size
);
}
double
best_time
=
std
::
numeric_limits
<
double
>::
max
();
double
first_time
=
-
1
;
// Initialize to default solution index
rocblas_int
best_sol
=
0
;
for
(
auto
sol
:
solution_indices
)
{
// Warmup: the first call to an op. may not be representative since there is
// more time taken initializing caches, etc. so we won't time it.
run
(
ctx
,
input_args
,
sol
);
double
host_time
=
time
<
milliseconds
>
([
&
]
{
for
([[
maybe_unused
]]
int
hc
:
range
(
hot_calls
))
run
(
ctx
,
input_args
,
sol
);
ctx
.
finish
();
});
host_time
/=
hot_calls
;
// dev/evaluation only: track time for first solution.
if
(
first_time
<
0
)
first_time
=
host_time
;
// track current best
if
(
host_time
<
best_time
)
{
best_sol
=
sol
;
best_time
=
host_time
;
}
}
std
::
cout
<<
"Winning GEMM solution: "
<<
best_sol
<<
" in "
<<
best_time
<<
" ms, beats "
<<
first_time
<<
"ms"
<<
std
::
endl
;
return
best_sol
;
}
#endif
private:
size_t
num_matrices
=
0
;
rocblas_int
m
=
0
;
rocblas_int
n
=
0
;
rocblas_int
k
=
0
;
bool
transa
=
false
;
bool
transb
=
false
;
T
alpha
=
0
;
T
beta
=
0
;
std
::
function
<
const
void
*
()
>
get_alpha
{};
std
::
function
<
const
void
*
()
>
get_beta
{};
rocblas_gemm_flags
gemm_flags
=
rocblas_gemm_flags_none
;
rocblas_int
lda
=
0
;
rocblas_int
ldb
=
0
;
rocblas_int
ldc
=
0
;
rocblas_int
ldd
=
0
;
rocblas_int
a_stride
=
0
;
rocblas_int
b_stride
=
0
;
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
bool
compute_fp32
=
true
;
};
// gemm_impl
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
std
::
vector
<
shape
>
input_shapes
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
)
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
std
::
vector
<
shape
>
input_shapes
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
input_shapes
),
[](
const
argument
&
x
)
{
return
x
.
get_shape
();
});
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
gemm_item
.
run
(
ctx
,
args
,
solution_idx
);
}
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
// This code should be called only if either the environment var.
// MIGRAPHX_ENABLE_GEMM_TUNING, or option --exhaustive-tune, is set
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
float
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
)
/**
* Decides if the tune() or validate() method is appropriate and calls it.
* Return value is the chosen solution index, or 0 to let picker choose it.
*/
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
)
{
gemm_impl
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
solution_idx
==
0
)
{
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
tune
(
ctx
,
input_shapes
);
}
else
{
// If a tuned solution index is already given, don't tune again but validate
// in case the data was tuned with a different rocBLAS version
auto
gemm_item
=
gemm_impl
<
int32_t
>
(
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
);
solution_idx
=
gemm_item
.
validate
(
ctx
,
input_shapes
,
solution_idx
);
}
#else
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
(
void
)
alpha
,
(
void
)
beta
,
(
void
)
compute_fp32
;
#endif
return
solution_idx
;
}
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
2c8ba41a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -40,9 +40,8 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
struct
context
;
void
blas_shape
(
const
shape
&
s
);
shape
transpose_batch
(
const
shape
&
s
,
unsigned
trans_batch
);
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
struct
rocblas_gemm
...
...
@@ -52,6 +51,7 @@ struct rocblas_gemm
float
beta
=
0
;
bool
compute_fp32
=
false
;
unsigned
trans_batch
=
0
;
int32_t
solution_idx
=
0
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
@@ -60,7 +60,8 @@ struct rocblas_gemm
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
),
f
(
self
.
compute_fp32
,
"compute_fp32"
),
f
(
self
.
trans_batch
,
"trans_batch"
)));
f
(
self
.
trans_batch
,
"trans_batch"
),
f
(
self
.
solution_idx
,
"solution_idx"
)));
}
std
::
string
name
()
const
...
...
@@ -76,6 +77,8 @@ struct rocblas_gemm
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
// When input shapes are A, B, C the GEMM equation is C = α AB+ β C where α, β are
// scalars
check_shapes
{
in_shapes
,
*
this
}.
has
(
2
,
3
);
blas_shape
(
inputs
[
0
]);
blas_shape
(
inputs
[
1
]);
...
...
@@ -111,11 +114,12 @@ struct rocblas_gemm
{
if
(
this
->
name
()
==
"gpu::gemm"
)
{
gemm
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
);
gemm
_compute
(
ctx
,
output_shape
,
args
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
gemm
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
);
gemm_compute
(
ctx
,
output_shape
,
args
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
return
args
.
back
();
}
...
...
@@ -124,6 +128,33 @@ struct rocblas_gemm
{
return
shapes
.
size
()
-
1
;
}
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
)
{
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
if
(
enabled
(
MIGRAPHX_ENABLE_GEMM_TUNING
{})
or
ctx
.
get_exhaustive_tune_flag
())
{
if
(
this
->
name
()
==
"gpu::gemm"
)
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
alpha
,
beta
,
compute_fp32
,
solution_idx
);
}
else
{
solution_idx
=
gemm_finalize
(
ctx
,
output_shape
,
input_shapes
,
int32_t
(
alpha
),
int32_t
(
beta
),
compute_fp32
,
solution_idx
);
}
}
#else
// suppress compiler warnings
(
void
)
ctx
,
(
void
)
output_shape
,
(
void
)
input_shapes
;
#endif
}
};
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/gemm_impl.hpp
View file @
2c8ba41a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -24,26 +24,64 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_IMPL_HPP
#include <iterator>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/context.hpp>
// Set this environment variable to "true" to perform GEMM tuning even when the
// --exhaustive-tune option isn't set. Can be used to skip slow convolution tuning.
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_GEMM_TUNING
);
using
milliseconds
=
std
::
chrono
::
duration
<
double
,
std
::
milli
>
;
using
microseconds
=
std
::
chrono
::
duration
<
double
,
std
::
micro
>
;
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
);
void
gemm
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
);
/**
* @brief Templated implementations of the compute() and finalize() methods of the Gemm operator.
* For each function there are overloads using either float or int32_t for the arguments
* alpha and beta.
*
* @param ctx .
* @param output_shape .
* @param args .
* @param alpha .
* @param beta .
* @param compute_fp32 .
*/
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
float
alpha
,
float
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
void
gemm_compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
float
alpha
,
float
beta
,
bool
compute_fp32
);
int32_t
gemm_finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
input_shapes
,
int32_t
alpha
,
int32_t
beta
,
bool
compute_fp32
,
int32_t
solution_idx
);
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/include/migraphx/gpu/rocblas.hpp
View file @
2c8ba41a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
2c8ba41a
...
...
@@ -103,6 +103,7 @@ MIGRAPHX_DEVICE_MATH(floor, ::floor)
MIGRAPHX_DEVICE_MATH
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH
(
isinf
,
::
isinf
)
MIGRAPHX_DEVICE_MATH
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH
(
nearbyint
,
::
nearbyint
)
MIGRAPHX_DEVICE_MATH
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH
(
round
,
::
round
)
...
...
@@ -152,6 +153,7 @@ MIGRAPHX_DEVICE_MATH_HALF(atan, ::atan)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
nearbyint
,
::
nearbyint
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
...
...
@@ -236,6 +238,7 @@ MIGRAPHX_DEVICE_MATH_VEC(isnan)
MIGRAPHX_DEVICE_MATH_VEC
(
log
)
MIGRAPHX_DEVICE_MATH_VEC
(
max
)
MIGRAPHX_DEVICE_MATH_VEC
(
min
)
MIGRAPHX_DEVICE_MATH_VEC
(
nearbyint
)
MIGRAPHX_DEVICE_MATH_VEC
(
pow
)
MIGRAPHX_DEVICE_MATH_VEC
(
remainder
)
MIGRAPHX_DEVICE_MATH_VEC
(
round
)
...
...
test/gpu/codegen_literal.cpp
View file @
2c8ba41a
...
...
@@ -64,7 +64,7 @@ TEST_CASE(mul_literal_round_test)
auto
l1
=
mm
->
add_literal
(
1
/
0.00787402
f
);
auto
mul
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
l0
,
l1
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
round
"
),
mul
);
auto
round
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"
nearbyint
"
),
mul
);
mm
->
add_return
({
round
});
...
...
test/gpu/gemm_tune.cpp
0 → 100644
View file @
2c8ba41a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iostream>
#include <vector>
#include <migraphx/gpu/gemm.hpp>
#include <hip/hip_runtime_api.h>
#include <migraphx/gpu/target.hpp>
#include <migraphx/verify.hpp>
#include <test.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/iterator_for.hpp>
// includes needed for run_lowering
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/auto_contiguous.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/pass_manager.hpp>
// Abbreviated lowering; we don't need the usual cleanup passes for this test
void
run_lowering
(
migraphx
::
program
&
p
,
bool
offload_copy
=
false
)
{
auto
ctx
=
migraphx
::
gpu
::
context
{};
migraphx
::
run_passes
(
*
p
.
get_main_module
(),
{
migraphx
::
auto_contiguous
{},
migraphx
::
gpu
::
lowering
{
&
ctx
,
offload_copy
}});
}
/**
* Tests the automatic GEMM tuning feature. In the finalize() method of the gemm op,
* rocBLAS API functions are called to quickly benchmark all the GEMM solutions
* available in the currently installed rocBLAS library and choose the index of the fastest.
*/
TEST_CASE
(
gemm_tune_with_rocblas
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
// gemm_op.to_value().debug_print();
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; invokes rocblas_gemm_strided_batched_ex
TEST_CASE
(
gemm_tune_strided
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
2
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"beta"
,
2
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
// GEMM tuning of a strided-batch matrix; created by lowering
TEST_CASE
(
gemm_tune_strided_lowered
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
// At time of writing this test, gemm_impl considers a shape is strided if it has
// at least three dimensions and the 3rd-to-last is nonzero, invoking
// rocblas_gemm_strided_batched_ex. Also, DOT operator requires all dimensions except the last
// two to be equal.
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
,
5
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
4
,
5
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
migraphx
::
operation
dot_op
=
migraphx
::
make_op
(
"dot"
);
mm
->
add_instruction
(
dot_op
,
a
,
b
);
// lowering adds gemm implementation for dot operator
run_lowering
(
p
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
// tuned solution index is not deterministic, but anything other than 0
// (default, invalid, or not available) is good.
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
TEST_CASE
(
gemm_tune_invalid_sol_index
)
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
sa
{
migraphx
::
shape
::
float_type
,
{
4
,
2
}};
migraphx
::
shape
sb
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}};
migraphx
::
shape
s_output
{
migraphx
::
shape
::
float_type
,
{
4
,
3
}};
auto
a
=
mm
->
add_parameter
(
"a"
,
sa
);
auto
b
=
mm
->
add_parameter
(
"b"
,
sb
);
auto
output
=
mm
->
add_parameter
(
"out"
,
s_output
);
auto
gemm_oper
=
migraphx
::
make_op
(
"gpu::gemm"
,
{{
"solution_idx"
,
987654321
}});
mm
->
add_instruction
(
gemm_oper
,
a
,
b
,
output
);
migraphx
::
target
gpu_t
=
migraphx
::
gpu
::
target
{};
migraphx
::
compile_options
options
;
options
.
exhaustive_tune
=
true
;
p
.
compile
(
gpu_t
,
options
);
migraphx
::
value
solution_idx
(
0
);
for
(
auto
ins
:
iterator_for
(
*
p
.
get_main_module
()))
{
if
(
ins
->
name
()
==
"gpu::gemm"
)
{
auto
gemm_op
=
migraphx
::
get_operation
(
ins
);
auto
gemmv
=
gemm_op
.
to_value
();
// given invalid starting index, should return default 0
solution_idx
=
gemm_op
.
to_value
()[
"solution_idx"
];
break
;
}
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
EXPECT
(
0
==
solution_idx
.
to
<
std
::
size_t
>
());
#else
EXPECT
(
0
!=
solution_idx
.
to
<
std
::
size_t
>
());
#endif
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/onnx/.onnxrt-commit
View file @
2c8ba41a
2eeafc37bca21dc8bf337dda7020b486543162d7
b7b8b5b2ce80edb33990c7ae0fedac6ae3c623f4
test/onnx/gen_onnx.py
View file @
2c8ba41a
...
...
@@ -7087,6 +7087,16 @@ def roialign_test():
return
([
node
],
[
x
,
roi
,
bi
],
[
y
])
@
onnx_test
()
def
round_half_test
():
x
=
helper
.
make_tensor_value_info
(
'x'
,
TensorProto
.
FLOAT16
,
[
4
,
4
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
FLOAT16
,
[
4
,
4
])
node
=
onnx
.
helper
.
make_node
(
'Round'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
])
return
([
node
],
[
x
],
[
y
])
@
onnx_test
()
def
scatter_add_test
():
x
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
4
,
5
,
6
])
...
...
@@ -8006,6 +8016,32 @@ def slice_var_input_dyn1():
return
([
node
],
[
data
,
starts
,
ends
,
axes
],
[
output
])
@
onnx_test
()
def
slice_var_input_default_steps
():
step
=
np
.
array
([
1
,
1
])
step_tensor
=
helper
.
make_tensor
(
name
=
"step"
,
data_type
=
TensorProto
.
INT64
,
dims
=
step
.
shape
,
vals
=
step
.
astype
(
int
))
arg_step
=
helper
.
make_node
(
"Constant"
,
inputs
=
[],
outputs
=
[
'arg_step'
],
value
=
step_tensor
)
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
None
,
2
])
starts
=
helper
.
make_tensor_value_info
(
'starts'
,
TensorProto
.
INT64
,
[
2
])
ends
=
helper
.
make_tensor_value_info
(
'ends'
,
TensorProto
.
INT64
,
[
2
])
axes
=
helper
.
make_tensor_value_info
(
'axes'
,
TensorProto
.
INT64
,
[
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
1
,
2
])
node
=
onnx
.
helper
.
make_node
(
'Slice'
,
inputs
=
[
'data'
,
'starts'
,
'ends'
,
'axes'
,
'arg_step'
],
outputs
=
[
'output'
])
return
([
arg_step
,
node
],
[
data
,
starts
,
ends
,
axes
],
[
output
])
@
onnx_test
()
def
slice_var_input_steps_error
():
step
=
np
.
array
([
2
,
1
])
...
...
@@ -8019,9 +8055,9 @@ def slice_var_input_steps_error():
value
=
step_tensor
)
data
=
helper
.
make_tensor_value_info
(
'data'
,
TensorProto
.
FLOAT
,
[
3
,
2
])
starts
=
helper
.
make_tensor_value_info
(
'starts'
,
TensorProto
.
FLOAT
,
[
2
])
ends
=
helper
.
make_tensor_value_info
(
'ends'
,
TensorProto
.
FLOAT
,
[
2
])
axes
=
helper
.
make_tensor_value_info
(
'axes'
,
TensorProto
.
FLOAT
,
[
2
])
starts
=
helper
.
make_tensor_value_info
(
'starts'
,
TensorProto
.
INT64
,
[
2
])
ends
=
helper
.
make_tensor_value_info
(
'ends'
,
TensorProto
.
INT64
,
[
2
])
axes
=
helper
.
make_tensor_value_info
(
'axes'
,
TensorProto
.
INT64
,
[
2
])
output
=
helper
.
make_tensor_value_info
(
'output'
,
TensorProto
.
FLOAT
,
[
1
,
2
])
node
=
onnx
.
helper
.
make_node
(
...
...
@@ -9031,6 +9067,20 @@ def upsample_test():
return
([
node
],
[
X
],
[
Y
],
[
scale_tensor
])
@
onnx_test
()
def
upsample_ver7_test
():
X
=
helper
.
make_tensor_value_info
(
'X'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
2
,
2
])
Y
=
helper
.
make_tensor_value_info
(
'Y'
,
TensorProto
.
FLOAT
,
[
1
,
1
,
4
,
6
])
node
=
onnx
.
helper
.
make_node
(
'Upsample'
,
inputs
=
[
'X'
],
outputs
=
[
'Y'
],
mode
=
'nearest'
,
scales
=
[
1.0
,
1.0
,
2.0
,
3.0
])
return
([
node
],
[
X
],
[
Y
])
@
onnx_test
()
def
variable_batch_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
...
...
test/onnx/onnx_test.cpp
View file @
2c8ba41a
...
...
@@ -5788,9 +5788,9 @@ TEST_CASE(quantizelinear_test)
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto
round
= mm->add_instruction(migraphx::make_op("
round
"), div);
auto s
= round
->get_shape();
auto clip = insert_quantizelinear_clip(*mm, div,
round
, s, 0, 255);
auto
nearbyint
= mm->add_instruction(migraphx::make_op("
nearbyint
"), div);
auto s
= nearbyint
->get_shape();
auto clip
= insert_quantizelinear_clip(*mm, div,
nearbyint
, s, 0, 255);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
...
@@ -5813,9 +5813,9 @@ TEST_CASE(quantizelinear_int32_test)
{{"target_type", migraphx::to_value(migraphx::shape::float_type)}}),
l0);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto
round
= mm->add_instruction(migraphx::make_op("
round
"), div);
auto s
= round
->get_shape();
auto clip = insert_quantizelinear_clip(*mm, div,
round
, s, 0, 255);
auto
nearbyint
= mm->add_instruction(migraphx::make_op("
nearbyint
"), div);
auto s
= nearbyint
->get_shape();
auto clip
= insert_quantizelinear_clip(*mm, div,
nearbyint
, s, 0, 255);
mm->add_instruction(
migraphx::make_op("convert",
{{"target_type", migraphx::to_value(migraphx::shape::uint8_type)}}),
...
...
@@ -5835,7 +5835,7 @@ TEST_CASE(quantizelinear_zero_point_test)
auto l1_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_mbcast);
auto round = mm->add_instruction(migraphx::make_op("
round
"), div);
auto round = mm->add_instruction(migraphx::make_op("
nearbyint
"), div);
auto l2_mbcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {5}}}), l2);
l2_mbcast = mm->add_instruction(
...
...
@@ -5868,7 +5868,7 @@ migraphx::program make_quantizelinear_axis_prog()
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l1);
auto div = mm->add_instruction(migraphx::make_op("div"), l0, l1_bcast);
auto round = mm->add_instruction(migraphx::make_op("
round
"), div);
auto round = mm->add_instruction(migraphx::make_op("
nearbyint
"), div);
auto l2_bcast = mm->add_instruction(
migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", input_lens}}), l2);
l2_bcast = mm->add_instruction(
...
...
@@ -6557,9 +6557,8 @@ TEST_CASE(resize_nonstd_input_test)
auto tx =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), inx);
mm->add_instruction(migraphx::make_op("undefined"));
auto tx_cont = mm->add_instruction(migraphx::make_op("contiguous"), tx);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx
_cont
);
auto lrsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {8}}}), tx);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), lrsp, li);
mm->add_return({r});
...
...
@@ -6998,7 +6997,7 @@ TEST_CASE(round_test)
migraphx::program p;
auto* mm = p.get_main_module();
auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::double_type, {10, 5}});
mm->add_instruction(migraphx::make_op("
round
"), input);
mm->add_instruction(migraphx::make_op("
nearbyint
"), input);
auto prog = optimize_onnx("round_test.onnx");
EXPECT(p == prog);
...
...
@@ -7654,6 +7653,25 @@ TEST_CASE(slice_var_input_dyn1)
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_default_steps)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data =
mm->add_parameter("data", migraphx::shape{migraphx::shape::float_type, {{3, 8}, {2, 2}}});
auto starts = mm->add_parameter("starts", migraphx::shape{migraphx::shape::int64_type, {2}});
auto ends = mm->add_parameter("ends", migraphx::shape{migraphx::shape::int64_type, {2}});
auto axes = mm->add_parameter("axes", migraphx::shape{migraphx::shape::int64_type, {2}});
mm->add_literal({{migraphx::shape::int64_type, {2}}, {1, 1}});
auto ret = mm->add_instruction(migraphx::make_op("slice"), data, starts, ends, axes);
mm->add_return({ret});
migraphx::onnx_options options;
options.default_dyn_dim_value = {3, 8};
auto prog = parse_onnx("slice_var_input_default_steps.onnx", options);
EXPECT(p == prog);
}
TEST_CASE(slice_var_input_steps_error)
{
EXPECT(test::throws([&] { migraphx::parse_onnx("slice_var_input_steps_error.onnx"); }));
...
...
@@ -8418,6 +8436,27 @@ TEST_CASE(upsample_test)
EXPECT(p == prog);
}
TEST_CASE(upsample_ver7_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape sx{migraphx::shape::float_type, {1, 1, 2, 2}};
auto ix = mm->add_parameter("X", sx);
migraphx::shape si{migraphx::shape::int32_type, {1, 1, 4, 6}};
std::vector<int> ind = {0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3};
auto li = mm->add_literal(migraphx::literal(si, ind));
auto rsp = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4}}}), ix);
auto r = mm->add_instruction(migraphx::make_op("gather", {{"axis", 0}}), rsp, li);
mm->add_return({r});
auto prog = migraphx::parse_onnx("upsample_ver7_test.onnx");
EXPECT(p == prog);
}
TEST_CASE(unknown_test_throw_print_error)
{
migraphx::onnx_options options;
...
...
test/onnx/reshape_variable_input_test0.onnx
0 → 100644
View file @
2c8ba41a
reshape_variable_input_test0:q
0
12"Reshapereshape_variable_input_test0Z
0
Z
1
b
2
B
\ No newline at end of file
test/onnx/reshape_variable_input_test1.onnx
0 → 100644
View file @
2c8ba41a
File added
test/onnx/round_half_test.onnx
0 → 100644
View file @
2c8ba41a
round_half_test:J
xy"Roundround_half_testZ
x
b
y
B
\ No newline at end of file
test/onnx/slice_var_input_default_steps.onnx
0 → 100644
View file @
2c8ba41a
File added
test/onnx/upsample_ver7_test.onnx
0 → 100644
View file @
2c8ba41a
File added
test/onnx/verify_onnx.cpp
View file @
2c8ba41a
...
...
@@ -2056,6 +2056,43 @@ TEST_CASE(reversesequence_time_verify_test)
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
round_half_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"round_half_test.onnx"
);
p
.
compile
(
migraphx
::
make_target
(
"ref"
));
migraphx
::
shape
xs
{
migraphx
::
shape
::
half_type
,
{
4
,
4
}};
std
::
vector
<
float
>
tmp
=
{
-
3.51
,
-
3.5
,
-
3.49
,
-
2.51
,
-
2.50
,
-
2.49
,
-
1.6
,
-
1.5
,
-
0.51
,
-
0.5
,
0.5
,
0.6
,
2.4
,
2.5
,
3.5
,
4.5
};
std
::
vector
<
migraphx
::
half
>
data
{
tmp
.
cbegin
(),
tmp
.
cend
()};
migraphx
::
parameter_map
param_map
;
param_map
[
"x"
]
=
migraphx
::
argument
(
xs
,
data
.
data
());
auto
result
=
p
.
eval
(
param_map
).
back
();
std
::
vector
<
migraphx
::
half
>
result_vector
;
result
.
visit
([
&
](
auto
output
)
{
result_vector
.
assign
(
output
.
begin
(),
output
.
end
());
});
tmp
=
{
-
4.0
,
-
4.0
,
-
3.0
,
-
3.0
,
-
2.0
,
-
2.0
,
-
2.0
,
-
2.0
,
-
1.0
,
0.0
,
0.0
,
1.0
,
2.0
,
2.0
,
4.0
,
4.0
};
std
::
vector
<
migraphx
::
half
>
gold
{
tmp
.
cbegin
(),
tmp
.
cend
()};
EXPECT
(
migraphx
::
verify
::
verify_rms_range
(
result_vector
,
gold
));
}
TEST_CASE
(
selu_test
)
{
migraphx
::
program
p
=
migraphx
::
parse_onnx
(
"selu_test.onnx"
);
...
...
test/py/CMakeLists.txt
View file @
2c8ba41a
...
...
@@ -41,13 +41,21 @@ function(add_py_venv_fixture FIXTURE_NAME VIRTUAL_ENV_DIR REQUIREMENTS_FILE)
set
(
PYTHON_EXECUTABLE
${
PYTHON_
${
PYTHON_VERSION
}
_EXECUTABLE
}
)
if
(
NOT TEST py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_initialize_env
)
if
(
NOT
(
${
PYTHON_VERSION
}
STREQUAL
${
PYTHON_VERSION_TO_DISABLE_ONNX
}
))
if
(
NOT
(
${
FIXTURE_NAME
}
STREQUAL
"onnx"
AND
${
PYTHON_VERSION
}
STREQUAL
${
PYTHON_VERSION_TO_DISABLE_ONNX
}
))
add_test
(
NAME py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_initialize_env COMMAND
${
PYTHON_EXECUTABLE
}
-m venv
${
VIRTUAL_ENV_DIR
}
/
${
PYTHON_VERSION
}
--clear
)
set_tests_properties
(
py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_initialize_env PROPERTIES FIXTURES_SETUP
${
FIXTURE_NAME
}
_
${
PYTHON_VERSION
}
_INIT_VENV
)
set
(
PYTHON_EXECUTABLE
${
VIRTUAL_ENV_DIR
}
/
${
PYTHON_VERSION
}
/bin/python
)
add_test
(
NAME py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_setup_env
COMMAND
${
PYTHON_EXECUTABLE
}
-m pip install -r
${
REQUIREMENTS_FILE
}
)
if
(
EXISTS
${
REQUIREMENTS_FILE
}
)
add_test
(
NAME py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_setup_env
COMMAND
${
PYTHON_EXECUTABLE
}
-m pip install -r
${
REQUIREMENTS_FILE
}
)
else
()
# If there is no requirements file, then there are no packages to install in the virtual env.
# Just create a placeholder test for setting up the required fixture for running the tests.
add_test
(
NAME py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_setup_env
COMMAND
${
PYTHON_EXECUTABLE
}
-m pip install --help
)
endif
()
set_tests_properties
(
py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_setup_env PROPERTIES FIXTURES_REQUIRED
${
FIXTURE_NAME
}
_
${
PYTHON_VERSION
}
_INIT_VENV
)
set_tests_properties
(
py_
${
PYTHON_VERSION
}
_
${
FIXTURE_NAME
}
_setup_env PROPERTIES FIXTURES_SETUP
${
FIXTURE_NAME
}
_
${
PYTHON_VERSION
}
_VENV
)
endif
()
...
...
@@ -67,7 +75,7 @@ function(add_py_test NAME SCRIPT FIXTURE_NAME VENV_DIR)
else
()
set
(
PYTHON_EXECUTABLE
${
VENV_DIR
}
/
${
PYTHON_VERSION
}
/bin/python
)
endif
()
if
(
NOT
${
PYTHON_VERSION
}
STREQUAL
${
PYTHON_VERSION_TO_DISABLE_ONNX
}
)
if
(
NOT
(
${
FIXTURE_NAME
}
STREQUAL
"onnx"
AND
${
PYTHON_VERSION
}
STREQUAL
${
PYTHON_VERSION_TO_DISABLE_ONNX
}
)
)
add_test
(
NAME test_py_
${
PYTHON_VERSION
}
_
${
NAME
}
COMMAND
${
ENV_COMMAND
}
${
PYTHON_EXECUTABLE
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/
${
SCRIPT
}
${
ARGN
}
)
...
...
test/py/onnx_backend_test.py
View file @
2c8ba41a
...
...
@@ -83,7 +83,6 @@ def disabled_tests_onnx_1_7_0(backend_test):
backend_test
.
exclude
(
r
'test_nonmaxsuppression_two_batches_cpu'
)
backend_test
.
exclude
(
r
'test_nonmaxsuppression_two_classes_cpu'
)
backend_test
.
exclude
(
r
'test_nonzero_example_cpu'
)
backend_test
.
exclude
(
r
'test_round_cpu'
)
backend_test
.
exclude
(
r
'test_softmax_axis_0_cpu'
)
backend_test
.
exclude
(
r
'test_softmax_axis_1_cpu'
)
backend_test
.
exclude
(
r
'test_softmax_default_axis_cpu'
)
...
...
test/py/requirements.txt
deleted
100644 → 0
View file @
5d998fb2
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
numpy==1.21.6
test/py/test_gpu.py
View file @
2c8ba41a
...
...
@@ -22,7 +22,6 @@
# THE SOFTWARE.
#####################################################################################
import
migraphx
import
numpy
as
np
def
test_conv_relu
():
...
...
@@ -51,8 +50,12 @@ def test_sub_uint64():
params
=
{}
shapes
=
p
.
get_parameter_shapes
()
params
[
"0"
]
=
np
.
arange
(
120
).
reshape
(
shapes
[
"0"
].
lens
()).
astype
(
np
.
uint64
)
params
[
"1"
]
=
np
.
arange
(
20
).
reshape
(
shapes
[
"1"
].
lens
()).
astype
(
np
.
uint64
)
params
[
"0"
]
=
migraphx
.
create_argument
(
migraphx
.
shape
(
type
=
'uint64_type'
,
lens
=
shapes
[
"0"
].
lens
()),
list
(
range
(
120
)))
params
[
"1"
]
=
migraphx
.
create_argument
(
migraphx
.
shape
(
type
=
'uint64_type'
,
lens
=
shapes
[
"1"
].
lens
()),
list
(
range
(
20
)))
r
=
p
.
run
(
params
)
print
(
r
)
...
...
@@ -67,7 +70,9 @@ def test_neg_int64():
params
=
{}
shapes
=
p
.
get_parameter_shapes
()
params
[
"0"
]
=
np
.
arange
(
6
).
reshape
(
shapes
[
"0"
].
lens
()).
astype
(
np
.
int64
)
params
[
"0"
]
=
migraphx
.
create_argument
(
migraphx
.
shape
(
type
=
'int64_type'
,
lens
=
shapes
[
"0"
].
lens
()),
list
(
range
(
6
)))
r
=
p
.
run
(
params
)
print
(
r
)
...
...
@@ -82,8 +87,9 @@ def test_nonzero():
params
=
{}
shapes
=
p
.
get_parameter_shapes
()
params
[
"data"
]
=
np
.
array
([
1
,
1
,
0
,
1
]).
reshape
(
shapes
[
"data"
].
lens
()).
astype
(
bool
)
params
[
"data"
]
=
migraphx
.
create_argument
(
migraphx
.
shape
(
type
=
'bool_type'
,
lens
=
shapes
[
"data"
].
lens
()),
[
1
,
1
,
0
,
1
])
r
=
p
.
run
(
params
)
print
(
r
)
...
...
@@ -101,8 +107,8 @@ def test_fp16_imagescaler():
params
=
{}
shapes
=
p
.
get_parameter_shapes
()
params
[
"0"
]
=
np
.
random
.
randn
(
768
).
reshape
(
shapes
[
"0"
].
lens
()).
astype
(
np
.
float16
)
params
[
"0"
]
=
migraphx
.
generate_argument
(
migraphx
.
shape
(
type
=
'half_type'
,
lens
=
shapes
[
"0"
].
lens
()),
768
)
r
=
p
.
run
(
params
)[
-
1
]
print
(
r
)
...
...
@@ -120,10 +126,12 @@ def test_if_pl():
params
=
{}
shapes
=
p
.
get_parameter_shapes
()
params
[
"x"
]
=
np
.
ones
(
6
).
reshape
(
shapes
[
"x"
].
lens
()).
astype
(
np
.
float32
)
params
[
"y"
]
=
np
.
array
([
2.0
,
2.0
,
2.0
,
2.0
,
2.0
,
2.0
,
2.0
,
2.0
,
2.0
]).
reshape
(
shapes
[
"y"
].
lens
()).
astype
(
np
.
float32
)
params
[
"cond"
]
=
np
.
array
([
1
]).
reshape
(()).
astype
(
bool
)
params
[
"x"
]
=
migraphx
.
fill_argument
(
migraphx
.
shape
(
type
=
'float_type'
,
lens
=
shapes
[
"x"
].
lens
()),
1
)
params
[
"y"
]
=
migraphx
.
fill_argument
(
migraphx
.
shape
(
type
=
'float_type'
,
lens
=
shapes
[
"y"
].
lens
()),
2.0
)
params
[
"cond"
]
=
migraphx
.
fill_argument
(
migraphx
.
shape
(
type
=
"bool"
,
lens
=
[
1
],
strides
=
[
0
]),
1
)
r
=
p
.
run
(
params
)[
-
1
]
print
(
r
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment