Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
13a7d458
Unverified
Commit
13a7d458
authored
Aug 19, 2022
by
Paul Fultz II
Committed by
GitHub
Aug 19, 2022
Browse files
Merge branch 'develop' into perk-kernel
parents
3852e43b
8045f7c8
Changes
21
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
686 additions
and
151 deletions
+686
-151
.github/workflows/performance.yaml
.github/workflows/performance.yaml
+1
-3
src/api/include/migraphx/migraphx.hpp
src/api/include/migraphx/migraphx.hpp
+18
-18
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+30
-11
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+92
-3
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+65
-19
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+29
-2
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+5
-0
src/targets/gpu/jit/layernorm.cpp
src/targets/gpu/jit/layernorm.cpp
+133
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+8
-40
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+3
-2
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+83
-0
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+19
-0
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+5
-3
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+49
-37
test/include/test.hpp
test/include/test.hpp
+12
-9
test/simplify_reshapes_test.cpp
test/simplify_reshapes_test.cpp
+87
-0
test/verify/test_layernorm.cpp
test/verify/test_layernorm.cpp
+1
-1
test/verify/test_softmax_large1.cpp
test/verify/test_softmax_large1.cpp
+43
-0
No files found.
.github/workflows/performance.yaml
View file @
13a7d458
name
:
MIGraphX Performance Tests
name
:
MIGraphX Performance Tests
on
:
on
:
push
:
branches
:
[
develop
]
pull_request
:
pull_request
:
branches
:
[
develop
]
branches
:
[
develop
]
types
:
[
opened
,
synchronize
,
closed
]
schedule
:
schedule
:
-
cron
:
"
0
5
*
*
1-6"
-
cron
:
"
0
5
*
*
1-6"
...
...
src/api/include/migraphx/migraphx.hpp
View file @
13a7d458
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
...
@@ -517,7 +517,7 @@ struct shape : MIGRAPHX_CONST_HANDLE_BASE(shape)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
shape
(
const
migraphx_shape
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shape
)
/// Construct a scalar shape
/// Construct a scalar shape
shape
(
migraphx_shape_datatype_t
type
)
shape
(
migraphx_shape_datatype_t
type
)
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
...
@@ -601,7 +601,7 @@ struct argument : MIGRAPHX_CONST_HANDLE_BASE(argument)
{
{
argument
()
{}
argument
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
argument
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
argument
(
const
migraphx_argument
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
...
@@ -655,7 +655,7 @@ struct target : MIGRAPHX_HANDLE_BASE(target)
{
{
target
()
{}
target
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
target
)
/// Construct a target from its name
/// Construct a target from its name
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
target
(
const
char
*
name
)
{
this
->
make_handle
(
&
migraphx_target_create
,
name
);
}
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -665,7 +665,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
{
{
program_parameter_shapes
()
{}
program_parameter_shapes
()
{}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameter_shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
...
@@ -695,7 +695,7 @@ struct program_parameter_shapes : MIGRAPHX_HANDLE_BASE(program_parameter_shapes)
/// A class to construct the inputs parameters for a program
/// A class to construct the inputs parameters for a program
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
struct
program_parameters
:
MIGRAPHX_HANDLE_BASE
(
program_parameters
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program_parameters
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
MIGRAPHX_DEPRECATED
(
"Contructor without lifetime annotation is deprecated."
)
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
program_parameters
(
migraphx_program_parameters
*
p
)
{
this
->
set_handle
(
p
,
borrow
{});
}
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
...
@@ -722,7 +722,7 @@ struct program_parameters : MIGRAPHX_HANDLE_BASE(program_parameters)
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
struct
arguments
:
MIGRAPHX_HANDLE_BASE
(
arguments
),
array_base
<
arguments
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
arguments
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
...
@@ -741,7 +741,7 @@ struct arguments : MIGRAPHX_HANDLE_BASE(arguments), array_base<arguments>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
struct
shapes
:
MIGRAPHX_HANDLE_BASE
(
shapes
),
array_base
<
shapes
>
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
shapes
)
size_t
size
()
const
size_t
size
()
const
{
{
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
...
@@ -760,7 +760,7 @@ struct shapes : MIGRAPHX_HANDLE_BASE(shapes), array_base<shapes>
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
struct
operation
:
MIGRAPHX_HANDLE_BASE
(
operation
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
operation
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
operation
(
const
char
*
name
,
const
char
*
attributes
=
nullptr
,
Ts
...
xs
)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
...
@@ -778,12 +778,12 @@ struct operation : MIGRAPHX_HANDLE_BASE(operation)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
struct
instruction
:
MIGRAPHX_CONST_HANDLE_BASE
(
instruction
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instruction
)
};
};
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
struct
instructions
:
MIGRAPHX_HANDLE_BASE
(
instructions
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
instructions
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
instructions
(
Ts
...
xs
)
instructions
(
Ts
...
xs
)
...
@@ -797,7 +797,7 @@ struct module;
...
@@ -797,7 +797,7 @@ struct module;
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
struct
modules
:
MIGRAPHX_HANDLE_BASE
(
modules
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
modules
)
template
<
class
...
Ts
>
template
<
class
...
Ts
>
modules
(
Ts
...
xs
)
modules
(
Ts
...
xs
)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
...
@@ -911,7 +911,7 @@ struct compile_options : MIGRAPHX_HANDLE_BASE(compile_options)
{
{
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
compile_options
()
{
this
->
make_handle
(
&
migraphx_compile_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
compile_options
)
/// For targets with offloaded memory(such as the gpu), this will insert
/// For targets with offloaded memory(such as the gpu), this will insert
/// instructions during compilation to copy the input parameters to the
/// instructions during compilation to copy the input parameters to the
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -935,7 +935,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
{
{
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
program
()
{
this
->
make_handle
(
&
migraphx_program_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
program
)
/// Compile the program for a specific target to be ran on
/// Compile the program for a specific target to be ran on
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
void
compile
(
const
target
&
ptarget
,
const
compile_options
&
poptions
)
const
...
@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
...
@@ -1021,7 +1021,7 @@ struct program : MIGRAPHX_HANDLE_BASE(program)
// options for migraphx file format options
// options for migraphx file format options
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
struct
file_options
:
MIGRAPHX_HANDLE_BASE
(
file_options
)
{
{
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
file_options
)
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
file_options
()
{
this
->
make_handle
(
&
migraphx_file_options_create
);
}
// set file format
// set file format
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
...
@@ -1063,7 +1063,7 @@ struct onnx_options : MIGRAPHX_HANDLE_BASE(onnx_options)
{
{
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
onnx_options
()
{
this
->
make_handle
(
&
migraphx_onnx_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
onnx_options
)
/// Make onnx parser treat an inputs with a certain dimensions
/// Make onnx parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
...
@@ -1145,7 +1145,7 @@ struct tf_options : MIGRAPHX_HANDLE_BASE(tf_options)
{
{
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
tf_options
()
{
this
->
make_handle
(
&
migraphx_tf_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
tf_options
)
/// Make tf parser treat an inputs with a certain dimensions
/// Make tf parser treat an inputs with a certain dimensions
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
void
set_input_parameter_shape
(
const
std
::
string
&
name
,
std
::
vector
<
std
::
size_t
>
dim
)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
...
@@ -1198,7 +1198,7 @@ struct quantize_op_names : MIGRAPHX_HANDLE_BASE(quantize_op_names)
{
{
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
quantize_op_names
()
{
this
->
make_handle
(
&
migraphx_quantize_op_names_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_op_names
)
void
add
(
const
std
::
string
&
name
)
void
add
(
const
std
::
string
&
name
)
{
{
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
...
@@ -1223,7 +1223,7 @@ struct quantize_int8_options : MIGRAPHX_HANDLE_BASE(quantize_int8_options)
{
{
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
quantize_int8_options
()
{
this
->
make_handle
(
&
migraphx_quantize_int8_options_create
);
}
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
;
MIGRAPHX_HANDLE_CONSTRUCTOR
(
quantize_int8_options
)
/// Add an operator that should be quantized
/// Add an operator that should be quantized
void
add_op_name
(
const
std
::
string
&
name
)
void
add_op_name
(
const
std
::
string
&
name
)
...
...
src/py/migraphx_py.cpp
View file @
13a7d458
...
@@ -40,6 +40,7 @@
...
@@ -40,6 +40,7 @@
#include <migraphx/register_target.hpp>
#include <migraphx/register_target.hpp>
#include <migraphx/json.hpp>
#include <migraphx/json.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/common.hpp>
#ifdef HAVE_GPU
#ifdef HAVE_GPU
#include <migraphx/gpu/hip.hpp>
#include <migraphx/gpu/hip.hpp>
...
@@ -82,7 +83,7 @@ void visit_py(T x, F f)
...
@@ -82,7 +83,7 @@ void visit_py(T x, F f)
{
{
f
(
x
.
template
cast
<
bool
>());
f
(
x
.
template
cast
<
bool
>());
}
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
))
else
if
(
py
::
isinstance
<
py
::
int_
>
(
x
)
||
py
::
hasattr
(
x
,
"__index__"
)
)
{
{
f
(
x
.
template
cast
<
int
>());
f
(
x
.
template
cast
<
int
>());
}
}
...
@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -324,6 +325,7 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"get_parameter_names"
,
&
migraphx
::
program
::
get_parameter_names
)
.
def
(
"get_parameter_names"
,
&
migraphx
::
program
::
get_parameter_names
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"get_parameter_shapes"
,
&
migraphx
::
program
::
get_parameter_shapes
)
.
def
(
"get_output_shapes"
,
&
migraphx
::
program
::
get_output_shapes
)
.
def
(
"get_output_shapes"
,
&
migraphx
::
program
::
get_output_shapes
)
.
def
(
"is_compiled"
,
&
migraphx
::
program
::
is_compiled
)
.
def
(
.
def
(
"compile"
,
"compile"
,
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
[](
migraphx
::
program
&
p
,
const
migraphx
::
target
&
t
,
bool
offload_copy
,
bool
fast_math
)
{
...
@@ -358,8 +360,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -358,8 +360,8 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__ne__"
,
std
::
not_equal_to
<
migraphx
::
program
>
{})
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
.
def
(
"__repr__"
,
[](
const
migraphx
::
program
&
p
)
{
return
migraphx
::
to_string
(
p
);
});
py
::
class_
<
migraphx
::
operation
>
(
m
,
"op"
)
py
::
class_
<
migraphx
::
operation
>
op
(
m
,
"op"
)
;
.
def
(
py
::
init
([](
const
std
::
string
&
name
,
py
::
kwargs
kwargs
)
{
op
.
def
(
py
::
init
([](
const
std
::
string
&
name
,
py
::
kwargs
kwargs
)
{
migraphx
::
value
v
=
migraphx
::
value
::
object
{};
migraphx
::
value
v
=
migraphx
::
value
::
object
{};
if
(
kwargs
)
if
(
kwargs
)
{
{
...
@@ -367,9 +369,26 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -367,9 +369,26 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
}
}
return
migraphx
::
make_op
(
name
,
v
);
return
migraphx
::
make_op
(
name
,
v
);
}))
}))
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
py
::
enum_
<
migraphx
::
op
::
pooling_mode
>
(
op
,
"pooling_mode"
)
.
value
(
"average"
,
migraphx
::
op
::
pooling_mode
::
average
)
.
value
(
"max"
,
migraphx
::
op
::
pooling_mode
::
max
)
.
value
(
"lpnorm"
,
migraphx
::
op
::
pooling_mode
::
lpnorm
);
py
::
enum_
<
migraphx
::
op
::
rnn_direction
>
(
op
,
"rnn_direction"
)
.
value
(
"forward"
,
migraphx
::
op
::
rnn_direction
::
forward
)
.
value
(
"reverse"
,
migraphx
::
op
::
rnn_direction
::
reverse
)
.
value
(
"bidirectional"
,
migraphx
::
op
::
rnn_direction
::
bidirectional
);
m
.
def
(
"argument_from_pointer"
,
[](
const
migraphx
::
shape
shape
,
const
int64_t
address
)
{
return
migraphx
::
argument
(
shape
,
reinterpret_cast
<
void
*>
(
address
));
},
py
::
arg
(
"shape"
),
py
::
arg
(
"address"
));
m
.
def
(
m
.
def
(
"parse_tf"
,
"parse_tf"
,
[](
const
std
::
string
&
filename
,
[](
const
std
::
string
&
filename
,
...
...
src/simplify_reshapes.cpp
View file @
13a7d458
...
@@ -151,8 +151,11 @@ struct find_transpose
...
@@ -151,8 +151,11 @@ struct find_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
return
match
::
name
(
"transpose"
)(
match
::
none_of
(
auto
output_not_transpose
=
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
))));
match
::
none_of
(
match
::
skip_output
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
)));
auto
input_has_transpose
=
match
::
args
(
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"transpose"
)));
return
match
::
name
(
"transpose"
)(
output_not_transpose
,
input_has_transpose
);
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
...
@@ -664,9 +667,94 @@ struct find_slice_transpose
...
@@ -664,9 +667,94 @@ struct find_slice_transpose
}
}
};
};
struct
find_transpose_slice
{
auto
matcher
()
const
{
return
match
::
name
(
"transpose"
)(
match
::
all_of
[
match
::
outputs
()](
match
::
name
(
"slice"
)));
}
static
std
::
vector
<
int64_t
>
slice_distance
(
const
op
::
slice
&
op
)
{
assert
(
op
.
starts
.
size
()
==
op
.
ends
.
size
());
std
::
vector
<
int64_t
>
result
(
op
.
starts
.
size
());
std
::
transform
(
op
.
ends
.
begin
(),
op
.
ends
.
end
(),
op
.
starts
.
begin
(),
result
.
begin
(),
std
::
minus
<>
{});
return
result
;
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
slices
=
ins
->
outputs
();
if
(
slices
.
empty
())
return
;
auto
slice
=
any_cast
<
op
::
slice
>
(
slices
.
front
()
->
get_operator
());
auto
sdistance
=
slice_distance
(
slice
);
// Check all distances and axes are the same
if
(
std
::
any_of
(
slices
.
begin
(),
slices
.
end
(),
[
&
](
auto
sins
)
{
auto
s
=
any_cast
<
op
::
slice
>
(
sins
->
get_operator
());
return
s
.
axes
!=
slice
.
axes
or
slice_distance
(
s
)
!=
sdistance
;
}))
return
;
// Check distances are divisible by lens of corresponding axes
auto
mod_by_distance
=
[
&
](
const
auto
&
v
,
auto
f
)
{
return
std
::
inner_product
(
v
.
begin
(),
v
.
end
(),
sdistance
.
begin
(),
0
,
std
::
plus
<>
{},
[
&
](
auto
x
,
auto
d
)
->
uint64_t
{
if
(
d
==
0
)
return
1
;
return
f
(
x
)
%
d
;
});
};
if
(
mod_by_distance
(
slice
.
axes
,
[
&
](
auto
x
)
{
return
ins
->
get_shape
().
lens
()[
x
];
})
!=
0
or
mod_by_distance
(
slice
.
starts
,
id
{})
!=
0
or
mod_by_distance
(
slice
.
ends
,
id
{})
!=
0
)
return
;
// TODO: Handle multiple axes
if
(
sdistance
.
size
()
!=
1
)
return
;
auto
axis
=
slice
.
axes
.
front
();
// Skip if axis would be packed
if
(
std
::
all_of
(
ins
->
get_shape
().
lens
().
begin
(),
ins
->
get_shape
().
lens
().
begin
()
+
axis
,
[](
auto
x
)
{
return
x
==
1
;
}))
return
;
// Compute axis before transpose to use for unsqueeze
auto
perm
=
ins
->
get_operator
().
to_value
()[
"permutation"
].
to_vector
<
int64_t
>
();
auto
preaxis
=
std
::
find
(
perm
.
begin
(),
perm
.
end
(),
axis
)
-
perm
.
begin
();
// Make unsqeeze
auto
unsqueeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
preaxis
}},
{
"steps"
,
sdistance
}}),
ins
->
inputs
());
// Make transpose
std
::
transform
(
perm
.
begin
(),
perm
.
end
(),
perm
.
begin
(),
[
&
](
auto
i
)
{
if
(
i
>
preaxis
)
return
i
+
1
;
return
i
;
});
perm
.
insert
(
perm
.
begin
(),
preaxis
+
1
);
auto
transpose
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}),
unsqueeze
);
// Slice and squeeze
for
(
auto
s
:
slices
)
{
auto
op
=
any_cast
<
op
::
slice
>
(
s
->
get_operator
());
op
.
axes
=
{
0
};
op
.
starts
=
{
op
.
starts
.
front
()
/
sdistance
.
front
()};
op
.
ends
=
{
op
.
ends
.
front
()
/
sdistance
.
front
()};
auto
slice_ins
=
m
.
insert_instruction
(
ins
,
op
,
transpose
);
auto
squeeze
=
m
.
insert_instruction
(
ins
,
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice_ins
);
m
.
replace_instruction
(
s
,
squeeze
);
}
}
};
void
simplify_reshapes
::
apply
(
module
&
m
)
const
void
simplify_reshapes
::
apply
(
module
&
m
)
const
{
{
for
(
int
i
=
0
;
i
<
2
;
i
++
)
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
{
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_where_op
{},
find_where_op
{},
...
@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -679,6 +767,7 @@ void simplify_reshapes::apply(module& m) const
find_nested_convert
{},
find_nested_convert
{},
find_nested_slice
{},
find_nested_slice
{},
find_nested_concat
{},
find_nested_concat
{},
find_transpose_slice
{},
find_slice_transpose
{},
find_slice_transpose
{},
find_transpose_contiguous_reshaper_unary
{});
find_transpose_contiguous_reshaper_unary
{});
dead_code_elimination
{}.
apply
(
m
);
dead_code_elimination
{}.
apply
(
m
);
...
...
src/targets/gpu/compile_gen.cpp
View file @
13a7d458
...
@@ -25,6 +25,13 @@
...
@@ -25,6 +25,13 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -75,25 +82,25 @@ std::string vectorize::str() const
...
@@ -75,25 +82,25 @@ std::string vectorize::str() const
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
const
std
::
size_t
max_lds_bytes
=
4096
;
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
vector
<
bool
>
result
(
inputs
.
size
());
std
::
transform
(
inputs
.
begin
(),
std
::
vector
<
std
::
size_t
>
preloaded
;
inputs
.
end
(),
auto
idxs
=
range
(
inputs
.
size
());
std
::
back_inserter
(
result
),
std
::
copy_if
(
idxs
.
begin
(),
idxs
.
end
(),
std
::
back_inserter
(
preloaded
),
[
&
](
auto
i
)
{
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
return
inputs
[
i
].
strides
()[
axis
]
==
0
;
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
});
if
(
bytes
<
max_lds_bytes
)
std
::
sort
(
preloaded
.
begin
(),
preloaded
.
end
(),
by
(
std
::
less
<>
{},
[
&
](
auto
i
)
{
return
{
result
};
return
inputs
[
i
].
bytes
();
// TODO: Try to partially preload items
}));
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
std
::
size_t
bytes
=
0
;
for
(
auto
i
:
preloaded
)
{
auto
input
=
inputs
[
i
];
bytes
+=
input
.
bytes
();
if
(
bytes
>
max_lds_bytes
)
break
;
result
[
i
]
=
true
;
}
return
{
result
};
return
{
result
};
}
}
...
@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
...
@@ -125,6 +132,45 @@ std::string make_transformer_args(std::vector<std::string> transformers)
return
join_strings
(
std
::
move
(
transformers
),
", "
);
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
)
{
module
m
=
pm
;
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
).
set_name
(
name
));
return
g
.
str
();
}
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
std
::
string
generate_name_from_ops
(
const
module
&
m
)
{
auto
op_names
=
get_op_names
(
m
);
return
join_strings
(
op_names
,
"_"
);
}
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/fuse_ops.cpp
View file @
13a7d458
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
...
@@ -827,13 +827,14 @@ void apply_conv_bias(context& ctx, module& m, const match::matcher_result& r)
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
m
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
template
<
class
...
Strings
>
inline
auto
precompile_name
(
Strings
...
names
)
// NOLINT
{
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
return
false
;
return
false
;
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
return
(
op
.
name
()
==
s
);
return
(
contains
({
names
...},
op
.
name
()
)
);
});
});
}
}
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
...
@@ -1041,6 +1042,31 @@ struct find_contiguous_pointwise
}
}
};
};
struct
find_layernorm_pointwise
{
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
arg
(
0
)(
precompile_name
(
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
).
bind
(
"layernorm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
layernorm
=
r
.
instructions
[
"layernorm"
];
auto
*
pm
=
ins
->
module_inputs
().
front
();
if
(
not
layernorm
->
module_inputs
().
empty
())
return
;
auto
inputs
=
layernorm
->
inputs
();
inputs
.
pop_back
();
inputs
.
insert
(
inputs
.
end
(),
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
());
m
.
replace_instruction
(
ins
,
layernorm
->
get_operator
(),
inputs
,
{
pm
});
}
};
void
fuse_ops
::
apply
(
module
&
m
)
const
void
fuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
match
::
find_matches
(
m
,
find_contiguous_pointwise
{},
find_gelu
{},
find_gelu_new
{
fast_math
});
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
...
@@ -1063,6 +1089,7 @@ void fuse_ops::apply(module& m) const
match
::
find_matches
(
m
,
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_triadd_layernorm
{},
find_gemm_add
{},
find_gemm_add
{},
find_layernorm_pointwise
{},
find_gemm_pointwise
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_contiguous
{});
match
::
find_matches
(
m
,
find_contiguous
{});
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
View file @
13a7d458
...
@@ -25,6 +25,7 @@
...
@@ -25,6 +25,7 @@
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <migraphx/module_ref.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
...
@@ -62,6 +63,10 @@ std::string make_transformer_args(Ts... xs)
return
make_transformer_args
({
xs
.
str
()...});
return
make_transformer_args
({
xs
.
str
()...});
}
}
std
::
string
generate_pointwise
(
const
module
&
pm
,
const
std
::
string
&
name
);
std
::
string
generate_name_from_ops
(
const
module
&
m
);
}
// namespace gen
}
// namespace gen
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/targets/gpu/jit/layernorm.cpp
0 → 100644
View file @
13a7d458
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
layernorm_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/layernorm.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/preload.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
__global__ void ${kernel}(${params})
{
auto idx = make_index();
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto... xs) {
${layernorm}<${axis}>(${post}, xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
layernorm_compiler
:
compiler
<
layernorm_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"layernorm"
,
"gpu::prelayernorm"
,
"gpu::preadd_layernorm"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
inputs
.
front
().
lens
().
size
()
-
1
;
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
axis
==
faxis
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
(
inputs
.
back
().
elements
()
/
inputs
[
0
].
lens
()[
axis
]);
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"layernorm_kernel"
);
auto
src
=
interpolate_string
(
layernorm_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"post"
,
v
.
get
(
"post"
,
std
::
string
{
"op::id{}"
})},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"layernorm"
,
v
.
get
(
"layernorm"
,
std
::
string
{
"layernorm"
})},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"layernorm"
]
=
"layernorm"
;
v
[
"kernel"
]
=
"layernorm_kernel"
;
if
(
op
.
name
()
==
"gpu::preadd_layernorm"
)
{
v
[
"layernorm"
]
=
"add_layernorm"
;
v
[
"kernel"
]
=
"add_layernorm_kernel"
;
}
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_layernorm"
);
v
[
"post"
]
=
"MIGRAPHX_LIFT(post_layernorm)"
;
v
[
"kernel"
]
=
v
[
"layernorm"
].
to
<
std
::
string
>
()
+
"_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/pointwise.cpp
View file @
13a7d458
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
...
@@ -65,18 +65,6 @@ __global__ void ${kernel}(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
,
"contiguous"
};
}
...
@@ -127,33 +115,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -127,33 +115,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
{
{
assert
(
not
ins
->
module_inputs
().
empty
());
assert
(
not
ins
->
module_inputs
().
empty
());
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
run_passes
(
*
pm
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
auto
pf
=
generate_pointwise
(
*
pm
,
"inner_pointwise"
);
cpp_generator
g
;
std
::
string
lambda
=
"MIGRAPHX_LIFT(inner_pointwise)"
;
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
auto
kernel_name
=
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
return
replace
(
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
compile_op
(
ctx
,
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
g
.
add_point_op
(
"mod"
,
"migraphx::mod(${0}, ${1})"
);
g
.
add_point_op
(
"fmod"
,
"migraphx::fmod(${0}, ${1})"
);
// Add explict conversions
g
.
fresult
([](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()
},
{
"kernel"
,
op_name_string
}}));
{{
"lambda"
,
lambda
},
{
"preamble"
,
pf
},
{
"kernel"
,
kernel_name
}}));
}
}
}
}
};
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
13a7d458
...
@@ -32,7 +32,8 @@
...
@@ -32,7 +32,8 @@
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
[](auto&&... xs) MIGRAPHX_RETURNS((__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
[](auto&&... private_lisft_xs) MIGRAPHX_RETURNS( \
(__VA_ARGS__)(static_cast<decltype(private_lisft_xs)>(private_lisft_xs)...))
namespace
migraphx
{
namespace
migraphx
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
0 → 100644
View file @
13a7d458
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#define MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/print.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
F
,
class
BinOp
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
mean
=
[
&
](
auto
f
)
{
return
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x1
,
auto
x2
)
{
return
f
(
x1
,
x2
)
/
value_type
{
relements
};
})(
input1
,
input2
);
};
// mean(x)
auto
mean_x
=
mean
(
op
);
// mean(m ^ 2)
auto
mean_m2
=
mean
([
&
](
auto
x1
,
auto
x2
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
return
m
*
m
;
});
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
m
=
op
(
x1
,
x2
)
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + 1e-12)
y
=
compute
(
m
*
rsqrt
(
mean_m2
+
value_type
{
1e-12
}),
xs
...);
})(
output
,
input1
,
input2
,
inputs
...);
});
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input
,
class
...
Inputs
>
__device__
void
layernorm
(
F
compute
,
Output
output
,
Input
input
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x
,
auto
)
{
return
x
;
},
output
,
input
,
input
,
inputs
...);
}
template
<
index_int
Axis
,
class
F
,
class
Output
,
class
Input1
,
class
Input2
,
class
...
Inputs
>
__device__
void
add_layernorm
(
F
compute
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
generic_binary_layernorm
<
Axis
>
(
compute
,
[](
auto
x1
,
auto
x2
)
{
return
x1
+
x2
;
},
output
,
input1
,
input2
,
inputs
...);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_LAYERNORM_HPP
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
13a7d458
...
@@ -90,7 +90,7 @@ struct lowest
...
@@ -90,7 +90,7 @@ struct lowest
template
<
class
T
>
template
<
class
T
>
constexpr
operator
T
()
const
constexpr
operator
T
()
const
{
{
return
numeric_lowest
<
T
>
();
return
numeric_lowest
<
vec_type
<
T
>
>
();
}
}
};
};
...
@@ -99,7 +99,7 @@ struct highest
...
@@ -99,7 +99,7 @@ struct highest
template
<
class
T
>
template
<
class
T
>
constexpr
operator
T
()
const
constexpr
operator
T
()
const
{
{
return
numeric_max
<
T
>
();
return
numeric_max
<
vec_type
<
T
>
>
();
}
}
};
};
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
13a7d458
...
@@ -224,6 +224,18 @@ struct block
...
@@ -224,6 +224,18 @@ struct block
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
});
}
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -281,6 +293,13 @@ struct lane
...
@@ -281,6 +293,13 @@ struct lane
}
}
});
});
}
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
using
reduce_type
=
decltype
(
slicer
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
13a7d458
...
@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
...
@@ -192,9 +192,13 @@ struct common_type<T, U, Us...>
template
<
class
...
Ts
>
template
<
class
...
Ts
>
using
common_type_t
=
typename
common_type
<
Ts
...
>::
type
;
using
common_type_t
=
typename
common_type
<
Ts
...
>::
type
;
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
constexpr
unsigned
long
int_max
(
unsigned
long
n
)
{
return
(
1u
<<
(
n
*
8
))
-
1
;
}
template
<
class
T
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_integral
<
T
>{}
or
is_floating_point
<
T
>
{}
or
is_same
<
T
,
migraphx
::
half
>
{})
>
constexpr
T
numeric_max
()
constexpr
T
numeric_max
()
{
{
if
constexpr
(
is_integral
<
T
>
{})
if
constexpr
(
is_integral
<
T
>
{})
...
@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
...
@@ -230,8 +234,6 @@ constexpr T numeric_lowest()
}
}
}
}
#define MIGRAPHX_REQUIRES(...) class = enable_if_t<__VA_ARGS__>
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
13a7d458
...
@@ -175,7 +175,7 @@ template <class T, class Op>
...
@@ -175,7 +175,7 @@ template <class T, class Op>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
return
vec_type
<
T
>
{
x
}
;
else
else
{
{
vec_type
<
T
>
result
=
x
[
0
];
vec_type
<
T
>
result
=
x
[
0
];
...
...
src/targets/gpu/prefuse_ops.cpp
View file @
13a7d458
...
@@ -24,12 +24,53 @@
...
@@ -24,12 +24,53 @@
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
{
namespace
{
template
<
class
Derived
,
std
::
size_t
N
>
struct
layernorm_base
{
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
std
::
vector
<
module_ref
>
mods
)
const
{
std
::
size_t
nargs
=
1
;
if
(
not
mods
.
empty
())
{
auto
*
pm
=
mods
.
front
();
nargs
=
pm
->
get_parameter_names
().
size
();
}
check_shapes
{
inputs
,
static_cast
<
const
Derived
&>
(
*
this
)}.
has
(
nargs
+
N
);
auto
s
=
inputs
.
at
(
0
);
if
(
s
.
scalar
())
{
return
s
;
}
else
if
(
s
.
broadcasted
())
{
return
{
s
.
type
(),
s
.
lens
()};
}
else
{
return
s
.
with_lens
(
s
.
lens
());
}
}
};
struct
layernorm
:
layernorm_base
<
layernorm
,
0
>
{
std
::
string
name
()
const
{
return
"gpu::prelayernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
layernorm
);
struct
add_layernorm
:
layernorm_base
<
add_layernorm
,
1
>
{
std
::
string
name
()
const
{
return
"gpu::preadd_layernorm"
;
}
};
MIGRAPHX_REGISTER_OP
(
add_layernorm
);
struct
find_layernorm
struct
find_layernorm
{
{
auto
matcher
()
const
{
return
match
::
layernorm
();
}
auto
matcher
()
const
{
return
match
::
layernorm
();
}
...
@@ -39,59 +80,30 @@ struct find_layernorm
...
@@ -39,59 +80,30 @@ struct find_layernorm
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"x"
];
auto
x_ins
=
r
.
instructions
[
"x"
];
if
(
not
x_ins
->
get_shape
().
standard
())
m
.
replace_instruction
(
ins
,
layernorm
{},
x_ins
);
x_ins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
x_ins
);
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::layernorm"
),
x_ins
,
a
);
}
}
};
};
struct
find_
tri
addlayernorm
struct
find_add
_
layernorm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
add1
=
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
match
::
name
(
"add"
).
bind
(
"add"
)));
match
::
name
(
"add"
)(
match
::
none_of
(
match
::
is_constant
()),
match
::
args
(
match
::
any
().
bind
(
"z1"
),
match
::
any
().
bind
(
"z2"
)));
auto
add2
=
match
::
name
(
"add"
)(
match
::
either_arg
(
0
,
1
)(
add1
,
match
::
any
().
bind
(
"z3"
)));
return
match
::
layernorm
()(
match
::
var
(
"x"
)(
add2
));
}
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
x_ins
=
r
.
instructions
[
"z1"
];
auto
add_ins
=
r
.
instructions
[
"add"
];
auto
y_ins
=
r
.
instructions
[
"z2"
];
auto
z_ins
=
r
.
instructions
[
"z3"
];
for
(
auto
*
pins
:
{
&
x_ins
,
&
y_ins
,
&
z_ins
})
{
if
(
not
(
*
pins
)
->
get_shape
().
standard
())
*
pins
=
m
.
insert_instruction
(
ins
,
make_op
(
"contiguous"
),
*
pins
);
}
auto
relements
=
x_ins
->
get_shape
().
lens
().
back
();
if
(
relements
>
1024
or
(
relements
%
4
!=
0
and
relements
>
256
))
return
;
auto
a
=
m
.
insert_instruction
(
m
.
replace_instruction
(
ins
,
add_layernorm
{},
add_ins
->
inputs
());
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
x_ins
->
get_shape
())}}));
m
.
replace_instruction
(
ins
,
make_op
(
"gpu::triadd_layernorm"
),
x_ins
,
y_ins
,
z_ins
,
a
);
}
}
};
};
}
// namespace
}
// namespace
void
prefuse_ops
::
apply
(
module
&
m
)
const
void
prefuse_ops
::
apply
(
module
&
m
)
const
{
{
match
::
find_matches
(
m
,
find_
tri
addlayernorm
{},
find_layernorm
{});
match
::
find_matches
(
m
,
find_add
_
layernorm
{},
find_layernorm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
test/include/test.hpp
View file @
13a7d458
...
@@ -108,15 +108,7 @@ struct function
...
@@ -108,15 +108,7 @@ struct function
};
};
template
<
class
Stream
,
class
Iterator
>
template
<
class
Stream
,
class
Iterator
>
inline
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
)
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
);
{
if
(
start
!=
last
)
{
s
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
s
<<
", "
<<
x
;
});
}
return
s
;
}
template
<
class
Stream
>
template
<
class
Stream
>
inline
Stream
&
operator
<<
(
Stream
&
s
,
std
::
nullptr_t
)
inline
Stream
&
operator
<<
(
Stream
&
s
,
std
::
nullptr_t
)
...
@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
...
@@ -136,6 +128,17 @@ inline auto operator<<(Stream& s, const Range& v) -> decltype(stream_range(s, v.
return
s
;
return
s
;
}
}
template
<
class
Stream
,
class
Iterator
>
inline
Stream
&
stream_range
(
Stream
&
s
,
Iterator
start
,
Iterator
last
)
{
if
(
start
!=
last
)
{
s
<<
*
start
;
std
::
for_each
(
std
::
next
(
start
),
last
,
[
&
](
auto
&&
x
)
{
s
<<
", "
<<
x
;
});
}
return
s
;
}
template
<
class
T
>
template
<
class
T
>
const
T
&
get_value
(
const
T
&
x
)
const
T
&
get_value
(
const
T
&
x
)
{
{
...
...
test/simplify_reshapes_test.cpp
View file @
13a7d458
...
@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m)
...
@@ -39,6 +39,15 @@ void run_pass(migraphx::module& m)
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{}});
migraphx
::
run_passes
(
m
,
{
migraphx
::
simplify_reshapes
{},
migraphx
::
dead_code_elimination
{}});
}
}
inline
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
to_lens
(
const
std
::
vector
<
migraphx
::
shape
>&
shapes
)
{
std
::
vector
<
std
::
vector
<
std
::
size_t
>>
result
;
std
::
transform
(
shapes
.
begin
(),
shapes
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
();
});
return
result
;
}
TEST_CASE
(
double_contig
)
TEST_CASE
(
double_contig
)
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
...
@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose)
...
@@ -1275,4 +1284,82 @@ TEST_CASE(transpose_slice_single_transpose)
EXPECT
(
m1
==
m2
);
EXPECT
(
m1
==
m2
);
}
}
TEST_CASE
(
transpose_slice_non_packed_axis
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
transpose
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
sqrt
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
slice
);
m1
.
add_return
({
sqrt
});
}
auto
output_shapes
=
m1
.
get_output_shapes
();
run_pass
(
m1
);
EXPECT
(
m1
.
get_output_shapes
()
==
output_shapes
);
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
12
}}}),
x
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
0
,
2
,
1
,
4
}}}),
unsqueeze
);
auto
slice
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transpose
);
auto
squeeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice
);
auto
sqrt
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"sqrt"
),
squeeze
);
m2
.
add_return
({
sqrt
});
}
EXPECT
(
m1
==
m2
);
}
TEST_CASE
(
transpose_slice_non_packed_multi_axis
)
{
migraphx
::
module
m1
;
{
auto
x
=
m1
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
transpose
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
x
);
auto
slice1
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
transpose
);
auto
slice2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
transpose
);
auto
transpose2
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
slice2
);
auto
slice3
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
transpose
);
m1
.
add_return
({
slice1
,
transpose2
,
slice3
});
}
auto
output_shapes
=
m1
.
get_output_shapes
();
run_pass
(
m1
);
EXPECT
(
to_lens
(
m1
.
get_output_shapes
())
==
to_lens
(
output_shapes
));
migraphx
::
module
m2
;
{
auto
x
=
m2
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
384
,
36
,
64
}});
auto
unsqueeze
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"unsqueeze"
,
{{
"axes"
,
{
2
}},
{
"steps"
,
{
12
}}}),
x
);
auto
transpose
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
3
,
0
,
2
,
1
,
4
}}}),
unsqueeze
);
auto
slice1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
1
}}}),
transpose
);
auto
squeeze1
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice1
);
auto
slice2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
1
}},
{
"ends"
,
{
2
}}}),
transpose
);
auto
squeeze2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice2
);
auto
transpose2
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
squeeze2
);
auto
slice3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
0
}},
{
"starts"
,
{
2
}},
{
"ends"
,
{
3
}}}),
transpose
);
auto
squeeze3
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"squeeze"
,
{{
"axes"
,
{
0
}}}),
slice3
);
m2
.
add_return
({
squeeze1
,
transpose2
,
squeeze3
});
}
EXPECT
(
m1
.
sort
()
==
m2
.
sort
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/verify/test_layernorm.cpp
View file @
13a7d458
...
@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
...
@@ -68,7 +68,7 @@ struct test_layernorm : verify_program<test_layernorm>
{
{
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
size_t
>
dims
=
{
1
,
1
,
5
};
std
::
vector
<
size_t
>
dims
=
{
1
,
2
,
5
};
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
});
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
dims
});
add_layernorm
(
*
mm
,
x
,
dims
);
add_layernorm
(
*
mm
,
x
,
dims
);
return
p
;
return
p
;
...
...
test/verify/test_softmax_large1.cpp
0 → 100644
View file @
13a7d458
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/common.hpp>
struct
test_softmax_large1
:
verify_program
<
test_softmax_large1
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
x
=
mm
->
add_parameter
(
"x"
,
migraphx
::
shape
{
migraphx
::
shape
::
float_type
,
{
2
,
4
}});
auto
large
=
mm
->
add_literal
({
migraphx
::
shape
{
migraphx
::
shape
::
float_type
},
{
10000
}});
auto
add
=
migraphx
::
add_common_op
(
*
mm
,
migraphx
::
make_op
(
"add"
),
{
x
,
large
});
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
add
);
return
p
;
}
};
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment