Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6f768035
Commit
6f768035
authored
Dec 06, 2023
by
Umang Yadav
Browse files
Merge branch 'rocblas_mlir_fp8' into miopen_fp8
parents
da7717ce
b2542239
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
589 additions
and
356 deletions
+589
-356
src/targets/gpu/device_name.cpp
src/targets/gpu/device_name.cpp
+6
-0
src/targets/gpu/driver/precompile_op.cpp
src/targets/gpu/driver/precompile_op.cpp
+84
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+291
-174
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+26
-14
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
+0
-48
src/targets/gpu/include/migraphx/gpu/device_name.hpp
src/targets/gpu/include/migraphx/gpu/device_name.hpp
+2
-0
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
+2
-1
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
+4
-0
src/targets/gpu/include/migraphx/gpu/miopen.hpp
src/targets/gpu/include/migraphx/gpu/miopen.hpp
+6
-0
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+14
-15
src/targets/gpu/jit/scatter.hpp
src/targets/gpu/jit/scatter.hpp
+78
-0
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+8
-37
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
...targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
+5
-1
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+39
-41
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+13
-13
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+3
-4
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
+3
-3
No files found.
src/targets/gpu/device_name.cpp
View file @
6f768035
...
...
@@ -49,6 +49,12 @@ std::string get_device_name()
return
props
.
gcnArchName
;
}
bool
gfx_has_fp8_intrinsics
()
{
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
());
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx940"
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/
eliminate_fp8
.cpp
→
src/
targets/gpu/driver/precompile_op
.cpp
View file @
6f768035
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -21,49 +21,64 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "migraphx/serialize.hpp"
#include <iterator>
#include <utility>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/ranges.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
void
eliminate_fp8
::
apply
(
module
&
m
)
const
struct
precompile_op
:
action
<
precompile_op
>
{
for
(
auto
ins
:
iterator_for
(
m
)
)
static
program
create_preop_program
(
const
operation
&
preop
,
std
::
vector
<
shape
>
inputs
)
{
if
(
not
contains
(
op_names
,
ins
->
name
())
or
ins
->
get_shape
().
type
()
!=
migraphx
::
shape
::
fp8e4m3fnuz_type
)
continue
;
migraphx
::
shape
::
type_t
orig_type
=
ins
->
get_shape
().
type
();
std
::
vector
<
instruction_ref
>
orig_inputs
=
ins
->
inputs
();
std
::
vector
<
instruction_ref
>
new_inputs
;
std
::
transform
(
orig_inputs
.
begin
(),
orig_inputs
.
end
(),
std
::
back_inserter
(
new_inputs
),
[
&
](
const
auto
&
i
)
{
return
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
target_type
)}}),
i
);
});
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
instruction_ref
>
args
;
inputs
.
pop_back
();
transform
(
inputs
,
range
(
inputs
.
size
()),
std
::
back_inserter
(
args
),
[
&
](
auto
input
,
auto
i
)
{
return
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
i
),
input
);
});
mm
->
add_instruction
(
preop
,
args
);
return
p
;
}
auto
new_ins
=
m
.
insert_instruction
(
ins
,
ins
->
get_operator
(),
{
new_inputs
});
auto
convert_back_ins
=
m
.
insert_instruction
(
ins
,
migraphx
::
make_op
(
"convert"
,
{{
"target_type"
,
migraphx
::
to_value
(
orig_type
)}}),
new_ins
);
m
.
replace_instruction
(
ins
,
convert_back_ins
);
static
operation
get_code_object
(
const
program
&
p
)
{
MIGRAPHX_TIDY_CONST
auto
*
mm
=
p
.
get_main_module
();
auto
it
=
std
::
find_if
(
mm
->
begin
(),
mm
->
end
(),
[](
const
auto
&
ins
)
{
return
(
ins
.
name
()
==
"gpu::code_object"
);
});
if
(
it
==
mm
->
end
())
MIGRAPHX_THROW
(
"Failed to create code object"
);
return
it
->
get_operator
();
}
static
void
apply
(
const
parser
&
p
,
const
value
&
v
)
{
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
name
=
v
.
at
(
"name"
).
to
<
std
::
string
>
();
auto
preop
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
preop
.
from_value
(
v
.
at
(
"fields"
));
bool
exhaustive
=
v
.
get
(
"exhaustive"
,
false
);
auto
prog
=
create_preop_program
(
preop
,
inputs
);
run_passes
(
prog
,
{
lowering
{},
compile_ops
{
&
ctx
,
exhaustive
}});
auto
op
=
get_code_object
(
prog
);
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
preop
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
}
}
;
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
6f768035
...
...
@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_EXTRA_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MLIR
);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
bool
mlir_enabled
()
{
...
...
@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif
}
static
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
{
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
bool
mlir_attention_enabled
()
{
#ifdef MIGRAPHX_MLIR
if
(
not
mlir_enabled
())
return
false
;
return
is_requested
(
"attention"
);
#else
return
false
;
#endif
}
#ifdef MIGRAPHX_MLIR
struct
mlir_op
...
...
@@ -62,41 +94,27 @@ struct mlir_op
return
pack
(
f
(
self
.
op
,
"op"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
module_ref
mod
=
mods
[
0
];
check_shapes
{
inputs
,
*
this
}.
packed_or_broadcasted
();
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
const
std
::
string
&
param_name
:
names
)
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
if
(
ins
->
name
()
==
"@param"
)
{
continue
;
}
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
or
ins
->
name
()
==
"@param"
)
{
ins_shapes
[
ins
]
=
ins
->
get_shape
();
continue
;
}
if
(
ins
->
name
()
==
"@return"
)
{
auto
s
=
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR doesnt support non-standard output"
);
return
s
;
return
ins_shapes
[
ins
->
inputs
().
at
(
0
)].
with_type
(
type
);
}
std
::
vector
<
shape
>
input_shapes
;
input_shapes
.
resize
(
ins
->
inputs
().
size
());
...
...
@@ -112,38 +130,55 @@ struct mlir_op
MIGRAPHX_REGISTER_OP
(
mlir_op
);
namespace
{
std
::
tuple
<
instruction_ref
,
std
::
vector
<
operation
>>
get_fusable_input_op_stream
(
instruction_ref
lower_input
)
{
instruction_ref
upper_input
=
lower_input
;
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"multibroadcast"
,
"broadcast"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
operation
op
=
upper_input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
upper_input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
upper_input
=
upper_input
->
inputs
().
at
(
0
);
}
return
{
upper_input
,
op_stream
};
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
const
std
::
vector
<
instruction_ref
>&
gemm_based_op_inputs
,
const
operation
&
gemm_based_op
)
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
()
)
for
(
instruction_ref
input
:
gemm_based_op
_
inputs
)
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
(
{
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
operation
op
=
input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
auto
[
upper_input
,
op_stream
]
=
get_fusable_input_op_stream
(
input
);
top_inputs
.
push_back
(
upper_input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
upper_
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
,
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
...
...
@@ -183,6 +218,7 @@ auto is_mlir_conv(mlir_mode mode)
return
false
;
if
(
ins
->
name
()
!=
"convolution"
and
ins
->
name
()
!=
"quant_convolution"
)
return
false
;
auto
input_arg_t
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
value
v
=
ins
->
get_operator
().
to_value
();
auto
group
=
v
.
at
(
"group"
).
to
<
int
>
();
if
(
group
!=
1
)
...
...
@@ -190,6 +226,10 @@ auto is_mlir_conv(mlir_mode mode)
// Avoid MLIR assertion: Index < Length && "Invalid index!"
if
(
ins
->
get_shape
().
lens
().
size
()
!=
4
)
return
false
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
float_type
and
input_arg_t
==
shape
::
fp8e4m3fnuz_type
)
return
true
;
if
(
ins
->
get_shape
().
type
()
==
shape
::
int8_type
)
return
true
;
if
(
mode
==
mlir_mode
::
int8
)
...
...
@@ -205,103 +245,140 @@ auto is_mlir_conv(mlir_mode mode)
});
}
struct
find_mlir_fused_ops
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
dot_mode
=
mlir_mode
::
none
;
auto
matcher
()
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
is_mlir_dot
(
dot_mode
),
is_mlir_conv
(
conv_mode
)).
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
if
(
ins
->
name
()
!=
"@literal"
)
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
continue
;
}
return
ins_map
;
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
)
{
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
ins_map
.
count
(
input
))
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
ins_map
.
at
(
input
));
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
parent_mod
->
add_parameter
(
name
,
input
->
get_shape
()));
});
return
parent_mod
->
insert_instructions
(
parent_mod
->
end
(),
pm
,
param_map
);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
,
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
,
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
and
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
and
name
==
"convert"
)
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
fp8e4m3fnuz_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
if
(
result_type
==
shape
::
fp8e4m3fnuz_type
)
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
,
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
,
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
fp8e4m3fnuz_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
and
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
and
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
}
// else
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
mlir_pointwise
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"pointwise"
)
return
false
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
is_pointwise_op_supported_by_mlir
(
i
);
});
}
struct
find_mlir_fused_ops
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
dot_mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
is_mlir_dot
(
dot_mode
),
is_mlir_conv
(
conv_mode
)).
bind
(
"gemm_based_op"
));
return
mlir_pointwise
()(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -311,29 +388,12 @@ struct find_mlir_fused_ops
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
std
::
transform
(
names
.
begin
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
,
&
anchor
=
anchor_op
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
mm
->
add_return
(
mm
->
insert_instructions
(
mm
->
end
(),
pm
,
param_map
));
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
->
inputs
(),
gemm_based_op
->
get_operator
());
mm
->
add_return
(
fold_pointwise_mod
(
ins
,
mm
,
{{
x_ins
,
anchor_op
}}));
std
::
vector
<
instruction_ref
>
inputs
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
...
...
@@ -351,54 +411,104 @@ struct find_mlir_standalone_op
{
mlir_mode
mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
return
Matcher
(
mode
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
conv
_based_op
=
r
.
result
;
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
conv
_based_op
->
inputs
().
begin
(),
conv
_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
auto
gemm
_based_op
=
r
.
result
;
// enable only for fp32/fp16/i8
/fp8
types
if
(
std
::
any_of
(
gemm
_based_op
->
inputs
().
begin
(),
gemm
_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
({
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
fp8e4m3fnuz
_type
,
shape
::
type_t
::
int8
_type
},
shape
::
type_t
::
int8
_type
,
shape
::
type_t
::
fp8e4m3fnuz
_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
conv
_based_op
->
name
()
+
std
::
to_string
(
counter
++
));
mpm
.
create_module
(
"mlir_"
+
gemm
_based_op
->
name
()
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
conv_based_op
);
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
->
inputs
(),
gemm_based_op
->
get_operator
());
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
conv
_based_op
,
mlir_op
{
conv
_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
gemm
_based_op
,
mlir_op
{
gemm
_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
}
};
using
find_mlir_standalone_convolution_op
=
find_mlir_standalone_op
<&
is_mlir_conv
>
;
using
find_mlir_standalone_dot_op
=
find_mlir_standalone_op
<&
is_mlir_dot
>
;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
struct
find_mlir_standalone_attention_op
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::pre_gemm_softmax_gemm"
).
bind
(
"gemm_softmax_gemm"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
auto
gemm_softmax_gemm
=
r
.
instructions
[
"gemm_softmax_gemm"
];
std
::
vector
<
instruction_ref
>
inputs
;
mm
->
set_bypass
();
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
gemm0_inputs
=
gemm_softmax_gemm
->
inputs
();
gemm0_inputs
.
pop_back
();
auto
[
gemm0
,
top_gemm0_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm0_inputs
,
make_op
(
"dot"
));
inputs
.
insert
(
inputs
.
begin
(),
top_gemm0_inputs
.
begin
(),
top_gemm0_inputs
.
end
());
// handle scale
auto
v
=
gemm_softmax_gemm
->
get_operator
().
to_value
();
assert
(
v
.
contains
(
"scale"
));
auto
scale
=
v
.
at
(
"scale"
).
to
<
float
>
();
auto
scale_lit
=
mm
->
add_literal
(
literal
{
shape
{
gemm0
->
get_shape
().
type
()},
{
scale
}});
instruction_ref
scale_lit_mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm0
->
get_shape
().
lens
()}}),
scale_lit
);
auto
scaled_gemm0
=
mm
->
add_instruction
(
make_op
(
"mul"
),
gemm0
,
scale_lit_mbcast
);
auto
softmax
=
mm
->
add_instruction
(
make_op
(
"softmax"
,
{{
"axis"
,
gemm0
->
get_shape
().
lens
().
size
()
-
1
}}),
scaled_gemm0
);
auto
[
old_upper_v
,
upper_v_op_stream
]
=
get_fusable_input_op_stream
(
gemm_softmax_gemm
->
inputs
()[
2
]);
instruction_ref
new_upper_v
=
mm
->
add_parameter
(
"z"
,
old_upper_v
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
upper_v_op_stream
))
{
new_upper_v
=
mm
->
add_instruction
(
op
,
{
new_upper_v
});
}
inputs
.
push_back
(
old_upper_v
);
auto
gemm1
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
softmax
,
new_upper_v
});
ins_map
[
gemm_softmax_gemm
]
=
gemm1
;
auto
ins_to_replace
=
gemm1
;
auto
ins_to_be_replaced
=
gemm_softmax_gemm
;
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
())
{
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_softmax_gemm
;
});
ins_to_be_replaced
=
r
.
instructions
[
"trailing_pm"
];
}
mm
->
add_return
({
ins_to_replace
});
mpm
.
get_module
().
replace_instruction
(
ins_to_be_replaced
,
mlir_op
{
gemm1
->
get_operator
()},
inputs
,
{
mm
});
}
};
struct
find_mlir_attention_fused_ops
:
public
find_mlir_standalone_attention_op
{
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
auto
matcher
()
const
{
auto
standalone_matcher
=
find_mlir_standalone_attention_op
::
matcher
();
return
mlir_pointwise
()(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));
;
}
};
}
// namespace
#endif // MIGRAPHX_MLIR
...
...
@@ -420,13 +530,20 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode
mode
=
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
// Attention offloads; default disabled
if
(
mlir_attention_enabled
())
{
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{});
}
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{.
conv_mode
=
get_mode
(
"fused"
,
mlir_mode
::
fast
),
.
dot_mode
=
get_mode
(
"fused"
,
mode
)});
match
::
find_matches
(
mpm
,
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
int8
)},
find_mlir_standalone_convolution_op
{
get_mode
(
"convolution"
,
mlir_mode
::
all
)},
find_mlir_standalone_dot_op
{
get_mode
(
"dot"
,
mlir_mode
::
none
)});
#else
(
void
)
mpm
;
...
...
src/targets/gpu/gemm_impl.cpp
View file @
6f768035
...
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
...
...
@@ -36,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
...
...
@@ -185,12 +200,17 @@ struct gemm_impl
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
...
@@ -230,7 +250,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
...
...
@@ -240,7 +259,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
...
...
@@ -254,7 +272,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -264,7 +281,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -304,7 +320,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -314,7 +329,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -365,7 +379,8 @@ struct gemm_impl
output_type
,
ldd
,
d_stride
,
num_matrices
);
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
...
...
@@ -398,7 +413,8 @@ struct gemm_impl
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
);
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...
...
@@ -428,7 +444,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -438,7 +453,6 @@ struct gemm_impl
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
@@ -449,7 +463,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -459,7 +472,6 @@ struct gemm_impl
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
r
ocblas_data
type
compute_type
=
rocblas_datatype_f32_r
;
r
b_compute_
type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
6f768035
...
...
@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
);
std
::
vector
<
hiprtc_src_file
>
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
);
...
...
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
deleted
100644 → 0
View file @
da7717ce
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
MIGRAPHX_DEVICE_EXPORT
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/device_name.hpp
View file @
6f768035
...
...
@@ -37,6 +37,8 @@ MIGRAPHX_GPU_EXPORT std::string get_device_name();
MIGRAPHX_GPU_EXPORT
int
get_device_id
();
MIGRAPHX_GPU_EXPORT
bool
gfx_has_fp8_intrinsics
();
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
...
...
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
View file @
6f768035
...
...
@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace
gpu
{
MIGRAPHX_GPU_EXPORT
bool
mlir_enabled
();
MIGRAPHX_GPU_EXPORT
bool
mlir_attention_enabled
();
struct
MIGRAPHX_GPU_EXPORT
fuse_mlir
{
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
bool
enable_extra
=
false
;
std
::
string
name
()
const
{
return
"gpu::fuse_mlir"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
...
...
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
View file @
6f768035
...
...
@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
static
bool
is_mlir_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
type_t
::
float_type
,
shape
::
half_type
},
t
);
}
};
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/miopen.hpp
View file @
6f768035
...
...
@@ -217,6 +217,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss
<<
op
.
mode
;
MIGRAPHX_THROW
(
ss
.
str
());
}
if
(
not
std
::
all_of
(
op
.
dilations
.
cbegin
(),
op
.
dilations
.
cend
(),
[](
std
::
size_t
d
)
{
return
d
==
1
;
}))
{
MIGRAPHX_THROW
(
"Unsupported dilations for pooling: ["
+
to_string_range
(
op
.
dilations
)
+
"]"
);
}
auto
p
=
make_obj
<
pooling_descriptor
>
(
&
miopenCreatePoolingDescriptor
);
int
kdims
=
op
.
kdims
();
...
...
src/targets/gpu/jit/reduce.cpp
View file @
6f768035
...
...
@@ -146,7 +146,6 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
vectorize
vec
{};
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
{
// Vectorize if the axis is a reduction axis
...
...
@@ -170,13 +169,13 @@ struct simple_reduce_compiler : compiler<simple_reduce_compiler>
options
.
kernel_name
=
"reduce_kernel"
;
std
::
string
identity
=
"[](auto x) { return x; }"
;
auto
src
=
interpolate_string
(
simple_reduce_kernel
,
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{{
"reduction"
,
v
.
at
(
"reduction"
).
to
<
std
::
string
>
()},
{
"init"
,
v
.
get
(
"init"
,
std
::
string
{
"0"
})},
{
"read"
,
v
.
get
(
"read"
,
identity
)},
{
"write"
,
v
.
get
(
"write"
,
identity
)},
{
"algo"
,
algo
},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
}
...
...
@@ -267,13 +266,13 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
auto
src
=
interpolate_string
(
fused_reduce_kernel
,
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"algo"
,
algo
},
{
"reduced"
,
"decltype("
+
generate_make_shape
(
reduce_output_shape
)
+
")"
},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"algo"
,
algo
},
{
"reduced"
,
"decltype("
+
generate_make_shape
(
reduce_output_shape
)
+
")"
},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"transformers"
,
make_transformer_args
(
vec
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})}});
options
.
params
+=
"-Wno-float-equal"
;
return
compile_hip_code_object
(
src
,
options
);
}
...
...
src/targets/gpu/
include/migraphx/gpu/pad
.hpp
→
src/targets/gpu/
jit/scatter
.hpp
View file @
6f768035
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-202
2
Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-202
3
Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
...
...
@@ -21,41 +21,58 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_
RTGLIB_PAD
_HPP
#define MIGRAPHX_GUARD_
RTGLIB_PAD
_HPP
#ifndef MIGRAPHX_GUARD_
JIT_SCATTER
_HPP
#define MIGRAPHX_GUARD_
JIT_SCATTER
_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_pad
template
<
typename
Derived
>
struct
scatter_compiler
:
compiler
<
Derived
>
{
op
::
pad
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
const
auto
inputs
=
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()});
hip_compile_options
options
;
options
.
set_launch_params
(
op
.
to_value
(),
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
derived
().
get_kernel_name
(
op
);
options
.
virtual_inputs
=
inputs
;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options
.
params
+=
"-Wno-float-equal"
;
const
auto
src
=
derived
().
make_interpolated_string
(
op
);
return
prepend_copy_data_to_output
(
compile_hip_code_object
(
src
,
options
));
}
std
::
string
name
()
const
{
return
"gpu::pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
compiler_replace
prepend_copy_data_to_output
(
const
operation
&
co
)
const
{
return
shapes
.
size
()
-
1
;
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
op
)
const
{
return
op
.
name
()
+
"_kernel"
;
}
const
Derived
&
derived
()
const
{
return
static_cast
<
const
Derived
&>
(
*
this
);
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/scatternd.cpp
View file @
6f768035
...
...
@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"
;
struct
scatternd_compiler
:
compiler
<
scatternd_compiler
>
struct
scatternd_compiler
:
scatter_
compiler
<
scatternd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
};
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_min"
,
"scatternd_max"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
std
::
string
make_interpolated_string
(
const
operation
&
op
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
"scatternd_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
reduction
=
"assign_"
+
v
.
get
(
"reduction"
,
std
::
string
{
"none"
});
auto
src
=
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
reduction
}});
return
compile_hip_code_object
(
src
,
options
);
const
auto
reduction
=
op
.
name
().
substr
(
std
::
char_traits
<
char
>::
length
(
"scatternd_"
));
return
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
"assign_"
+
reduction
}});
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
auto
reduction
=
op
.
name
().
substr
(
10
);
return
insert
(
compile_op
(
ctx
,
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
{{
"reduction"
,
reduction
}}));
}
compiler_replace
insert
(
const
operation
&
co
)
const
{
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
)
const
{
return
"scatternd_kernel"
;
}
};
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
View file @
6f768035
...
...
@@ -22,8 +22,12 @@
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
template
<
typename
To
,
typename
From
>
template
<
typename
To
,
typename
From
,
MIGRAPHX_REQUIRES
(
is_trivially_copyable
<
To
>{}
and
is_trivially_copyable
<
From
>
{})
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
...
...
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
View file @
6f768035
...
...
@@ -365,15 +365,6 @@ struct float8
inline
__device__
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
__device__
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
__device__
constexpr
bool
operator
==
(
const
float8
&
rhs
)
const
{
if
(
rhs
.
is_nan
()
or
rhs
.
is_inf
()
or
this
->
is_nan
()
or
this
->
is_inf
())
return
false
;
else
if
((
rhs
.
is_zero
()
and
this
->
is_zero
())
or
(
this
->
data
==
rhs
.
data
))
return
true
;
return
false
;
}
inline
__device__
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
...
...
@@ -403,12 +394,20 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_FABS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f; \
return v; \
#define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
}
// NOLINTNEXTLINE
...
...
@@ -417,11 +416,10 @@ using fp8e5m2fnuz = float8<migraphx::fp8::f8_type::bf8, true>;
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(==, T, bool) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_
FAB
S(T)
MIGRAPHX_FP8_
OTHER_OP
S(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2fnuz
)
...
...
@@ -447,7 +445,7 @@ class numeric_limits<fp8e4m3fnuz>
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm
alized(DeNorm)
. DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
...
...
@@ -475,7 +473,7 @@ class numeric_limits<fp8e4m3fn>
}
static
constexpr
__device__
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01
// this is min value that is not DeNorm
alized(DeNorm)
. DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fn
lowest
()
...
...
@@ -503,8 +501,7 @@ class numeric_limits<fp8e5m2fnuz>
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
...
...
@@ -529,8 +526,7 @@ class numeric_limits<fp8e5m2>
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNorm. DeNorm min is 0x01. I am not sure if we want to make
// this distinction. For the floating points we would end up using lowest most of the times.
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
...
...
@@ -539,24 +535,26 @@ class numeric_limits<fp8e5m2>
};
}
// namespace fp8
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_MIN_MAX(T) \
template <> \
constexpr T numeric_max<T, void>() \
{ \
return fp8::numeric_limits<T>::max(); \
} \
template <> \
constexpr T numeric_lowest<T>() \
{ \
return fp8::numeric_limits<T>::lowest(); \
}
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2fnuz
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e4m3fn
);
MIGRAPHX_FP8_MIN_MAX
(
fp8
::
fp8e5m2
);
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_max
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_max different overload definition
// compared to numeric_max defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
max
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_lowest
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_lowest different overload definition
// compared to numeric_lowest defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
lowest
();
}
}
// namespace migraphx
// =================================================================================================
#if defined(__clang__)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
View file @
6f768035
...
...
@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
op
::
product
{});
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
std
::
size_t
num_batches
=
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
op
::
product
{});
const
std
::
size_t
data_batch_stride
=
const
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
const
size_t
j
=
i
/
slice_size
;
const
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
size_t
relative_slice_offset
=
0
;
for
(
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
MIGRAPHX_ASSERT
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
6f768035
...
...
@@ -54,12 +54,11 @@ __device__ void generic_binary_layernorm(
using
value_type
=
typename
Input1
::
type
;
using
vec_value_type
=
vec_type
<
value_type
>
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements_r
=
static_cast
<
vec_value_type
>
(
1.0
/
relements
)
;
constexpr
auto
relements_r
=
vec_value_type
{
1.0
/
relements
}
;
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_value_type
>
(
static_cast
<
vec_value_type
>
(
0
),
static_cast
<
vec_value_type
>
(
0
)),
make_array
<
vec_value_type
>
(
vec_value_type
{
0
},
vec_value_type
{
0
}),
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing
...
...
@@ -71,7 +70,7 @@ __device__ void generic_binary_layernorm(
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
static_cast
<
value_type
>
(
eps
);
value_type
eps_val
=
implicit_conversion
(
eps
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
6f768035
...
...
@@ -290,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template
<
class
T
,
class
U
>
constexpr
auto
convert
(
U
v
)
{
return
vec_transform
(
v
)([](
auto
x
)
{
return
static_cast
<
T
>
(
x
);
});
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
static_cast
<
T
>
(
x
);
});
}
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
6f768035
...
...
@@ -118,7 +118,7 @@ struct highest
template
<
class
T
>
constexpr
operator
T
()
const
{
return
numeric_max
<
vec_type
<
T
>
,
void
>
();
return
numeric_max
<
vec_type
<
T
>>
();
}
};
}
// namespace migraphx
...
...
src/targets/gpu/kernels/include/migraphx/kernels/pad.hpp
View file @
6f768035
...
...
@@ -28,6 +28,7 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/ranges.hpp>
#include <migraphx/kernels/vec.hpp>
namespace
migraphx
{
...
...
@@ -39,7 +40,6 @@ __device__ void pad(const index& idx,
const
PadVal
&
pad_val
)
{
auto
output_shape
=
output
.
get_shape
();
using
otype
=
typename
Output
::
type
;
idx
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
// 1. get current multi-index for output
// 2. get the size of the input to determine input boundaries
...
...
@@ -54,9 +54,9 @@ __device__ void pad(const index& idx,
if
(
any_of
(
range_multi
.
begin
(),
range_multi
.
end
(),
[
&
](
auto
j
)
{
return
multi
[
j
]
<
offsets
[
j
]
or
input_idx
[
j
]
>=
input_bounds
[
j
];
}))
output
[
multi
]
=
otype
(
pad_val
);
output
[
multi
]
=
implicit_conversion
(
pad_val
);
else
output
[
multi
]
=
otype
(
input
[
input_idx
]);
output
[
multi
]
=
implicit_conversion
(
input
[
input_idx
]);
});
}
...
...
Prev
1
2
3
4
5
6
7
8
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment