Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d60364a3
Commit
d60364a3
authored
May 10, 2022
by
Paul
Browse files
Consolidate the vecotrize and preload
parent
5e5ed37a
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
169 additions
and
87 deletions
+169
-87
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+103
-0
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+8
-0
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
+46
-0
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+2
-0
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+9
-79
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+0
-8
No files found.
src/targets/gpu/CMakeLists.txt
View file @
d60364a3
...
@@ -131,6 +131,7 @@ add_library(migraphx_gpu
...
@@ -131,6 +131,7 @@ add_library(migraphx_gpu
clip.cpp
clip.cpp
code_object_op.cpp
code_object_op.cpp
compile_ops.cpp
compile_ops.cpp
compile_gen.cpp
compile_hip.cpp
compile_hip.cpp
compile_hip_code_object.cpp
compile_hip_code_object.cpp
compiler.cpp
compiler.cpp
...
...
src/targets/gpu/compile_gen.cpp
0 → 100644
View file @
d60364a3
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gen
{
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs is half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
type
()
==
shape
::
half_type
;
}))
return
{
2
};
return
{
4
,
2
};
}
vectorize
vectorize
::
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
{
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
()),
axis
};
}
std
::
string
vectorize
::
str
()
const
{
return
"vectorize<"
+
to_string
(
size
)
+
", "
+
to_string
(
axis
)
+
">()"
;
}
preload
preload
::
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
{
result
};
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
{
result
};
}
std
::
string
preload
::
str
()
const
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
args
.
begin
(),
std
::
prev
(
args
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"auto_preload<false, "
+
join_strings
(
bool_strs
,
", "
)
+
">(idx)"
;
}
bool
preload
::
is_preloading
()
const
{
return
std
::
accumulate
(
args
.
begin
(),
args
.
end
(),
false
,
std
::
logical_or
<>
{});
}
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
)
{
return
join_strings
(
std
::
move
(
transformers
),
", "
);
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/compile_hip_code_object.cpp
View file @
d60364a3
...
@@ -119,6 +119,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
...
@@ -119,6 +119,14 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over)
};
};
}
}
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
)
{
{
std
::
vector
<
src_file
>
srcs
;
std
::
vector
<
src_file
>
srcs
;
...
...
src/targets/gpu/include/migraphx/gpu/compile_gen.hpp
0 → 100644
View file @
d60364a3
#ifndef MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#define MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
#include <migraphx/config.hpp>
#include <string>
#include <unordered_map>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
shape
;
namespace
gpu
{
namespace
gen
{
struct
vectorize
{
std
::
size_t
size
;
std
::
size_t
axis
;
static
vectorize
elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
str
()
const
;
};
struct
preload
{
std
::
vector
<
bool
>
args
;
static
preload
broadcasts
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
);
bool
is_preloading
()
const
;
std
::
string
str
()
const
;
};
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
);
std
::
string
make_transformer_args
(
std
::
vector
<
std
::
string
>
transformers
);
template
<
class
...
Ts
>
std
::
string
make_transformer_args
(
Ts
...
xs
)
{
return
make_transformer_args
({
xs
.
str
()...});
}
}
// namespace gen
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_COMPILE_GEN_HPP
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
d60364a3
...
@@ -46,6 +46,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
...
@@ -46,6 +46,8 @@ compute_global_for(context& ctx, std::size_t n, std::size_t over = 1);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
);
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
d60364a3
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/cpp_generator.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
@@ -17,6 +18,8 @@ namespace migraphx {
...
@@ -17,6 +18,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
static
const
char
*
const
pointwise_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/pointwise.hpp>
...
@@ -30,7 +33,7 @@ extern "C" {
...
@@ -30,7 +33,7 @@ extern "C" {
__global__ void kernel(${params})
__global__ void kernel(${params})
{
{
auto idx = make_index();
auto idx = make_index();
pointwise(idx,
auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>()
)(${lambda}, ${args});
pointwise(idx,
${transformers}
)(${lambda}, ${args});
}
}
}
}
...
@@ -50,75 +53,6 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -50,75 +53,6 @@ struct pointwise_compiler : compiler<pointwise_compiler>
else
else
return
1
;
return
1
;
}
}
static
std
::
size_t
find_fast_axis
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
permutation
=
find_permutation
(
inputs
);
auto
it
=
std
::
max_element
(
permutation
.
begin
(),
permutation
.
end
());
return
it
-
permutation
.
begin
();
}
static
std
::
vector
<
bool
>
preload
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
const
std
::
size_t
max_lds_bytes
=
4096
;
std
::
vector
<
bool
>
result
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
shape
&
input
)
{
return
input
.
strides
()[
axis
]
==
0
;
});
auto
bytes
=
std
::
inner_product
(
inputs
.
begin
(),
inputs
.
end
(),
result
.
begin
(),
std
::
size_t
{
0
},
std
::
plus
<>
{},
[](
const
shape
&
s
,
bool
b
)
->
std
::
size_t
{
if
(
b
)
return
s
.
bytes
();
return
0
;
});
if
(
bytes
<
max_lds_bytes
)
return
result
;
// TODO: Try to partially preload items
std
::
fill
(
result
.
begin
(),
result
.
end
(),
false
);
return
result
;
}
static
std
::
string
preload_str
(
const
std
::
vector
<
bool
>&
bs
)
{
std
::
vector
<
std
::
string
>
bool_strs
;
std
::
transform
(
bs
.
begin
(),
std
::
prev
(
bs
.
end
()),
std
::
back_inserter
(
bool_strs
),
[](
bool
b
)
{
if
(
b
)
return
"true"
;
return
"false"
;
});
return
"false, "
+
join_strings
(
bool_strs
,
", "
);
}
static
std
::
vector
<
std
::
size_t
>
vector_sizes
(
const
std
::
vector
<
shape
>&
inputs
)
{
// If all inputs is half then only use half2
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[](
const
auto
&
s
)
{
return
s
.
type
()
==
shape
::
half_type
;
}))
return
{
2
};
return
{
4
,
2
};
}
static
auto
vectorize_elements
(
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
auto
sizes
=
vector_sizes
(
inputs
);
std
::
vector
<
std
::
size_t
>
max_vec_size
;
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
std
::
back_inserter
(
max_vec_size
),
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
return
1
;
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
1
;
});
return
*
std
::
min_element
(
max_vec_size
.
begin
(),
max_vec_size
.
end
());
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
hip_compile_options
options
;
hip_compile_options
options
;
...
@@ -127,21 +61,17 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -127,21 +61,17 @@ struct pointwise_compiler : compiler<pointwise_compiler>
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
axis
=
find_fast_axis
(
options
.
virtual_inputs
);
auto
vec_size
=
vectorize_elements
(
axis
,
options
.
virtual_inputs
);
auto
vec
=
vectorize
::
elements
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
::
broadcasts
(
axis
,
inputs
);
auto
is_preloading
=
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
options
.
set_launch_params
(
v
,
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec
_
size
,
options
.
output
.
elements
()
/
vec
.
size
,
oversubscribe_if
(
not
is_preloading
)));
oversubscribe_if
(
not
preloads
.
is_preloading
()
)));
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
{
"transformers"
,
make_transformer_args
(
preloads
,
vec
)},
{
"axis"
,
std
::
to_string
(
axis
)},
{
"preloads"
,
preload_str
(
preloads
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
return
compile_hip_code_object
(
src
,
options
);
return
compile_hip_code_object
(
src
,
options
);
}
}
...
...
src/targets/gpu/jit/reduce.cpp
View file @
d60364a3
...
@@ -40,14 +40,6 @@ __global__ void kernel(void* input_p, void* output_p)
...
@@ -40,14 +40,6 @@ __global__ void kernel(void* input_p, void* output_p)
)__migraphx__"
;
)__migraphx__"
;
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
)
{
size_t
block_size
=
128
;
while
(
block_size
<=
max_block_size
and
block_size
<=
n
)
block_size
*=
2
;
return
block_size
/
2
;
}
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
{
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment