Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5d6880e5
"vscode:/vscode.git/clone" did not exist on "75d73d38159a6dbe247ed499bb08a044eb207397"
Commit
5d6880e5
authored
Jun 07, 2022
by
Paul
Browse files
Merge branch 'jit-vector-softmax' into bert-opt
parents
349249c1
84327e69
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
395 additions
and
108 deletions
+395
-108
src/module.cpp
src/module.cpp
+8
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+108
-0
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+13
-0
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+46
-0
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+2
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+13
-82
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+27
-20
src/targets/gpu/jit/softmax.cpp
src/targets/gpu/jit/softmax.cpp
+84
-0
src/targets/gpu/kernel.cpp
src/targets/gpu/kernel.cpp
+2
-0
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+40
-3
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+22
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+16
-0
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+3
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/module.cpp
100755 → 100644
View file @
5d6880e5
...
@@ -22,6 +22,8 @@
...
@@ -22,6 +22,8 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_FINALIZE
)
struct
module_impl
struct
module_impl
{
{
// A list is used to keep references to an instruction stable
// A list is used to keep references to an instruction stable
...
@@ -553,8 +555,14 @@ instruction_ref module::find_dangling_reference() const
...
@@ -553,8 +555,14 @@ instruction_ref module::find_dangling_reference() const
void
module
::
finalize
(
context
&
ctx
)
void
module
::
finalize
(
context
&
ctx
)
{
{
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_FINALIZE
{});
for
(
auto
ins
:
iterator_for
(
*
this
))
for
(
auto
ins
:
iterator_for
(
*
this
))
{
{
if
(
trace
)
{
std
::
cout
<<
"Finalize: "
;
this
->
debug_print
(
ins
);
}
ins
->
finalize
(
ctx
);
ins
->
finalize
(
ctx
);
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
for
(
const
auto
&
smod
:
ins
->
module_inputs
())
{
{
...
...
src/targets/gpu/CMakeLists.txt
View file @
5d6880e5
...
@@ -131,6 +131,7 @@ add_library(migraphx_gpu
...
@@ -131,6 +131,7 @@ add_library(migraphx_gpu
clip.cpp
clip.cpp
code_object_op.cpp
code_object_op.cpp
compile_ops.cpp
compile_ops.cpp
compile_gen.cpp
compile_hip.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_hip_code_object.cpp
compiler.cpp
compiler.cpp
...
...
src/targets/gpu/compile_gen.cpp
0 → 100644
View file @
5d6880e5
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gen
{
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs is half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
type
()
==
shape
::
half_type
;
}))
return
{
2
};
return
{
4
,
2
};
}
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
()[
axis
]
==
1
;
}))
return
{
1
,
axis
};
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
return
sizes
.
front
();
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
{
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
()),
axis
};
}
std
::
string
vectorize
::
str
()
const
{
return
"vectorize<"
+
to_string
(
size
)
+
", "
+
to_string
(
axis
)
+
">()"
;
}
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
{
result
};
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
{
result
};
}
std
::
string
preload
::
str
()
const
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
args
.
begin
(),
std
::
prev
(
args
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"auto_preload<false, "
+
join_strings
(
bool_strs
,
", "
)
+
">(idx)"
;
}
bool
preload
::
is_preloading
()
const
{
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
false
,
std
::
logical_or
<>
{});
}
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
)
{
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/compile_hip_code_object.cpp
View file @
5d6880e5
...
@@ -119,8 +119,21 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
...
@@ -119,8 +119,21 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
};
};
}
}
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
{
{
assert
(
options
.
global
>
0
);
assert
(
options
.
local
>
0
);
assert
(
not
options
.
inputs
.
empty
());
assert
(
options
.
inputs
.
size
()
==
options
.
virtual_inputs
.
size
()
or
options
.
virtual_inputs
.
empty
());
std
::
vector
<
src_file
>
srcs
;
std
::
vector
<
src_file
>
srcs
;
std
::
transform
(
migraphx_kernels
().
begin
(),
std
::
transform
(
migraphx_kernels
().
begin
(),
migraphx_kernels
().
end
(),
migraphx_kernels
().
end
(),
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
0 → 100644
View file @
5d6880e5
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <string>
#include <unordered_map>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
shape
;
namespace
gpu
{
namespace
gen
{
struct
vectorize
{
std
::
size_t
size
=
1
;
std
::
size_t
axis
=
0
;
static
vectorize
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
str
()
const
;
};
struct
preload
{
std
::
vector
<
bool
>
args
=
{};
static
preload
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
bool
is_preloading
()
const
;
std
::
string
str
()
const
;
};
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
);
template
<
class
...
Ts
>
std
::
string
make_transformer_args
(
Ts
...
xs
)
{
return
make_transformer_args
({
xs
.
str
()...});
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
5d6880e5
...
@@ -46,6 +46,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
...
@@ -46,6 +46,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
5d6880e5
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -17,6 +18,8 @@ namespace migraphx {
...
@@ -17,6 +18,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/pointwise.hpp>
...
@@ -30,7 +33,7 @@ extern "C" {
...
@@ -30,7 +33,7 @@ extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
auto idx = make_index();
auto idx = make_index();
pointwise(idx,
auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>()
)(${lambda}, ${args});
pointwise(idx,
${transformers}
)(${lambda}, ${args});
}
}
}
}
...
@@ -62,75 +65,6 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -62,75 +65,6 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
else
return
1
;
return
1
;
}
}
static
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
static
std
::
vector
<
bool
>
preload
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
result
;
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
result
;
}
static
std
::
string
preload_str
(
const
std
::
vector
<
bool
>&
bs
)
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
bs
.
begin
(),
std
::
prev
(
bs
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"false, "
+
join_strings
(
bool_strs
,
", "
);
}
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs is half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
type
()
==
shape
::
half_type
;
}))
return
{
2
};
return
{
4
,
2
};
}
static
auto
vectorize_elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
());
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
...
@@ -139,23 +73,20 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -139,23 +73,20 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec_size
=
vectorize_elements
(
axis
,
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
options
.
virtual_inputs
);
auto
is_preloading
=
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
options
.
set_launch_params
(
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
v
,
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec
.
size
,
options
.
output
.
elements
()
/
vec_size
,
oversubscribe_if
(
not
preloads
.
is_preloading
())));
oversubscribe_if
(
not
is_preloading
)));
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"axis"
,
std
::
to_string
(
axis
)},
{
"preloads"
,
preload_str
(
preloads
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
...
src/targets/gpu/jit/reduce.cpp
View file @
5d6880e5
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -16,9 +17,12 @@ namespace migraphx {
...
@@ -16,9 +17,12 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
static
const
char
*
const
simple_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
#include <args.hpp>
namespace migraphx {
namespace migraphx {
...
@@ -26,9 +30,10 @@ namespace migraphx {
...
@@ -26,9 +30,10 @@ namespace migraphx {
${preamble}
${preamble}
extern "C" {
extern "C" {
__global__ void kernel(void* input_p, void* output_p)
__global__ void
reduce_
kernel(void* input_p, void* output_p)
{
{
make_tensors()(input_p, output_p)([](auto input, auto output) {
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
simple_reduce<reduce::${algo}>(${reduction}, ${init}, input, output, ${read}, ${write});
});
});
...
@@ -40,14 +45,6 @@ __global__ void kernel(void* input_p, void* output_p)
...
@@ -40,14 +45,6 @@ __global__ void kernel(void* input_p, void* output_p)
)__migraphx__"
;
)__migraphx__"
;
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
{
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
...
@@ -101,32 +98,42 @@ struct reduce_compiler : compiler<reduce_compiler>
...
@@ -101,32 +98,42 @@ struct reduce_compiler : compiler<reduce_compiler>
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
auto
reduce_elements
=
get_reduce_elements
(
inputs
);
options
.
inputs
=
inputs
;
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
inputs
));
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
faxis
,
options
.
virtual_inputs
);
}
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
if
(
algo
==
"block"
)
{
{
auto
block_size
=
compute_block_size
(
r
educe_
elements
,
256
);
auto
block_size
=
compute_block_size
(
relements
,
256
);
options
.
set_launch_params
(
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
*
block_size
,
256
),
block_size
);
v
,
compute_global_for
(
ctx
,
n
elements
*
block_size
,
256
),
block_size
);
}
}
else
if
(
algo
==
"lane"
)
else
if
(
algo
==
"lane"
)
{
{
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
back
().
elements
()
,
256
));
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
n
elements
,
256
));
}
}
else
else
{
{
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
MIGRAPHX_THROW
(
"Unknown reduce algo: "
+
algo
);
}
}
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"reduce_kernel"
;
options
.
output
=
inputs
.
back
();
std
::
string
identity
=
"[](auto x) { return x; }"
;
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"algo"
,
algo
},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
...
...
src/targets/gpu/jit/softmax.cpp
0 → 100644
View file @
5d6880e5
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/module.hpp>
#include <migraphx/pass_manager.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
static
const
char
*
const
softmax_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/softmax.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void softmax_kernel(void* input_p, void* output_p)
{
transform_args(make_tensors(), ${transformers})(input_p, output_p)([](auto input, auto output) {
softmax<${axis}>(input, output);
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
softmax_compiler
:
compiler
<
softmax_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"softmax"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
// TODO: Use reduce_dims
auto
axis
=
v
.
at
(
"axis"
).
to
<
int64_t
>
();
auto
faxis
=
find_fast_axis
({
inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
faxis
,
inputs
);
}
auto
relements
=
inputs
[
0
].
lens
()[
axis
]
/
vec
.
size
;
auto
nelements
=
inputs
.
back
().
elements
()
/
relements
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
options
.
output
=
inputs
.
back
();
options
.
inputs
=
inputs
;
options
.
kernel_name
=
"softmax_kernel"
;
auto
src
=
interpolate_string
(
softmax_kernel
,
{{
"transformers"
,
make_transformer_args
(
vec
)},
{
"axis"
,
to_string
(
axis
)}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
op
.
to_value
()));
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/kernel.cpp
View file @
5d6880e5
...
@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
...
@@ -59,6 +59,8 @@ void launch_kernel(hipFunction_t fun,
void
*
kernargs
,
void
*
kernargs
,
std
::
size_t
size
)
std
::
size_t
size
)
{
{
assert
(
global
>
0
);
assert
(
local
>
0
);
void
*
config
[]
=
{
void
*
config
[]
=
{
// HIP_LAUNCH_PARAM_* are macros that do horrible things
// HIP_LAUNCH_PARAM_* are macros that do horrible things
#ifdef MIGRAPHX_USE_CLANG_TIDY
#ifdef MIGRAPHX_USE_CLANG_TIDY
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
5d6880e5
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
...
@@ -190,6 +191,13 @@ constexpr auto transform(integral_const_array<T, Xs...>, F f)
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
return
integral_const_array
<
T
,
f
(
Xs
)...
>
{};
}
}
template
<
class
T
,
T
...
Xs
,
class
F
>
constexpr
auto
transform_i
(
integral_const_array
<
T
,
Xs
...
>
,
F
f
)
{
return
sequence_c
<
sizeof
...(
Xs
)
>
(
[
=
](
auto
...
is
)
{
return
integral_const_array
<
T
,
f
(
Xs
,
is
)...
>
{};
});
}
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
template
<
class
T
,
T
...
Xs
,
class
U
,
U
...
Ys
,
class
F
>
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
constexpr
auto
transform
(
integral_const_array
<
T
,
Xs
...
>
,
integral_const_array
<
U
,
Ys
...
>
,
F
f
)
{
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
5d6880e5
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#ifndef MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#define MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#include <migraphx/kernels/
array
.hpp>
#include <migraphx/kernels/
integral_constant
.hpp>
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
#define MIGRAPHX_RETURNS(...) \
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
5d6880e5
...
@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f)
...
@@ -152,6 +152,21 @@ constexpr auto sliced(Slicer slicer, F f)
};
};
}
}
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
compute_reduce_axis
()
{
constexpr
auto
lens
=
transform_i
(
get_shape_c
<
Input
>
{}.
lens
,
[](
index_int
x
,
index_int
i
)
->
index_int
{
if
(
i
==
Axis
)
return
1
;
return
x
;
});
return
make_shape
(
lens
,
get_shape_c
<
Input
>
{}.
strides
);
}
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
struct
block
struct
block
{
{
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -163,9 +178,12 @@ struct block
...
@@ -163,9 +178,12 @@ struct block
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
block_reduce
(
idx
,
return
read
(
x
[
j
],
xs
[
j
]...);
op
,
});
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
read
(
x
[
j
],
xs
[
j
]...);
}),
op
);
});
});
}
}
...
@@ -175,6 +193,14 @@ struct block
...
@@ -175,6 +193,14 @@ struct block
if
(
idx
.
local
==
0
)
if
(
idx
.
local
==
0
)
f
();
f
();
}
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
});
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
@@ -221,6 +247,17 @@ struct lane
...
@@ -221,6 +247,17 @@ struct lane
{
{
f
();
f
();
}
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
sliced
(
slicer
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
}
};
};
template
<
class
Slicer
>
template
<
class
Slicer
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/shape.hpp
View file @
5d6880e5
...
@@ -9,6 +9,7 @@ namespace migraphx {
...
@@ -9,6 +9,7 @@ namespace migraphx {
template
<
class
Lens
,
class
Strides
>
template
<
class
Lens
,
class
Strides
>
struct
shape
struct
shape
{
{
using
shape_type
=
shape
;
using
index_array
=
typename
Lens
::
base_array
;
using
index_array
=
typename
Lens
::
base_array
;
Lens
lens
=
{};
Lens
lens
=
{};
Strides
strides
=
{};
Strides
strides
=
{};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
0 → 100644
View file @
5d6880e5
#ifndef MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#define MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/ops.hpp>
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
auto
batch_max
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
batch_max
);
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
batch_max
)
/
batch_sum
;
})(
output
,
input
);
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_SOFTMAX_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
5d6880e5
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -146,5 +148,19 @@ constexpr auto vec_packed_transform(Ts... xs)
...
@@ -146,5 +148,19 @@ constexpr auto vec_packed_transform(Ts... xs)
};
};
}
}
template
<
class
T
,
class
Op
>
constexpr
auto
vec_reduce
(
T
x
,
Op
op
)
{
if
constexpr
(
vec_size
<
T
>
()
<
2
)
return
x
;
else
{
vec_type
<
T
>
result
=
x
[
0
];
for
(
int
i
=
1
;
i
<
vec_size
<
T
>
();
i
++
)
result
=
op
(
result
,
x
[
i
]);
return
result
;
}
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
5d6880e5
...
@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
...
@@ -213,7 +213,9 @@ template <index_int N, index_int Axis, class T>
__device__
__host__
auto
vectorize_tensor
(
T
x
)
__device__
__host__
auto
vectorize_tensor
(
T
x
)
{
{
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
constexpr
auto
shape
=
get_shape_c
<
T
>
{};
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
if
constexpr
(
shape
.
lens
[
Axis
]
==
1
)
return
x
;
else
if
constexpr
(
shape
.
strides
[
Axis
]
==
0
)
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
return
tensor_step
<
N
>
(
x
,
_c
<
Axis
>
);
else
else
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
return
as_vec
<
N
>
(
x
,
_c
<
Axis
>
);
...
...
src/targets/gpu/lowering.cpp
View file @
5d6880e5
...
@@ -186,7 +186,6 @@ struct miopen_apply
...
@@ -186,7 +186,6 @@ struct miopen_apply
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_output"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"rnn_var_sl_shift_sequence"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"scatter_none"
);
add_extend_op
(
"softmax"
);
add_extend_op
(
"topk"
);
add_extend_op
(
"topk"
);
add_batch_norm_inference_op
();
add_batch_norm_inference_op
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment