Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
baac1dab
"docs/zh_CN/FeatureEngineering/GBDTSelector.rst" did not exist on "abc221589c65d75b494407c60a81ca87c3020463"
Commit
baac1dab
authored
May 24, 2023
by
Alan Turner
Browse files
Merge remote-tracking branch 'origin/develop' into ck-host-lib
parents
830dff7a
77042e30
Changes
299
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
225 additions
and
90 deletions
+225
-90
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/context.hpp
src/targets/gpu/include/migraphx/gpu/context.hpp
+9
-2
src/targets/gpu/include/migraphx/gpu/convolution.hpp
src/targets/gpu/include/migraphx/gpu/convolution.hpp
+4
-4
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/hip.hpp
src/targets/gpu/include/migraphx/gpu/hip.hpp
+22
-5
src/targets/gpu/include/migraphx/gpu/lowering.hpp
src/targets/gpu/include/migraphx/gpu/lowering.hpp
+7
-0
src/targets/gpu/include/migraphx/gpu/miopen.hpp
src/targets/gpu/include/migraphx/gpu/miopen.hpp
+11
-4
src/targets/gpu/include/migraphx/gpu/name.hpp
src/targets/gpu/include/migraphx/gpu/name.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/oper.hpp
src/targets/gpu/include/migraphx/gpu/oper.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp
src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/reduce_op.hpp
src/targets/gpu/include/migraphx/gpu/reduce_op.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/target.hpp
src/targets/gpu/include/migraphx/gpu/target.hpp
+0
-1
src/targets/gpu/jit/concat.cpp
src/targets/gpu/jit/concat.cpp
+3
-1
src/targets/gpu/jit/mlir.cpp
src/targets/gpu/jit/mlir.cpp
+1
-1
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+135
-45
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+8
-0
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
+0
-8
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+6
-0
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+10
-2
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+5
-13
No files found.
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
baac1dab
...
...
@@ -72,6 +72,8 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
=
1024
);
std
::
string
generate_make_shape
(
const
shape
&
s
);
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/include/migraphx/gpu/context.hpp
View file @
baac1dab
...
...
@@ -30,6 +30,7 @@
#include <migraphx/gpu/hip.hpp>
#include <migraphx/env.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <unordered_map>
#include <memory>
...
...
@@ -215,6 +216,10 @@ struct context
return
*
current_device
;
}
bool
get_exhaustive_tune_flag
()
const
{
return
exhaustive_tune
;
}
void
set_exhaustive_tune_flag
(
bool
t
)
{
exhaustive_tune
=
t
;
}
hip_device
::
stream
&
get_stream
()
{
return
get_current_device
().
get_stream
();
}
hip_device
::
stream
&
get_stream
(
std
::
size_t
n
)
{
return
get_current_device
().
get_stream
(
n
);
}
...
...
@@ -273,7 +278,8 @@ struct context
auto
v_streams
=
v
.
at
(
"streams"
);
std
::
size_t
n_streams
=
v_streams
.
without_key
().
to
<
std
::
size_t
>
();
this
->
current_device
=
std
::
make_shared
<
hip_device
>
(
0
,
n_streams
);
auto
device
=
get_device_id
();
this
->
current_device
=
std
::
make_shared
<
hip_device
>
(
device
,
n_streams
);
}
void
wait_for
(
any_ptr
queue
)
...
...
@@ -336,7 +342,8 @@ struct context
// TODO: Make this a vector to support multiple devices
std
::
shared_ptr
<
hip_device
>
current_device
;
std
::
vector
<
shared
<
hip_event_ptr
>>
events
;
bool
measure_perf
=
false
;
bool
exhaustive_tune
=
false
;
bool
measure_perf
=
false
;
// for event perf timing
shared
<
hip_event_ptr
>
start_event
=
nullptr
;
shared
<
hip_event_ptr
>
stop_event
=
nullptr
;
...
...
src/targets/gpu/include/migraphx/gpu/convolution.hpp
View file @
baac1dab
...
...
@@ -27,7 +27,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/operation.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/op/identity.hpp>
#include <migraphx/op/convolution.hpp>
...
...
@@ -175,8 +174,9 @@ struct miopen_convolution
auto
*
miopen_stream_handle
=
ctx
.
get_stream
().
get_miopen
();
solution_ptr
=
find_solution
(
miopen_stream_handle
,
conv_problem
.
get
());
auto
status
=
miopenGetSolutionWorkspaceSize
(
solution_ptr
.
get
(),
&
workspace_size
);
solution_ptr
=
find_solution
(
miopen_stream_handle
,
conv_problem
.
get
(),
ctx
.
get_exhaustive_tune_flag
());
auto
status
=
miopenGetSolutionWorkspaceSize
(
solution_ptr
.
get
(),
&
workspace_size
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen"
+
op
.
name
()
+
" : failed to get solution's workspace size"
);
...
...
@@ -233,7 +233,7 @@ struct miopen_convolution
&
perf
,
workspace
.
implicit
(),
workspace_size
,
false
);
ctx
.
get_exhaustive_tune_flag
()
);
if
(
status
!=
miopenStatusSuccess
)
MIGRAPHX_THROW
(
"MIOpen "
+
op
.
name
()
+
" : find convolution failed"
);
algo
=
perf
.
fwd_algo
;
...
...
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
View file @
baac1dab
...
...
@@ -34,6 +34,8 @@ struct module_pass_manager;
namespace
gpu
{
bool
mlir_enabled
();
struct
fuse_mlir
{
context
*
ctx
=
nullptr
;
...
...
src/targets/gpu/include/migraphx/gpu/hip.hpp
View file @
baac1dab
...
...
@@ -29,6 +29,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/dyn_output.hpp>
#include <utility>
namespace
migraphx
{
...
...
@@ -98,6 +99,13 @@ struct hip_sync_stream
return
{};
return
args
.
front
();
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
args
)
const
{
if
(
args
.
empty
())
return
-
1
;
return
0
;
}
};
struct
hip_copy_to_gpu
...
...
@@ -105,7 +113,7 @@ struct hip_copy_to_gpu
std
::
string
name
()
const
{
return
"hip::copy_to_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
).
same_type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
).
same_type
();
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
...
...
@@ -114,6 +122,10 @@ struct hip_copy_to_gpu
if
(
args
.
size
()
==
1
)
return
input
;
argument
result
=
args
[
1
].
share
();
if
(
result
.
get_shape
().
dynamic
())
{
result
=
result
.
reshape
(
args
[
0
].
get_shape
());
}
gpu_copy
(
ctx
,
input
,
result
);
// Associate the input since it was registered with hip
return
{
result
.
get_shape
(),
[
input
,
result
]()
mutable
{
return
result
.
data
();
}};
...
...
@@ -131,19 +143,24 @@ struct hip_copy_from_gpu
std
::
string
name
()
const
{
return
"hip::copy_from_gpu"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
,
2
).
same_type
();
check_shapes
{
inputs
,
*
this
,
true
}.
has
(
1
,
2
).
same_type
();
return
inputs
.
at
(
0
);
}
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
dyn_output
&
dyn_out
,
const
std
::
vector
<
argument
>&
args
)
const
{
if
(
args
.
size
()
==
1
)
{
argument
result
=
allocate_gpu
(
out
put_shape
,
true
);
argument
result
=
allocate_gpu
(
dyn_out
.
com
put
ed
_shape
,
true
);
gpu_copy
(
ctx
,
args
[
0
],
result
);
return
result
;
}
copy_from_gpu
(
ctx
,
args
[
0
],
args
[
1
]);
argument
input
=
args
[
0
].
share
();
if
(
input
.
get_shape
().
dynamic
())
{
input
=
input
.
reshape
(
args
[
1
].
get_shape
());
}
copy_from_gpu
(
ctx
,
input
,
args
[
1
]);
return
args
[
1
];
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
args
)
const
...
...
src/targets/gpu/include/migraphx/gpu/lowering.hpp
View file @
baac1dab
...
...
@@ -33,6 +33,13 @@ inline namespace MIGRAPHX_INLINE_NS {
struct
module
;
namespace
gpu
{
/**
* Compiler pass that makes GPU-specific instruction changes.
* * Copies to and from the device if `offload_copy` is true.
* * Maps instructions to their GPU-specific counterparts.
* * Inserts `allocate` instructions before GPU operators.
*/
struct
lowering
{
context
*
ctx
;
...
...
src/targets/gpu/include/migraphx/gpu/miopen.hpp
View file @
baac1dab
...
...
@@ -75,12 +75,19 @@ using miopen_find_options = MIGRAPHX_MANAGE_PTR(miopenFindOptions_t, miopenDestr
using
miopen_problem
=
MIGRAPHX_MANAGE_PTR
(
miopenProblem_t
,
miopenDestroyProblem
);
using
miopen_solution
=
MIGRAPHX_MANAGE_PTR
(
miopenSolution_t
,
miopenDestroySolution
);
inline
miopen_solution
find_solution
(
miopenHandle_t
handle
,
miopenProblem_t
problem
)
inline
miopen_solution
find_solution
(
miopenHandle_t
handle
,
miopenProblem_t
problem
,
bool
tune
=
false
)
{
miopenSolution_t
solution
;
size_t
found
=
0
;
auto
status
=
miopenFindSolutions
(
handle
,
problem
,
nullptr
,
&
solution
,
&
found
,
1
);
auto
result
=
miopen_solution
{
solution
};
size_t
found
=
0
;
miopen_find_options
fo
=
nullptr
;
if
(
tune
)
{
fo
=
make_obj
<
miopen_find_options
>
(
&
miopenCreateFindOptions
);
miopenSetFindOptionTuning
(
fo
.
get
(),
1
);
}
auto
status
=
miopenFindSolutions
(
handle
,
problem
,
fo
.
get
(),
&
solution
,
&
found
,
1
);
auto
result
=
miopen_solution
{
solution
};
if
(
status
!=
miopenStatusSuccess
or
found
==
0
)
MIGRAPHX_THROW
(
"MIOpen miopenFindSolutions failed"
);
return
result
;
...
...
src/targets/gpu/include/migraphx/gpu/name.hpp
View file @
baac1dab
...
...
@@ -56,7 +56,6 @@ struct oper
return
name
.
substr
(
pos_ns
+
2
);
}
}
return
"unknown_operator_name"
;
}
};
...
...
src/targets/gpu/include/migraphx/gpu/oper.hpp
View file @
baac1dab
...
...
@@ -31,7 +31,6 @@
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
...
...
src/targets/gpu/include/migraphx/gpu/prefix_scan_sum.hpp
View file @
baac1dab
...
...
@@ -33,7 +33,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
...
...
src/targets/gpu/include/migraphx/gpu/reduce_op.hpp
View file @
baac1dab
...
...
@@ -31,7 +31,6 @@
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <migraphx/type_name.hpp>
#include <utility>
#include <iostream>
...
...
src/targets/gpu/include/migraphx/gpu/target.hpp
View file @
baac1dab
...
...
@@ -37,7 +37,6 @@ struct target
std
::
string
name
()
const
;
std
::
vector
<
pass
>
get_passes
(
migraphx
::
context
&
gctx
,
const
compile_options
&
options
)
const
;
migraphx
::
context
get_context
()
const
;
argument
copy_to
(
const
argument
&
arg
)
const
;
argument
copy_from
(
const
argument
&
arg
)
const
;
argument
allocate
(
const
shape
&
s
)
const
;
...
...
src/targets/gpu/jit/concat.cpp
View file @
baac1dab
...
...
@@ -78,7 +78,9 @@ struct concat_compiler : compiler<concat_compiler>
options
.
params
=
"-Wno-float-equal"
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"concat_kernel"
);
auto
axis
=
find_fast_axis
(
options
.
inputs
);
auto
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
vectorize
vec
{};
if
(
axis
!=
v
.
at
(
"axis"
).
to
<
std
::
size_t
>
())
vec
=
vectorize
::
elements
(
ctx
,
axis
,
options
.
inputs
);
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
get_concat_elements
(
options
.
inputs
)
/
vec
.
size
,
256
));
auto
src
=
interpolate_string
(
...
...
src/targets/gpu/jit/mlir.cpp
View file @
baac1dab
...
...
@@ -32,7 +32,7 @@ namespace gpu {
struct
mlir_compiler
:
compiler
<
mlir_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gpu::mlir_
conv
"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"gpu::mlir_
op
"
};
}
operation
compile_op
(
context
&
,
const
std
::
vector
<
shape
>&
,
const
value
&
)
const
{
return
{};
}
...
...
src/targets/gpu/jit/reduce.cpp
View file @
baac1dab
...
...
@@ -60,15 +60,6 @@ __global__ void reduce_kernel(void* input_p, void* output_p)
)__migraphx__"
;
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
}
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
instruction_ref
>&
inputs
)
{
return
get_reduce_elements
(
to_shapes
(
inputs
));
}
static
std
::
vector
<
std
::
size_t
>
get_reduce_lens
(
const
std
::
vector
<
std
::
size_t
>&
input_lens
,
const
std
::
vector
<
std
::
size_t
>&
output_lens
)
{
...
...
@@ -86,9 +77,28 @@ static std::vector<std::size_t> get_reduce_lens(const std::vector<std::size_t>&
return
reduce_lens
;
}
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
template
<
class
T
>
static
shape
get_reduced_shape
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
axes
)
{
auto
lens
=
s
.
lens
();
std
::
fill
(
lens
.
begin
(),
lens
.
end
(),
1
);
for
(
const
auto
&
axis
:
axes
)
lens
[
axis
]
=
s
.
lens
()[
axis
];
return
shape
{
s
.
type
(),
lens
};
}
template
<
class
T
>
static
shape
get_output_shape
(
const
shape
&
s
,
const
std
::
vector
<
T
>&
axes
)
{
auto
lens
=
s
.
lens
();
for
(
const
auto
&
axis
:
axes
)
lens
[
axis
]
=
1
;
return
shape
{
s
.
type
(),
lens
};
}
template
<
class
ReduceLens
>
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
,
ReduceLens
rlens
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
const
auto
init
=
std
::
numeric_limits
<
std
::
size_t
>::
max
();
// The minimum stride
auto
min_stride
=
std
::
inner_product
(
...
...
@@ -103,11 +113,27 @@ static std::string get_reduce_algo(const std::vector<shape>& inputs)
return
"block"
;
}
struct
reduce_compiler
:
compiler
<
reduce_compiler
>
static
std
::
string
get_reduce_algo
(
const
std
::
vector
<
shape
>&
inputs
)
{
auto
rlens
=
get_reduce_lens
(
inputs
.
front
().
lens
(),
inputs
.
back
().
lens
());
return
get_reduce_algo
(
inputs
,
rlens
);
}
struct
simple_reduce_compiler
:
compiler
<
simple_reduce_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"reduce"
,
"reduce_sum"
,
"reduce_mean"
,
"reduce_max"
,
"reduce_min"
,
"reduce_prod"
};
return
{
"simple_reduce"
,
"reduce_sum"
,
"reduce_mean"
,
"reduce_max"
,
"reduce_min"
,
"reduce_prod"
};
}
static
std
::
size_t
get_reduce_elements
(
const
std
::
vector
<
shape
>&
inputs
)
{
return
inputs
.
front
().
elements
()
/
inputs
.
back
().
elements
();
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
...
...
@@ -127,7 +153,7 @@ struct reduce_compiler : compiler<reduce_compiler>
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>
block_size
*
256
)
if
(
relements
>
=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
...
...
@@ -157,44 +183,108 @@ struct reduce_compiler : compiler<reduce_compiler>
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
value
v
=
value
::
object
{};
if
(
op
.
name
()
==
"reduce_sum"
)
{
v
[
"reduction"
]
=
"op::sum{}"
;
}
else
if
(
op
.
name
()
==
"reduce_mean"
)
{
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean<"
+
std
::
to_string
(
reduce_elements
)
+
">{}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
else
if
(
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
reduce_type
))
v
[
"read"
]
=
mean
;
else
v
[
"write"
]
=
mean
;
}
else
if
(
op
.
name
()
==
"reduce_max"
)
{
v
[
"reduction"
]
=
"op::max{}"
;
v
[
"init"
]
=
"lowest{}"
;
}
else
if
(
op
.
name
()
==
"reduce_min"
)
reduce_op
r
{};
r
.
set
(
ins
,
op
);
v
[
"reduction"
]
=
r
.
reduction
;
v
[
"read"
]
=
r
.
read
;
v
[
"write"
]
=
r
.
write
;
v
[
"init"
]
=
r
.
init
;
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
static
const
char
*
const
fused_reduce_kernel
=
R"__migraphx__(
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/reduce.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <args.hpp>
namespace migraphx {
${preamble}
extern "C" {
MIGRAPHX_GLOBAL void ${kernel}(${params})
{
transform_args(make_tensors(), rotate_last(), ${transformers})(${args})([](auto y, auto... xs) {
fused_reduce<reduce::${algo}, ${reduced}>(y, partial(${lambda})(xs...));
});
}
}
} // namespace migraphx
)__migraphx__"
;
struct
fused_reduce_compiler
:
compiler
<
fused_reduce_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"fused_reduce"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
axes
=
v
.
at
(
"axes"
).
to_vector
<
std
::
size_t
>
();
auto
virtual_inputs
=
inputs
;
virtual_inputs
.
push_back
(
get_reduced_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
.
push_back
(
get_output_shape
(
inputs
.
front
(),
axes
));
virtual_inputs
=
reduce_dims
(
virtual_inputs
);
auto
reduce_output_shape
=
virtual_inputs
.
back
();
virtual_inputs
.
pop_back
();
auto
reduction_shape
=
virtual_inputs
.
back
();
virtual_inputs
.
pop_back
();
hip_compile_options
options
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
virtual_inputs
=
virtual_inputs
;
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
auto
nelements
=
reduce_output_shape
.
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
,
reduction_shape
.
lens
()));
if
(
algo
==
"block"
)
{
v
[
"reduction"
]
=
"op::min{}"
;
v
[
"init"
]
=
"highest{}"
;
// Vectorize if the axis is a reduction axis
if
(
reduce_output_shape
.
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
reduction_shape
.
elements
()
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>=
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
else
if
(
op
.
name
()
==
"reduce_prod
"
)
else
if
(
algo
==
"lane
"
)
{
v
[
"reduction"
]
=
"op::product{}"
;
v
[
"init"
]
=
"1"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
,
256
));
}
else
{
MIGRAPHX_THROW
(
"Un
supported reduce"
);
MIGRAPHX_THROW
(
"Un
known reduce algo: "
+
algo
);
}
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"reduce_kernel"
);
auto
src
=
interpolate_string
(
fused_reduce_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"algo"
,
algo
},
{
"reduced"
,
"decltype("
+
generate_make_shape
(
reduce_output_shape
)
+
")"
},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
not
ins
->
module_inputs
().
empty
());
auto
v
=
op
.
to_value
();
auto
*
rm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_reduce
(
*
rm
,
"fused_reduce_op"
);
v
[
"lambda"
]
=
"MIGRAPHX_LIFT(fused_reduce_op)"
;
v
[
"kernel"
]
=
generate_name_from_ops
(
*
rm
)
+
"_kernel"
;
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
v
));
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
baac1dab
...
...
@@ -204,6 +204,14 @@ constexpr auto compose(Fs... fs)
})(
fs
...);
}
template
<
class
F
>
constexpr
auto
partial
(
F
f
)
{
return
[
=
](
auto
...
xs
)
{
return
[
=
](
auto
&&
...
ys
)
{
return
f
(
xs
...,
static_cast
<
decltype
(
ys
)
>
(
ys
)...);
};
};
}
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
{
...
...
src/targets/gpu/kernels/include/migraphx/kernels/hip.hpp
View file @
baac1dab
...
...
@@ -25,17 +25,9 @@
#define MIGRAPHX_GUARD_KERNELS_HIP_HPP
#ifndef MIGRAPHX_USE_HIPRTC
// Workaround macro redefinition issue with clang tidy
#if defined(__HIP_PLATFORM_HCC__) && defined(MIGRAPHX_USE_CLANG_TIDY)
#undef __HIP_PLATFORM_HCC__ // NOLINT
#endif
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
#include <hip/hip_math_constants.h>
#elif defined(MIGRAPHX_ENABLE_HIPRTC_WORKAROUNDS)
#include <hip/hip_common.h>
#include <hip/hip_math_constants.h>
#endif
#endif // MIGRAPHX_GUARD_KERNELS_HIP_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
baac1dab
...
...
@@ -241,6 +241,12 @@ struct index
}
};
#ifdef MIGRAPHX_NLOCAL
#define MIGRAPHX_GLOBAL \
__global__ __attribute__((amdgpu_flat_work_group_size(MIGRAPHX_NLOCAL, MIGRAPHX_NLOCAL)))
#else
#define MIGRAPHX_GLOBAL __global__
#endif
inline
__device__
__attribute__
((
const
))
index
make_index
()
{
return
index
{
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
,
threadIdx
.
x
,
blockIdx
.
x
};
// NOLINT
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
baac1dab
...
...
@@ -48,12 +48,20 @@ __device__ void generic_binary_layernorm(
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input1
,
Axis
>
()
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements_r
=
vec_type
<
value_type
>
{
1.0
/
relements
};
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x
)
{
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto
x2_sqrt
=
x
*
relements_rsqrt
;
return
make_array
(
x_out
,
x2_sqrt
*
x2_sqrt
);
})(
input
);
auto
mean_x
=
means
[
0
];
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
baac1dab
...
...
@@ -138,7 +138,7 @@ MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, floor, ::hfloor)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
isnan
,
::
__hisnan
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
//
MIGRAPHX_DEVICE_MATH_FOR(migraphx::half, sin, ::hsin)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sin
,
::
hsin
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
...
...
@@ -161,8 +161,7 @@ MIGRAPHX_DEVICE_MATH_HALF(fmod, ::fmod)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
// Most but not all of these math ops have operators of the same names. Ones not yet implemented
// at this time are: exp2, exp10, log2, log10, isinf
// Most but not all of these math ops have operators of the same names.
MIGRAPHX_DEVICE_MATH_HALF2
(
abs
,
::
__habs2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
ceil
,
::
h2ceil
)
MIGRAPHX_DEVICE_MATH_HALF2
(
cos
,
::
h2cos
)
...
...
@@ -176,7 +175,7 @@ MIGRAPHX_DEVICE_MATH_HALF2(log, ::h2log)
MIGRAPHX_DEVICE_MATH_HALF2
(
log10
,
::
h2log10
)
MIGRAPHX_DEVICE_MATH_HALF2
(
log2
,
::
h2log2
)
MIGRAPHX_DEVICE_MATH_HALF2
(
rsqrt
,
::
h2rsqrt
)
//
MIGRAPHX_DEVICE_MATH_HALF2(sin, ::h2sin)
MIGRAPHX_DEVICE_MATH_HALF2
(
sin
,
::
h2sin
)
MIGRAPHX_DEVICE_MATH_HALF2
(
sqrt
,
::
h2sqrt
)
template
<
class
T
,
class
U
>
...
...
@@ -189,7 +188,8 @@ MIGRAPHX_DEVICE_MATH_BINARY_FOR(float, max, ::max)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
float
,
min
,
::
min
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
max
,
::
max
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
double
,
min
,
::
min
)
// Add overloads for half that calls the float version
// Add overloads for half that calls the float version, this should use "hmax" and "hmin" once
// perf CI docker is upgraded to rocm-5.5
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
max
,
::
fmaxf
)
MIGRAPHX_DEVICE_MATH_BINARY_FOR
(
migraphx
::
half
,
min
,
::
fminf
)
...
...
@@ -217,14 +217,6 @@ constexpr auto min(const T& a, const U& b)
return
min
<
common_type_t
<
T
,
U
>>
(
a
,
b
);
}
// Sin for half is broken on hip, so use cos instead
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
vec_type
<
T
>,
half
>
{})
>
constexpr
T
sin
(
T
x
)
{
constexpr
const
T
shift
=
HIP_PIO2_F
;
return
migraphx
::
cos
(
shift
-
x
);
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment