Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
151dd91a
Commit
151dd91a
authored
Oct 11, 2023
by
turneram
Browse files
Merge remote-tracking branch 'origin/ck-flash-attn' into gemm-perf
parents
280e76d0
5b2b7489
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
175 additions
and
67 deletions
+175
-67
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
+1
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+9
-1
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+3
-5
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+9
-13
src/targets/gpu/device/targets.hpp.in
src/targets/gpu/device/targets.hpp.in
+5
-1
src/targets/gpu/hiprtc/main.cpp
src/targets/gpu/hiprtc/main.cpp
+1
-0
src/targets/gpu/include/migraphx/gpu/ck.hpp
src/targets/gpu/include/migraphx/gpu/ck.hpp
+8
-8
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+1
-4
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+1
-1
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+4
-1
src/tf/parse_reshape.cpp
src/tf/parse_reshape.cpp
+1
-2
test/gpu/jit.cpp
test/gpu/jit.cpp
+1
-1
test/gpu/stream_sync.cpp
test/gpu/stream_sync.cpp
+1
-1
test/include/test.hpp
test/include/test.hpp
+2
-0
test/jit.cpp
test/jit.cpp
+1
-3
test/msgpack.cpp
test/msgpack.cpp
+11
-8
test/multi_target/multitarget_test.cpp
test/multi_target/multitarget_test.cpp
+0
-1
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+55
-0
test/onnx/onnx_test.cpp
test/onnx/onnx_test.cpp
+61
-17
test/onnx/qlinearadd_bcast_test.onnx
test/onnx/qlinearadd_bcast_test.onnx
+0
-0
No files found.
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
View file @
151dd91a
...
...
@@ -24,6 +24,7 @@
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_CPU_POINTWISE_HPP
#include <array>
#include <migraphx/config.hpp>
#include <migraphx/context.hpp>
#include <migraphx/check_shapes.hpp>
...
...
src/targets/gpu/CMakeLists.txt
View file @
151dd91a
...
...
@@ -48,10 +48,18 @@ else()
set
(
MIGRAPHX_USE_HIPRTC ON CACHE BOOL
"Use hipRTC APIs"
)
endif
()
include
(
Embed
)
file
(
GLOB KERNEL_FILES CONFIGURE_DEPENDS
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
if
(
WIN32
)
# TODO: re-enable when CK is ported to Windows
list
(
REMOVE_ITEM KERNEL_FILES
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck_gemm.hpp
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/ck.hpp
)
endif
()
include
(
Embed
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
RELATIVE
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/
)
configure_file
(
device/targets.hpp.in include/migraphx/gpu/device/targets.hpp
)
...
...
src/targets/gpu/compile_hip.cpp
View file @
151dd91a
...
...
@@ -248,7 +248,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
...
...
@@ -338,7 +338,7 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
{
if
(
src
.
path
.
extension
()
!=
".cpp"
)
continue
;
std
::
cout
<<
std
::
string
(
src
.
content
.
first
,
src
.
len
()
)
<<
std
::
endl
;
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
...
...
@@ -359,9 +359,7 @@ bool hip_has_flags(const std::vector<std::string>& flags)
join_strings
(
flags
,
" "
)
+
" -x hip -c --offload-arch=gfx900 --cuda-device-only"
;
std
::
string
src
;
src_file
input
;
input
.
path
=
"main.cpp"
;
input
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
src_file
input
{
"main.cpp"
,
src
};
try
{
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
151dd91a
...
...
@@ -172,21 +172,17 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
=
options
.
additional_src_files
;
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
std
::
back_inserter
(
srcs
),
[](
auto
&&
p
)
{
auto
&&
name
=
p
.
first
;
auto
&&
c
=
p
.
second
;
auto
path
=
name
;
return
src_file
{
path
,
c
};
});
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
static
auto
kernels
{
::
migraphx_kernels
()};
std
::
transform
(
kernels
.
begin
(),
kernels
.
end
(),
std
::
back_inserter
(
srcs
),
[](
const
std
::
pair
<
std
::
string_view
,
std
::
string_view
>&
elem
)
{
return
src_file
{
elem
};
});
srcs
.
emplace_back
(
"main.cpp"
,
content
);
auto
args_hpp
=
generate_args_hpp
(
options
.
virtual_inputs
.
empty
()
?
options
.
inputs
:
options
.
virtual_inputs
);
srcs
.
push_back
(
src_file
{
fs
::
path
{
"
args
.
hpp
"
},
std
::
make_pair
(
args_hpp
.
data
(),
args_hpp
.
data
()
+
args_hpp
.
size
())});
srcs
.
emplace_back
(
"args.hpp"
,
args
_
hpp
);
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
...
...
src/targets/gpu/device/targets.hpp.in
View file @
151dd91a
...
...
@@ -24,7 +24,7 @@
#ifndef MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#define MIGRAPHX_GUARD_DEVICE_TARGETS_CPP
#include <migraphx/config.hpp>
#include <migraphx/
gpu/device/
config.hpp>
#include <string>
#include <vector>
...
...
@@ -34,9 +34,13 @@ namespace gpu {
namespace device {
#define MIGRAPHX_GPU_TARGETS "@GPU_TARGETS@" // NOLINT
MIGRAPHX_DEVICE_EXPORT
const std::vector<std::string>& get_targets();
MIGRAPHX_DEVICE_EXPORT
std::string get_targets_as_string();
MIGRAPHX_DEVICE_EXPORT
std::string get_device_name();
} // namespace device
...
...
src/targets/gpu/hiprtc/main.cpp
View file @
151dd91a
...
...
@@ -27,6 +27,7 @@
#include <migraphx/msgpack.hpp>
#include <migraphx/file_buffer.hpp>
#include <migraphx/ranges.hpp>
#include <array>
#include <iostream>
#include <cstring>
...
...
src/targets/gpu/include/migraphx/gpu/ck.hpp
View file @
151dd91a
...
...
@@ -28,6 +28,7 @@
#include <migraphx/env.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/stringutils.hpp>
#include <string_view>
#include "ck/host/device_gemm_multiple_d.hpp"
#include "ck/host/device_batched_gemm_softmax_gemm.hpp"
...
...
@@ -55,7 +56,7 @@ template <class P>
std
::
string
ck_disable_warnings
(
P
p
)
{
return
interpolate_string
(
disable_warning_pragma
,
{{
"content"
,
std
::
string
{
p
.
first
,
p
.
second
}}});
{{
"content"
,
std
::
string
{
p
.
data
(),
p
.
size
()
}}});
}
static
std
::
unordered_map
<
std
::
string
,
std
::
string
>
create_ck_header_strings
()
...
...
@@ -64,8 +65,8 @@ static std::unordered_map<std::string, std::string> create_ck_header_strings()
auto
ck_headers
=
ck
::
host
::
GetHeaders
();
std
::
transform
(
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&
&
p
)
{
return
std
::
make_pair
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
ck_headers
.
begin
(),
ck_headers
.
end
(),
std
::
inserter
(
result
,
result
.
begin
()),
[
&
](
auto
&
p
)
{
return
std
::
pair
<
std
::
string
,
std
::
string
>
(
p
.
first
,
ck_disable_warnings
(
p
.
second
));
});
return
result
;
}
...
...
@@ -74,11 +75,10 @@ static std::vector<src_file> create_ck_headers()
{
static
const
auto
&
header_strings
=
create_ck_header_strings
();
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&&
p
)
{
return
src_file
{
fs
::
path
{
p
.
first
},
{
p
.
second
.
data
(),
p
.
second
.
data
()
+
p
.
second
.
size
()}};
});
std
::
transform
(
header_strings
.
begin
(),
header_strings
.
end
(),
std
::
back_inserter
(
srcs
),
[
&
](
auto
&
p
)
{
return
src_file
{
p
};
});
return
srcs
;
}
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
151dd91a
...
...
@@ -45,10 +45,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS);
struct
hiprtc_src_file
{
hiprtc_src_file
()
=
default
;
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
.
first
,
s
.
content
.
second
)
{
}
hiprtc_src_file
(
const
src_file
&
s
)
:
path
(
s
.
path
.
string
()),
content
(
s
.
content
)
{}
std
::
string
path
;
std
::
string
content
;
template
<
class
Self
,
class
F
>
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
151dd91a
...
...
@@ -135,7 +135,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
const
auto
&
c_shape
=
inputs
.
back
();
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
0
);
auto
tuning_value
=
v
.
get
(
"tuning_value"
,
34
);
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
problem
=
create_problem
(
inputs
,
v
);
...
...
src/targets/gpu/mlir.cpp
View file @
151dd91a
...
...
@@ -320,7 +320,10 @@ struct mlir_program
MlirType
make_tensor
(
const
shape
&
s
)
const
{
assert
(
s
.
standard
());
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR expects all tensors to be in standard shape"
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MLIR does not support dynamic shapes"
);
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
return
mlirRankedTensorTypeGet
(
lens
.
size
(),
lens
.
data
(),
make_type
(
s
.
type
()),
mlirAttributeGetNull
());
...
...
src/tf/parse_reshape.cpp
View file @
151dd91a
...
...
@@ -45,8 +45,7 @@ struct parse_reshape : op_parser<parse_reshape>
auto
s
=
args
[
1
]
->
eval
();
std
::
vector
<
int64_t
>
dims
;
s
.
visit
([
&
](
auto
v
)
{
copy
(
v
,
std
::
back_inserter
(
dims
));
});
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
info
.
make_contiguous
(
args
[
0
]));
return
info
.
add_instruction
(
make_op
(
"reshape"
,
{{
"dims"
,
dims
}}),
args
[
0
]);
}
};
...
...
test/gpu/jit.cpp
View file @
151dd91a
...
...
@@ -155,7 +155,7 @@ int main() {}
migraphx
::
src_file
make_src_file
(
const
std
::
string
&
name
,
const
std
::
string
&
content
)
{
return
{
name
,
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())
};
return
{
name
,
content
};
}
TEST_CASE
(
simple_compile_hip
)
...
...
test/gpu/stream_sync.cpp
View file @
151dd91a
...
...
@@ -64,7 +64,7 @@ int main() {}
migraphx
::
src_file
make_src_file
(
const
std
::
string
&
name
,
const
std
::
string
&
content
)
{
return
{
name
,
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())
};
return
{
name
,
content
};
}
hip_stream_ptr
get_stream
()
...
...
test/include/test.hpp
View file @
151dd91a
...
...
@@ -339,6 +339,8 @@ inline std::ostream& operator<<(std::ostream& os, const color& c)
static
const
bool
use_color
=
isatty
(
STDOUT_FILENO
)
!=
0
;
if
(
use_color
)
return
os
<<
"
\033
["
<<
static_cast
<
std
::
size_t
>
(
c
)
<<
"m"
;
#else
(
void
)
c
;
#endif
return
os
;
}
...
...
test/jit.cpp
View file @
151dd91a
...
...
@@ -48,9 +48,7 @@ compile_function(const std::string& src, const std::string& flags, const std::st
migraphx
::
src_compiler
compiler
;
compiler
.
flags
=
flags
+
"-std=c++14 -fPIC -shared"
;
compiler
.
output
=
"libsimple.so"
;
migraphx
::
src_file
f
;
f
.
path
=
"main.cpp"
;
f
.
content
=
std
::
make_pair
(
src
.
data
(),
src
.
data
()
+
src
.
size
());
migraphx
::
src_file
f
{
"main.cpp"
,
src
};
auto
image
=
compiler
.
compile
({
f
});
return
migraphx
::
dynamic_loader
{
image
}.
get_function
<
F
>
(
fname
);
}
...
...
test/msgpack.cpp
View file @
151dd91a
...
...
@@ -97,9 +97,12 @@ TEST_CASE(test_msgpack_bool)
TEST_CASE
(
test_msgpack_float
)
{
migraphx
::
value
v
=
3.0
;
// changed all double values in this code to not end with .0 because on msgpack for Windows if
// input type is double and ends with .0 it could be converted to uint64_t or int64_t and the
// goal of these functions is to test double without conversions
migraphx
::
value
v
=
3.01
;
auto
buffer
=
migraphx
::
to_msgpack
(
v
);
EXPECT
(
buffer
==
msgpack_buffer
(
3.0
));
EXPECT
(
buffer
==
msgpack_buffer
(
3.0
1
));
EXPECT
(
migraphx
::
from_msgpack
(
buffer
)
==
v
);
}
...
...
@@ -129,10 +132,10 @@ TEST_CASE(test_msgpack_empty_array)
TEST_CASE
(
test_msgpack_object
)
{
migraphx
::
value
v
=
{{
"one"
,
1.0
},
{
"three"
,
3.0
},
{
"two"
,
2.0
}};
migraphx
::
value
v
=
{{
"one"
,
1.0
1
},
{
"three"
,
3.0
1
},
{
"two"
,
2.0
1
}};
auto
buffer
=
migraphx
::
to_msgpack
(
v
);
EXPECT
(
buffer
==
msgpack_buffer
(
std
::
map
<
std
::
string
,
double
>
{
{
"one"
,
1.0
},
{
"three"
,
3.0
},
{
"two"
,
2.0
}}));
{
"one"
,
1.0
1
},
{
"three"
,
3.0
1
},
{
"two"
,
2.0
1
}}));
EXPECT
(
migraphx
::
from_msgpack
(
buffer
)
==
v
);
}
...
...
@@ -157,17 +160,17 @@ struct foo
TEST_CASE
(
test_msgpack_object_class
)
{
migraphx
::
value
v
=
{{
"a"
,
1.0
},
{
"b"
,
"abc"
}};
migraphx
::
value
v
=
{{
"a"
,
1.0
1
},
{
"b"
,
"abc"
}};
auto
buffer
=
migraphx
::
to_msgpack
(
v
);
EXPECT
(
buffer
==
msgpack_buffer
(
foo
{
1.0
,
"abc"
}));
EXPECT
(
buffer
==
msgpack_buffer
(
foo
{
1.0
1
,
"abc"
}));
EXPECT
(
migraphx
::
from_msgpack
(
buffer
)
==
v
);
}
TEST_CASE
(
test_msgpack_array_class
)
{
migraphx
::
value
v
=
{{{
"a"
,
1.0
},
{
"b"
,
"abc"
}},
{{
"a"
,
3.0
},
{
"b"
,
"xyz"
}}};
migraphx
::
value
v
=
{{{
"a"
,
1.0
1
},
{
"b"
,
"abc"
}},
{{
"a"
,
3.0
1
},
{
"b"
,
"xyz"
}}};
auto
buffer
=
migraphx
::
to_msgpack
(
v
);
EXPECT
(
buffer
==
msgpack_buffer
(
std
::
vector
<
foo
>
{
foo
{
1.0
,
"abc"
},
foo
{
3.0
,
"xyz"
}}));
EXPECT
(
buffer
==
msgpack_buffer
(
std
::
vector
<
foo
>
{
foo
{
1.0
1
,
"abc"
},
foo
{
3.0
1
,
"xyz"
}}));
EXPECT
(
migraphx
::
from_msgpack
(
buffer
)
==
v
);
}
...
...
test/multi_target/multitarget_test.cpp
View file @
151dd91a
...
...
@@ -37,7 +37,6 @@
#include <migraphx/make_op.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp>
#include <basic_ops.hpp>
#include <migraphx/compile_options.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/generate.hpp>
...
...
test/onnx/gen_onnx.py
View file @
151dd91a
...
...
@@ -5149,6 +5149,61 @@ def prelu_brcst_test():
return
([
node
],
[
arg0
,
arg1
],
[
arg_out
])
@
onnx_test
()
def
qlinearadd_test
():
a
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
UINT8
,
[
64
])
sc_a
=
helper
.
make_tensor
(
'A_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_a
=
helper
.
make_tensor
(
'A_zero_point'
,
TensorProto
.
UINT8
,
[],
[
0
])
b
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
UINT8
,
[
64
])
sc_b
=
helper
.
make_tensor
(
'B_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_b
=
helper
.
make_tensor
(
'B_zero_point'
,
TensorProto
.
UINT8
,
[],
[
128
])
sc_c
=
helper
.
make_tensor
(
'C_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_c
=
helper
.
make_tensor
(
'C_zero_point'
,
TensorProto
.
UINT8
,
[],
[
64
])
c
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
UINT8
,
[
64
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAdd'
,
inputs
=
[
'A'
,
'A_scale'
,
'A_zero_point'
,
'B'
,
'B_scale'
,
'B_zero_point'
,
'C_scale'
,
'C_zero_point'
],
outputs
=
[
'C'
],
)
return
([
node
],
[
a
,
b
],
[
c
],
[
sc_a
,
zero_pt_a
,
sc_b
,
zero_pt_b
,
sc_c
,
zero_pt_c
])
@
onnx_test
()
def
qlinearadd_bcast_test
():
a
=
helper
.
make_tensor_value_info
(
'A'
,
TensorProto
.
INT8
,
[
64
])
sc_a
=
helper
.
make_tensor
(
'A_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_a
=
helper
.
make_tensor
(
'A_zero_point'
,
TensorProto
.
INT8
,
[],
[
0
])
b
=
helper
.
make_tensor_value_info
(
'B'
,
TensorProto
.
INT8
,
[
1
,
1
,
64
])
sc_b
=
helper
.
make_tensor
(
'B_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_b
=
helper
.
make_tensor
(
'B_zero_point'
,
TensorProto
.
INT8
,
[],
[
32
])
sc_c
=
helper
.
make_tensor
(
'C_scale'
,
TensorProto
.
FLOAT
,
[],
[
0.05
])
zero_pt_c
=
helper
.
make_tensor
(
'C_zero_point'
,
TensorProto
.
INT8
,
[],
[
-
64
])
c
=
helper
.
make_tensor_value_info
(
'C'
,
TensorProto
.
INT8
,
[
1
,
1
,
64
])
node
=
onnx
.
helper
.
make_node
(
'QLinearAdd'
,
inputs
=
[
'A'
,
'A_scale'
,
'A_zero_point'
,
'B'
,
'B_scale'
,
'B_zero_point'
,
'C_scale'
,
'C_zero_point'
],
outputs
=
[
'C'
],
)
return
([
node
],
[
a
,
b
],
[
c
],
[
sc_a
,
zero_pt_a
,
sc_b
,
zero_pt_b
,
sc_c
,
zero_pt_c
])
@
onnx_test
()
def
quantizelinear_test
():
arg0
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
5
])
...
...
test/onnx/onnx_test.cpp
View file @
151dd91a
...
...
@@ -1772,8 +1772,7 @@ TEST_CASE(depthtospace_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
auto prog = optimize_onnx("depthtospace_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1787,8 +1786,7 @@ TEST_CASE(depthtospace_crd_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 2, 2, 5, 5}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 1, 4, 2, 5, 3}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 10, 10}}}), tmp2);
auto prog = optimize_onnx("depthtospace_crd_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1802,8 +1800,7 @@ TEST_CASE(depthtospace_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 2, 3}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 4, 1, 5, 2}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 4, 6}}}), tmp2);
auto prog = optimize_onnx("depthtospace_simple_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1817,8 +1814,7 @@ TEST_CASE(spacetodepth_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 2, 5, 2, 5, 2}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {2, 8, 5, 5}}}), tmp2);
auto prog = optimize_onnx("spacetodepth_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -1832,8 +1828,7 @@ TEST_CASE(spacetodepth_simple_test)
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 2, 2, 2, 3, 2}}}), l0);
auto tmp2 = mm->add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 3, 5, 1, 2, 4}}}), tmp1);
auto tmp3 = mm->add_instruction(migraphx::make_op("contiguous"), tmp2);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp3);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {1, 8, 2, 3}}}), tmp2);
auto prog = optimize_onnx("spacetodepth_simple_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -4856,6 +4851,59 @@ TEST_CASE(prelu_brcst_test)
EXPECT(p == prog);
}
TEST_CASE(qlinearadd_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("A", {migraphx::shape::uint8_type, {64}});
auto b = mm->add_parameter("B", {migraphx::shape::uint8_type, {64}});
auto sc_a = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_a = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {0}});
auto sc_b = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_b = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {128}});
auto sc_c = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {0.05}});
auto z_pt_c = mm->add_literal(migraphx::literal{migraphx::shape::uint8_type, {64}});
auto scale_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_a);
auto z_pt_a_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_a);
auto fp_a =
mm->add_instruction(migraphx::make_op("dequantizelinear"), a, scale_a_bcast, z_pt_a_bcast);
auto scale_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_b);
auto z_pt_b_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_b);
auto fp_b =
mm->add_instruction(migraphx::make_op("dequantizelinear"), b, scale_b_bcast, z_pt_b_bcast);
auto fp_c = mm->add_instruction(migraphx::make_op("add"), fp_a, fp_b);
auto scale_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), sc_c);
auto z_pt_c_bcast =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64}}}), z_pt_c);
auto c =
mm->add_instruction(migraphx::make_op("quantizelinear"), fp_c, scale_c_bcast, z_pt_c_bcast);
mm->add_return({c});
auto prog = migraphx::parse_onnx("qlinearadd_test.onnx");
EXPECT(p.sort() == prog.sort());
}
TEST_CASE(quantizelinear_test)
{
migraphx::program p;
...
...
@@ -5438,12 +5486,9 @@ TEST_CASE(reshape_test)
migraphx::literal{migraphx::shape{migraphx::shape::int64_type, {2}}, reshape_dims});
auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {4, 2, 3}});
op.dims = reshape_dims;
auto c0 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c0);
auto c1 = mm->add_instruction(migraphx::make_op("contiguous"), l0);
mm->add_instruction(op, c1);
mm->add_instruction(op, l0);
mm->add_instruction(op, l0);
auto prog = optimize_onnx("reshape_test.onnx");
EXPECT(p == prog);
}
...
...
@@ -5456,8 +5501,7 @@ TEST_CASE(reshape_non_standard_test)
auto x = mm->add_parameter("x", s);
auto tran_x =
mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), x);
auto cont_x = mm->add_instruction(migraphx::make_op("contiguous"), tran_x);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), cont_x);
mm->add_instruction(migraphx::make_op("reshape", {{"dims", {4, 3, 2}}}), tran_x);
auto prog = optimize_onnx("reshape_non_standard_test.onnx");
EXPECT(p == prog);
...
...
test/onnx/qlinearadd_bcast_test.onnx
0 → 100644
View file @
151dd91a
File added
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment