Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f8a75f8a
Commit
f8a75f8a
authored
Dec 07, 2023
by
Paul
Browse files
Merge
parents
74448ed6
d00fdf6e
Changes
242
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1536 additions
and
277 deletions
+1536
-277
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+10
-0
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+40
-25
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
+15
-7
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+12
-12
src/targets/gpu/driver/precompile_op.cpp
src/targets/gpu/driver/precompile_op.cpp
+84
-0
src/targets/gpu/fuse_mlir.cpp
src/targets/gpu/fuse_mlir.cpp
+271
-166
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+1
-0
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
+3
-3
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
+2
-1
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
+4
-0
src/targets/gpu/include/migraphx/gpu/miopen.hpp
src/targets/gpu/include/migraphx/gpu/miopen.hpp
+6
-0
src/targets/gpu/jit/scatter.hpp
src/targets/gpu/jit/scatter.hpp
+78
-0
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+8
-37
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
...targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
+37
-0
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
+568
-0
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
...gets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
+331
-0
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
...targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
+13
-13
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+12
-9
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+40
-3
No files found.
src/targets/gpu/compile_gen.cpp
View file @
f8a75f8a
...
...
@@ -54,6 +54,11 @@ vectorize vectorize::elements(std::size_t axis,
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
std
::
size_t
>&
sizes
)
{
// disable vectorization for fp8 types
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
ishape
)
{
return
ishape
.
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
return
{
1
,
axis
};
if
(
std
::
all_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
()[
axis
]
==
1
;
}))
return
{
1
,
axis
};
...
...
@@ -86,6 +91,11 @@ vectorize vectorize::elements(std::size_t axis,
vectorize
vectorize
::
elements
(
context
&
ctx
,
std
::
size_t
axis
,
const
std
::
vector
<
shape
>&
inputs
)
{
// disable vectorization for fp8 types
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
ishape
)
{
return
ishape
.
type
()
==
migraphx
::
shape
::
fp8e4m3fnuz_type
;
}))
return
{
1
,
axis
};
if
(
inputs
.
empty
())
return
{
1
,
axis
};
std
::
size_t
n
=
std
::
max_element
(
inputs
.
begin
(),
...
...
src/targets/gpu/compile_hip.cpp
View file @
f8a75f8a
...
...
@@ -194,7 +194,7 @@ struct hiprtc_program
};
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
string
params
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
hiprtc_program
prog
(
std
::
move
(
srcs
));
...
...
@@ -238,8 +238,9 @@ bool hip_has_flags(const std::vector<std::string>& flags)
}
}
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
std
::
vector
<
hiprtc_src_file
>
hsrcs
{
srcs
.
begin
(),
srcs
.
end
()};
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
...
...
@@ -251,10 +252,21 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
std
::
cout
<<
std
::
string
(
src
.
content
)
<<
std
::
endl
;
}
}
auto
fname
=
fs
::
path
{
"migraphx-hiprtc-driver"
};
#ifdef _WIN32
fname
.
replace_extension
(
".exe"
);
#endif
auto
p
=
dynamic_loader
::
path
(
&
compile_hip_src_with_hiprtc
);
auto
driver
=
p
.
parent_path
()
.
parent_path
()
/
"bin"
/
"migraphx-hiprtc-driver"
;
auto
driver
=
p
.
parent_path
()
/
fname
;
if
(
fs
::
exists
(
driver
))
bool
found
=
fs
::
exists
(
driver
);
if
(
not
found
)
{
driver
=
p
.
parent_path
().
parent_path
()
/
"bin"
/
fname
;
found
=
fs
::
exists
(
driver
);
}
if
(
found
)
{
value
v
;
v
[
"srcs"
]
=
to_value
(
hsrcs
);
...
...
@@ -270,13 +282,13 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if
(
fs
::
exists
(
out
))
return
{
read_buffer
(
out
.
string
())};
}
return
compile_hip_src_with_hiprtc
(
std
::
move
(
hsrcs
),
std
::
move
(
params
)
,
arch
);
return
compile_hip_src_with_hiprtc
(
std
::
move
(
hsrcs
),
params
,
arch
);
}
#else // MIGRAPHX_USE_HIPRTC
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
,
// NOLINT
std
::
string
,
// NOLINT
const
std
::
string
&
,
// NOLINT
const
std
::
string
&
)
{
MIGRAPHX_THROW
(
"Not using hiprtc"
);
...
...
@@ -305,29 +317,15 @@ src_compiler assemble(src_compiler compiler)
return
compiler
;
}
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
)
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
)
{
assert
(
not
srcs
.
empty
());
if
(
not
is_hip_clang_compiler
())
MIGRAPHX_THROW
(
"Unknown hip compiler: "
MIGRAPHX_HIP_COMPILER
);
if
(
params
.
find
(
"-std="
)
==
std
::
string
::
npos
)
params
+=
" --std=c++17"
;
params
+=
" -fno-gpu-rdc"
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG_SYM
{}))
params
+=
" -g"
;
params
+=
" -c"
;
params
+=
" --offload-arch="
+
arch
;
params
+=
" --cuda-device-only"
;
params
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
params
+=
" -DMIGRAPHX_DEBUG"
;
params
+=
" -Wno-unused-command-line-argument -Wno-cuda-compat "
;
params
+=
MIGRAPHX_HIP_COMPILER_FLAGS
;
src_compiler
compiler
;
compiler
.
flags
=
params
;
compiler
.
compiler
=
MIGRAPHX_HIP_COMPILER
;
...
...
@@ -335,6 +333,23 @@ compile_hip_src(const std::vector<src_file>& srcs, std::string params, const std
if
(
has_compiler_launcher
())
compiler
.
launcher
=
MIGRAPHX_HIP_COMPILER_LAUNCHER
;
#endif
if
(
params
.
find
(
"-std="
)
==
std
::
string
::
npos
)
compiler
.
flags
+=
" --std=c++17"
;
compiler
.
flags
+=
" -fno-gpu-rdc"
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG_SYM
{}))
compiler
.
flags
+=
" -g"
;
compiler
.
flags
+=
" -c"
;
compiler
.
flags
+=
" --offload-arch="
+
arch
;
compiler
.
flags
+=
" --cuda-device-only"
;
compiler
.
flags
+=
" -O"
+
string_value_of
(
MIGRAPHX_GPU_OPTIMIZE
{},
"3"
)
+
" "
;
if
(
enabled
(
MIGRAPHX_GPU_DEBUG
{}))
compiler
.
flags
+=
" -DMIGRAPHX_DEBUG"
;
compiler
.
flags
+=
" -Wno-unused-command-line-argument -Wno-cuda-compat "
;
compiler
.
flags
+=
MIGRAPHX_HIP_COMPILER_FLAGS
;
if
(
enabled
(
MIGRAPHX_GPU_DUMP_SRC
{}))
{
for
(
const
auto
&
src
:
srcs
)
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
f8a75f8a
...
...
@@ -200,7 +200,7 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
auto
cos
=
compile_hip_src
(
srcs
,
std
::
move
(
options
.
params
)
,
get_device_name
());
auto
cos
=
compile_hip_src
(
srcs
,
options
.
params
,
get_device_name
());
if
(
cos
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"No code object"
);
return
code_object_op
{
value
::
binary
{
cos
.
front
()},
...
...
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
View file @
f8a75f8a
...
...
@@ -43,24 +43,32 @@ template <index_int N,
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
ForStride
fs
,
Input
input
,
Output
output
)
{
using
type
=
decltype
(
input
(
deduce_for_stride
(
fs
)));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2
][
N
];
type
x
=
init
;
fs
([
&
](
auto
i
)
{
index_int
iout
=
0
;
index_int
iin
=
1
;
if
(
idx
.
local
==
0
)
buffer
[
idx
.
local
]
=
op
(
input
(
i
),
x
);
buffer
[
iout
][
idx
.
local
]
=
op
(
input
(
i
),
x
);
else
buffer
[
idx
.
local
]
=
input
(
i
);
buffer
[
iout
][
idx
.
local
]
=
input
(
i
);
__syncthreads
();
for
(
index_int
s
=
1
;
s
<
idx
.
nlocal
();
s
*=
2
)
{
if
(
idx
.
local
+
s
<
idx
.
nlocal
())
iout
=
1
-
iout
;
iin
=
1
-
iin
;
if
(
idx
.
local
>=
s
)
{
buffer
[
idx
.
local
+
s
]
=
op
(
buffer
[
idx
.
local
],
buffer
[
idx
.
local
+
s
]);
buffer
[
iout
][
idx
.
local
]
=
op
(
buffer
[
iin
][
idx
.
local
],
buffer
[
iin
][
idx
.
local
-
s
]);
}
else
{
buffer
[
iout
][
idx
.
local
]
=
buffer
[
iin
][
idx
.
local
];
}
__syncthreads
();
}
x
=
buffer
[
idx
.
nlocal
()
-
1
];
output
(
i
,
buffer
[
idx
.
local
]);
x
=
buffer
[
iout
][
idx
.
nlocal
()
-
1
];
output
(
i
,
buffer
[
iout
][
idx
.
local
]);
});
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
f8a75f8a
...
...
@@ -146,20 +146,20 @@ __device__ __host__ T to_hip_type(T x)
// Hip doens't support __fp16
inline
__device__
__host__
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
#define MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X> \
struct trait : std::trait<X> \
{ \
}; \
\
template <> \
struct trait<T> : std::true_type \
{ \
#define MIGRAPHX_
DEVICE_
DETAIL_EXTEND_TRAIT_FOR(trait, T) \
template <class X>
\
struct trait : std::trait<X>
\
{
\
};
\
\
template <>
\
struct trait<T> : std::true_type
\
{
\
};
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
__fp16
)
MIGRAPHX_DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
__fp16
)
MIGRAPHX_
DEVICE_
DETAIL_EXTEND_TRAIT_FOR
(
is_floating_point
,
__fp16
)
MIGRAPHX_
DEVICE_
DETAIL_EXTEND_TRAIT_FOR
(
is_signed
,
__fp16
)
MIGRAPHX_
DEVICE_
DETAIL_EXTEND_TRAIT_FOR
(
is_arithmetic
,
__fp16
)
}
// namespace device
}
// namespace gpu
...
...
src/targets/gpu/driver/precompile_op.cpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/driver/action.hpp>
#include <migraphx/gpu/time_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/compile_ops.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
driver
{
struct
precompile_op
:
action
<
precompile_op
>
{
static
program
create_preop_program
(
const
operation
&
preop
,
std
::
vector
<
shape
>
inputs
)
{
program
p
;
auto
*
mm
=
p
.
get_main_module
();
std
::
vector
<
instruction_ref
>
args
;
inputs
.
pop_back
();
transform
(
inputs
,
range
(
inputs
.
size
()),
std
::
back_inserter
(
args
),
[
&
](
auto
input
,
auto
i
)
{
return
mm
->
add_parameter
(
"x"
+
std
::
to_string
(
i
),
input
);
});
mm
->
add_instruction
(
preop
,
args
);
return
p
;
}
static
operation
get_code_object
(
const
program
&
p
)
{
MIGRAPHX_TIDY_CONST
auto
*
mm
=
p
.
get_main_module
();
auto
it
=
std
::
find_if
(
mm
->
begin
(),
mm
->
end
(),
[](
const
auto
&
ins
)
{
return
(
ins
.
name
()
==
"gpu::code_object"
);
});
if
(
it
==
mm
->
end
())
MIGRAPHX_THROW
(
"Failed to create code object"
);
return
it
->
get_operator
();
}
static
void
apply
(
const
parser
&
p
,
const
value
&
v
)
{
context
ctx
;
auto
inputs
=
p
.
parse_shapes
(
v
.
at
(
"inputs"
));
auto
name
=
v
.
at
(
"name"
).
to
<
std
::
string
>
();
auto
preop
=
make_op
(
name
);
if
(
v
.
contains
(
"fields"
))
preop
.
from_value
(
v
.
at
(
"fields"
));
bool
exhaustive
=
v
.
get
(
"exhaustive"
,
false
);
auto
prog
=
create_preop_program
(
preop
,
inputs
);
run_passes
(
prog
,
{
lowering
{},
compile_ops
{
&
ctx
,
exhaustive
}});
auto
op
=
get_code_object
(
prog
);
auto
t
=
time_op
(
ctx
,
op
,
inputs
,
p
.
get
(
v
,
"iterations"
,
100
));
std
::
cout
<<
preop
<<
": "
<<
t
<<
"ms"
<<
std
::
endl
;
}
};
}
// namespace driver
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/fuse_mlir.cpp
View file @
f8a75f8a
...
...
@@ -38,6 +38,18 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_ENABLE_EXTRA_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_DISABLE_MLIR
);
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
bool
mlir_enabled
()
{
...
...
@@ -49,6 +61,26 @@ bool mlir_enabled()
#endif
}
static
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
{
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
bool
mlir_attention_enabled
()
{
#ifdef MIGRAPHX_MLIR
if
(
not
mlir_enabled
())
return
false
;
return
is_requested
(
"attention"
);
#else
return
false
;
#endif
}
#ifdef MIGRAPHX_MLIR
struct
mlir_op
...
...
@@ -62,31 +94,20 @@ struct mlir_op
return
pack
(
f
(
self
.
op
,
"op"
));
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>
&
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
module_ref
mod
=
mods
[
0
];
check_shapes
{
inputs
,
*
this
}.
packed_or_broadcasted
();
if
(
mods
.
size
()
!=
1
)
MIGRAPHX_THROW
(
"should have one submodule."
);
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
module_ref
mod
=
mods
[
0
];
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
auto
type
=
mod
->
get_output_shapes
().
front
().
type
();
std
::
unordered_map
<
instruction_ref
,
shape
>
ins_shapes
;
size_t
param_cnt
=
0
;
std
::
vector
<
std
::
string
>
names
=
mod
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
for
(
const
std
::
string
&
param_name
:
names
)
{
ins_shapes
[
mod
->
get_parameter
(
param_name
)]
=
inputs
[
param_cnt
++
];
}
for
(
auto
ins
:
iterator_for
(
*
mod
))
{
if
(
ins
->
name
()
==
"@param"
)
{
continue
;
}
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
or
ins
->
name
()
==
"@param"
)
{
ins_shapes
[
ins
]
=
ins
->
get_shape
();
continue
;
...
...
@@ -112,38 +133,48 @@ struct mlir_op
MIGRAPHX_REGISTER_OP
(
mlir_op
);
namespace
{
std
::
tuple
<
instruction_ref
,
std
::
vector
<
operation
>>
get_fusable_input_op_stream
(
instruction_ref
lower_input
)
{
instruction_ref
upper_input
=
lower_input
;
std
::
vector
<
operation
>
op_stream
;
while
(
contains
({
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
operation
op
=
upper_input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
upper_input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
upper_input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
upper_input
=
upper_input
->
inputs
().
at
(
0
);
}
return
{
upper_input
,
op_stream
};
}
std
::
tuple
<
instruction_ref
,
std
::
vector
<
instruction_ref
>>
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
instruction_ref
gemm_based_op
)
fuse_input_ops_and_gemm_based_op
(
module_ref
mm
,
const
std
::
vector
<
instruction_ref
>&
gemm_based_op_inputs
,
const
operation
&
gemm_based_op
)
{
std
::
vector
<
instruction_ref
>
top_inputs
;
std
::
vector
<
instruction_ref
>
imm_inputs
;
size_t
input_cnt
=
0
;
for
(
instruction_ref
input
:
gemm_based_op
->
inputs
()
)
for
(
instruction_ref
input
:
gemm_based_op
_
inputs
)
{
std
::
vector
<
operation
>
op_stream
;
while
(
contains
(
{
"slice"
,
"transpose"
,
"contiguous"
,
"reshape"
,
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
operation
op
=
input
->
get_operator
();
if
(
contains
({
"squeeze"
,
"flatten"
,
"unsqueeze"
},
input
->
name
()))
{
op
=
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
input
->
get_shape
().
lens
()}});
}
op_stream
.
push_back
(
op
);
input
=
input
->
inputs
().
at
(
0
);
}
top_inputs
.
push_back
(
input
);
auto
[
upper_input
,
op_stream
]
=
get_fusable_input_op_stream
(
input
);
top_inputs
.
push_back
(
upper_input
);
instruction_ref
prev_input
=
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
input
->
get_shape
());
mm
->
add_parameter
(
"y"
+
std
::
to_string
(
input_cnt
++
),
upper_
input
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
op_stream
))
{
prev_input
=
mm
->
add_instruction
(
op
,
{
prev_input
});
}
imm_inputs
.
push_back
(
prev_input
);
}
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
->
get_operator
(),
imm_inputs
);
instruction_ref
new_gemm_based_op
=
mm
->
add_instruction
(
gemm_based_op
,
imm_inputs
);
return
{
new_gemm_based_op
,
top_inputs
};
}
...
...
@@ -205,102 +236,135 @@ auto is_mlir_conv(mlir_mode mode)
});
}
struct
find_mlir_fused_ops
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
dot_mode
=
mlir_mode
::
none
;
auto
matcher
()
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
is_mlir_dot
(
dot_mode
),
is_mlir_conv
(
conv_mode
)).
bind
(
"gemm_based_op"
));
return
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
create_param_map_with_literals
(
module_ref
mm
,
const
module
*
pm
,
const
shape
&
shape
)
const
{
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
for
(
auto
ins
:
iterator_for
(
*
pm
))
if
(
ins
->
name
()
!=
"@literal"
)
{
if
(
ins
->
name
()
!=
"@literal"
)
{
continue
;
}
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
continue
;
}
return
ins_map
;
literal
r
=
ins
->
get_literal
();
instruction_ref
literal
=
mm
->
add_literal
(
r
);
instruction_ref
mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
shape
.
lens
()}}),
literal
);
ins_map
[
ins
]
=
mbcast
;
}
return
ins_map
;
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
const
std
::
vector
<
instruction_ref
>
fold_pointwise_mod
(
instruction_ref
pm_ins
,
module_ref
parent_mod
,
const
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>&
ins_map
)
{
auto
*
pm
=
pm_ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
std
::
sort
(
names
.
begin
(),
names
.
end
());
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
parent_mod
,
pm
,
pm_ins
->
get_shape
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
pm_ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
](
auto
name
,
auto
input
)
{
if
(
ins_map
.
count
(
input
))
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
ins_map
.
at
(
input
));
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
parent_mod
->
add_parameter
(
name
,
input
->
get_shape
()));
});
return
parent_mod
->
insert_instructions
(
parent_mod
->
end
(),
pm
,
param_map
);
}
// Whitelist supported fusion options, including imposing type constraints
// for cases where MLIR only supports an operation (usually a pointwise function)
// on particular types.
bool
is_pointwise_op_supported_by_mlir
(
const
instruction
&
i
)
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
using
type_t
=
shape
::
type_t
;
const
auto
&
name
=
i
.
name
();
const
auto
result_type
=
i
.
get_shape
().
type
();
const
std
::
initializer_list
<
type_t
>
allowed_types
=
{
type_t
::
float_type
,
type_t
::
half_type
,
type_t
::
int8_type
,
type_t
::
int32_type
,
type_t
::
bool_type
};
// Preliminary type check.
if
(
not
contains
(
allowed_types
,
result_type
))
{
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
,
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
,
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
and
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
and
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
const
std
::
initializer_list
<
std
::
string
>
any_type_ops
=
{
"@literal"
,
"@param"
,
"@return"
};
const
std
::
initializer_list
<
std
::
string
>
no_bool_ops
=
{
"convolution"
,
"quant_convolution"
,
"dot"
,
"quant_dot"
,
"add"
,
"clip"
,
"relu"
,
"sub"
,
"mul"
,
"div"
,
"pow"
,
"where"
,
"quantizelinear"
,
"dequantizelinear"
,
"abs"
,
"neg"
,
};
const
std
::
initializer_list
<
std
::
string
>
fp_only_ops
=
{
"ceil"
,
"erf"
,
"exp"
,
"floor"
,
"log"
,
"recip"
,
"rsqrt"
,
"sigmoid"
,
"softmax"
,
"tanh"
,
};
bool
is_float
=
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
result_type
);
if
(
contains
(
any_type_ops
,
name
))
return
true
;
if
(
result_type
!=
type_t
::
bool_type
and
contains
(
no_bool_ops
,
name
))
return
true
;
if
(
is_float
and
contains
(
fp_only_ops
,
name
))
return
true
;
// Only conversions between floating types are known to be unambigiously
// supported.
if
(
is_float
and
name
==
"convert"
)
{
return
std
::
all_of
(
i
.
inputs
().
begin
(),
i
.
inputs
().
end
(),
[](
const
auto
&
arg
)
{
return
contains
({
type_t
::
float_type
,
type_t
::
half_type
},
arg
->
get_shape
().
type
());
});
}
return
false
;
}
MIGRAPHX_PRED_MATCHER
(
mlir_pointwise
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"pointwise"
)
return
false
;
auto
*
pm
=
ins
->
module_inputs
().
front
();
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
is_pointwise_op_supported_by_mlir
(
i
);
});
}
struct
find_mlir_fused_ops
{
mlir_mode
conv_mode
=
mlir_mode
::
none
;
mlir_mode
dot_mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
auto
dot_or_conv
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
any_of
(
is_mlir_dot
(
dot_mode
),
is_mlir_conv
(
conv_mode
)).
bind
(
"gemm_based_op"
));
return
mlir_pointwise
()(
match
::
any_of
[
match
::
inputs
()](
dot_or_conv
.
bind
(
"x"
)));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
...
...
@@ -309,29 +373,12 @@ struct find_mlir_fused_ops
auto
x_ins
=
r
.
instructions
[
"x"
];
// input after contiguous
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
names
=
pm
->
get_parameter_names
();
// Whitelist pointwise operators.
if
(
std
::
any_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
const
auto
&
i
)
{
return
not
is_pointwise_op_supported_by_mlir
(
i
);
}))
return
;
std
::
sort
(
names
.
begin
(),
names
.
end
());
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
pm
->
name
());
mm
->
set_bypass
();
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
param_map
=
create_param_map_with_literals
(
mm
,
pm
,
gemm_based_op
->
get_shape
());
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
);
std
::
transform
(
names
.
begin
(),
names
.
end
(),
ins
->
inputs
().
begin
(),
std
::
inserter
(
param_map
,
param_map
.
end
()),
[
&
,
&
anchor
=
anchor_op
](
auto
name
,
auto
input
)
{
if
(
input
==
x_ins
)
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
anchor
);
return
std
::
make_pair
(
pm
->
get_parameter
(
name
),
mm
->
add_parameter
(
name
,
input
->
get_shape
()));
});
mm
->
add_return
(
mm
->
insert_instructions
(
mm
->
end
(),
pm
,
param_map
));
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
->
inputs
(),
gemm_based_op
->
get_operator
());
mm
->
add_return
(
fold_pointwise_mod
(
ins
,
mm
,
{{
x_ins
,
anchor_op
}}));
std
::
vector
<
instruction_ref
>
inputs
;
std
::
copy_if
(
ins
->
inputs
().
begin
(),
...
...
@@ -349,52 +396,103 @@ struct find_mlir_standalone_op
{
mlir_mode
mode
=
mlir_mode
::
none
;
auto
matcher
()
const
{
return
Matcher
(
mode
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
conv_based_op
=
r
.
result
;
auto
gemm_based_op
=
r
.
result
;
//
// enable only for fp32/fp16/i8 types
if
(
std
::
any_of
(
conv
_based_op
->
inputs
().
begin
(),
conv
_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
if
(
std
::
any_of
(
gemm
_based_op
->
inputs
().
begin
(),
gemm
_based_op
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
not
contains
(
{
shape
::
type_t
::
float_type
,
shape
::
type_t
::
half_type
,
shape
::
type_t
::
int8_type
},
i
->
get_shape
().
type
());
}))
return
;
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
conv
_based_op
->
name
()
+
std
::
to_string
(
counter
++
));
mpm
.
create_module
(
"mlir_"
+
gemm
_based_op
->
name
()
+
std
::
to_string
(
counter
++
));
mm
->
set_bypass
();
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
conv_based_op
);
auto
[
anchor_op
,
top_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm_based_op
->
inputs
(),
gemm_based_op
->
get_operator
());
mm
->
add_return
({
anchor_op
});
mpm
.
get_module
().
replace_instruction
(
conv
_based_op
,
mlir_op
{
conv
_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
gemm
_based_op
,
mlir_op
{
gemm
_based_op
->
get_operator
()},
top_inputs
,
{
mm
});
}
};
using
find_mlir_standalone_convolution_op
=
find_mlir_standalone_op
<&
is_mlir_conv
>
;
using
find_mlir_standalone_dot_op
=
find_mlir_standalone_op
<&
is_mlir_dot
>
;
/**
* @brief Declares a new MIGraphX environment variable which forces to generate
* only specific MLIR operations.
*
* The variable, if defined, forces MIGraphX to use only specific operations
* with MLIR regardless of the underlying GPU architecture. The variable accepts
* a list of operations separated by comma. The variable recognizes the following
* operations: "fused", "convolution", "dot". If the variable is not defined MIGraphX
* will decide by itself which operations to delegate to MLIR. The variable is
* intended to be primarily used by rocMLIR developers.
*/
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
);
struct
find_mlir_standalone_attention_op
{
auto
matcher
()
const
{
return
match
::
name
(
"gpu::pre_gemm_softmax_gemm"
).
bind
(
"gemm_softmax_gemm"
);
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
static
size_t
counter
=
0
;
module_ref
mm
=
mpm
.
create_module
(
"mlir_"
+
std
::
to_string
(
counter
++
));
auto
gemm_softmax_gemm
=
r
.
instructions
[
"gemm_softmax_gemm"
];
std
::
vector
<
instruction_ref
>
inputs
;
mm
->
set_bypass
();
bool
is_requested
(
std
::
string_view
option
,
bool
fallback
=
false
)
std
::
unordered_map
<
instruction_ref
,
instruction_ref
>
ins_map
;
auto
gemm0_inputs
=
gemm_softmax_gemm
->
inputs
();
gemm0_inputs
.
pop_back
();
auto
[
gemm0
,
top_gemm0_inputs
]
=
fuse_input_ops_and_gemm_based_op
(
mm
,
gemm0_inputs
,
make_op
(
"dot"
));
inputs
.
insert
(
inputs
.
begin
(),
top_gemm0_inputs
.
begin
(),
top_gemm0_inputs
.
end
());
// handle scale
auto
v
=
gemm_softmax_gemm
->
get_operator
().
to_value
();
assert
(
v
.
contains
(
"scale"
));
auto
scale
=
v
.
at
(
"scale"
).
to
<
float
>
();
auto
scale_lit
=
mm
->
add_literal
(
literal
{
shape
{
gemm0
->
get_shape
().
type
()},
{
scale
}});
instruction_ref
scale_lit_mbcast
=
mm
->
add_instruction
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
gemm0
->
get_shape
().
lens
()}}),
scale_lit
);
auto
scaled_gemm0
=
mm
->
add_instruction
(
make_op
(
"mul"
),
gemm0
,
scale_lit_mbcast
);
auto
softmax
=
mm
->
add_instruction
(
make_op
(
"softmax"
,
{{
"axis"
,
gemm0
->
get_shape
().
lens
().
size
()
-
1
}}),
scaled_gemm0
);
auto
[
old_upper_v
,
upper_v_op_stream
]
=
get_fusable_input_op_stream
(
gemm_softmax_gemm
->
inputs
()[
2
]);
instruction_ref
new_upper_v
=
mm
->
add_parameter
(
"z"
,
old_upper_v
->
get_shape
());
for
(
const
auto
&
op
:
reverse
(
upper_v_op_stream
))
{
new_upper_v
=
mm
->
add_instruction
(
op
,
{
new_upper_v
});
}
inputs
.
push_back
(
old_upper_v
);
auto
gemm1
=
mm
->
add_instruction
(
make_op
(
"dot"
),
{
softmax
,
new_upper_v
});
ins_map
[
gemm_softmax_gemm
]
=
gemm1
;
auto
ins_to_replace
=
gemm1
;
auto
ins_to_be_replaced
=
gemm_softmax_gemm
;
if
(
r
.
instructions
.
find
(
"trailing_pm"
)
!=
r
.
instructions
.
end
())
{
ins_to_replace
=
fold_pointwise_mod
(
r
.
instructions
[
"trailing_pm"
],
mm
,
ins_map
)[
0
];
std
::
copy_if
(
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
begin
(),
r
.
instructions
[
"trailing_pm"
]
->
inputs
().
end
(),
std
::
back_inserter
(
inputs
),
[
&
](
auto
input
)
{
return
input
!=
gemm_softmax_gemm
;
});
ins_to_be_replaced
=
r
.
instructions
[
"trailing_pm"
];
}
mm
->
add_return
({
ins_to_replace
});
mpm
.
get_module
().
replace_instruction
(
ins_to_be_replaced
,
mlir_op
{
gemm1
->
get_operator
()},
inputs
,
{
mm
});
}
};
struct
find_mlir_attention_fused_ops
:
public
find_mlir_standalone_attention_op
{
auto
string_value
=
string_value_of
(
MIGRAPHX_MLIR_USE_SPECIFIC_OPS
{},
""
);
if
(
string_value
.
empty
())
return
fallback
;
const
auto
options
=
split_string
(
string_value
,
','
);
return
contains
(
options
,
option
);
}
auto
matcher
()
const
{
auto
standalone_matcher
=
find_mlir_standalone_attention_op
::
matcher
();
return
mlir_pointwise
()(
match
::
any_of
[
match
::
inputs
()](
standalone_matcher
).
bind
(
"trailing_pm"
));
;
}
};
}
// namespace
#endif // MIGRAPHX_MLIR
...
...
@@ -416,6 +514,13 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
mlir_mode
mode
=
(
enabled
(
MIGRAPHX_ENABLE_EXTRA_MLIR
{})
or
enable_extra
)
?
mlir_mode
::
fast
:
mlir_mode
::
none
;
// Attention offloads; default disabled
if
(
mlir_attention_enabled
())
{
match
::
find_matches
(
mpm
,
find_mlir_attention_fused_ops
{});
match
::
find_matches
(
mpm
,
find_mlir_standalone_attention_op
{});
}
match
::
find_matches
(
mpm
,
find_mlir_fused_ops
{.
conv_mode
=
get_mode
(
"fused"
,
mlir_mode
::
fast
),
.
dot_mode
=
get_mode
(
"fused"
,
mode
)});
...
...
src/targets/gpu/gemm_impl.cpp
View file @
f8a75f8a
...
...
@@ -46,6 +46,7 @@ rocblas_datatype get_type(shape::type_t type)
case
shape
::
uint8_type
:
return
rocblas_datatype_u8_r
;
case
shape
::
int32_type
:
return
rocblas_datatype_i32_r
;
case
shape
::
uint32_type
:
return
rocblas_datatype_u32_r
;
case
shape
::
fp8e4m3fnuz_type
:
case
shape
::
tuple_type
:
case
shape
::
bool_type
:
case
shape
::
uint16_type
:
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip.hpp
View file @
f8a75f8a
...
...
@@ -58,10 +58,10 @@ struct hiprtc_src_file
MIGRAPHX_GPU_EXPORT
bool
hip_has_flags
(
const
std
::
vector
<
std
::
string
>&
flags
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src_with_hiprtc
(
std
::
vector
<
hiprtc_src_file
>
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
);
std
::
vector
<
hiprtc_src_file
>
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
std
::
string
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
vector
<
std
::
vector
<
char
>>
compile_hip_src
(
const
std
::
vector
<
src_file
>&
srcs
,
const
std
::
string
&
params
,
const
std
::
string
&
arch
);
MIGRAPHX_GPU_EXPORT
std
::
string
enum_params
(
std
::
size_t
count
,
std
::
string
param
);
...
...
src/targets/gpu/include/migraphx/gpu/fuse_mlir.hpp
View file @
f8a75f8a
...
...
@@ -34,10 +34,11 @@ struct module_pass_manager;
namespace
gpu
{
MIGRAPHX_GPU_EXPORT
bool
mlir_enabled
();
MIGRAPHX_GPU_EXPORT
bool
mlir_attention_enabled
();
struct
MIGRAPHX_GPU_EXPORT
fuse_mlir
{
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
bool
enable_extra
=
false
;
std
::
string
name
()
const
{
return
"gpu::fuse_mlir"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
...
...
src/targets/gpu/include/migraphx/gpu/gemm_softmax_gemm.hpp
View file @
f8a75f8a
...
...
@@ -66,6 +66,10 @@ struct gemm_softmax_gemm
}
static
bool
is_ck_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
half_type
},
t
);
}
static
bool
is_mlir_supported_type
(
shape
::
type_t
t
)
{
return
contains
({
shape
::
type_t
::
float_type
,
shape
::
half_type
},
t
);
}
};
}
// namespace gpu
...
...
src/targets/gpu/include/migraphx/gpu/miopen.hpp
View file @
f8a75f8a
...
...
@@ -211,6 +211,12 @@ inline pooling_descriptor make_pooling(const migraphx::op::pooling& op)
ss
<<
op
.
mode
;
MIGRAPHX_THROW
(
ss
.
str
());
}
if
(
not
std
::
all_of
(
op
.
dilations
.
cbegin
(),
op
.
dilations
.
cend
(),
[](
std
::
size_t
d
)
{
return
d
==
1
;
}))
{
MIGRAPHX_THROW
(
"Unsupported dilations for pooling: ["
+
to_string_range
(
op
.
dilations
)
+
"]"
);
}
auto
p
=
make_obj
<
pooling_descriptor
>
(
&
miopenCreatePoolingDescriptor
);
int
kdims
=
op
.
kdims
();
...
...
src/targets/gpu/jit/scatter.hpp
0 → 100644
View file @
f8a75f8a
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_JIT_SCATTER_HPP
#define MIGRAPHX_GUARD_JIT_SCATTER_HPP
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
template
<
typename
Derived
>
struct
scatter_compiler
:
compiler
<
Derived
>
{
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
const
auto
inputs
=
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()});
hip_compile_options
options
;
options
.
set_launch_params
(
op
.
to_value
(),
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
derived
().
get_kernel_name
(
op
);
options
.
virtual_inputs
=
inputs
;
// The compiler protests the inequality comparison in assign_mul when pertaining to floating
// point, despite it making sense in the context. Thus the warning removal.
options
.
params
+=
"-Wno-float-equal"
;
const
auto
src
=
derived
().
make_interpolated_string
(
op
);
return
prepend_copy_data_to_output
(
compile_hip_code_object
(
src
,
options
));
}
compiler_replace
prepend_copy_data_to_output
(
const
operation
&
co
)
const
{
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
op
)
const
{
return
op
.
name
()
+
"_kernel"
;
}
const
Derived
&
derived
()
const
{
return
static_cast
<
const
Derived
&>
(
*
this
);
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/jit/scatternd.cpp
View file @
f8a75f8a
...
...
@@ -21,11 +21,7 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include "scatter.hpp"
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -55,46 +51,21 @@ MIGRAPHX_GLOBAL void scatternd_kernel(void* in_indices, void* in_updates, void*
)__migraphx__"
;
struct
scatternd_compiler
:
compiler
<
scatternd_compiler
>
struct
scatternd_compiler
:
scatter_
compiler
<
scatternd_compiler
>
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
};
return
{
"scatternd_none"
,
"scatternd_add"
,
"scatternd_mul"
,
"scatternd_min"
,
"scatternd_max"
};
}
operation
compile_op
(
context
&
ctx
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
std
::
string
make_interpolated_string
(
const
operation
&
op
)
const
{
hip_compile_options
options
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
inputs
.
at
(
1
).
elements
()));
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
kernel_name
=
"scatternd_kernel"
;
options
.
virtual_inputs
=
inputs
;
auto
reduction
=
"assign_"
+
v
.
get
(
"reduction"
,
std
::
string
{
"none"
});
auto
src
=
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
reduction
}});
return
compile_hip_code_object
(
src
,
options
);
const
auto
reduction
=
op
.
name
().
substr
(
std
::
char_traits
<
char
>::
length
(
"scatternd_"
));
return
interpolate_string
(
scatternd_kernel
,
{{
"reduction"
,
"assign_"
+
reduction
}});
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
assert
(
starts_with
(
op
.
name
(),
"scatternd_"
));
auto
reduction
=
op
.
name
().
substr
(
10
);
return
insert
(
compile_op
(
ctx
,
to_shapes
(
std
::
vector
<
instruction_ref
>
{
ins
->
inputs
().
begin
()
+
1
,
ins
->
inputs
().
end
()}),
{{
"reduction"
,
reduction
}}));
}
compiler_replace
insert
(
const
operation
&
co
)
const
{
return
{
co
,
[](
module
&
m
,
instruction_ref
ins
,
const
operation
&
op
)
{
auto
args
=
ins
->
inputs
();
args
.
back
()
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::copy"
),
args
.
front
(),
args
.
back
());
args
.
erase
(
args
.
begin
());
return
m
.
replace_instruction
(
ins
,
op
,
args
);
}};
}
std
::
string
get_kernel_name
(
const
operation
&
)
const
{
return
"scatternd_kernel"
;
}
};
}
// namespace gpu
...
...
src/targets/gpu/kernels/include/migraphx/kernels/bit_cast.hpp
0 → 100644
View file @
f8a75f8a
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#define MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
#include <migraphx/kernels/type_traits.hpp>
namespace
migraphx
{
template
<
typename
To
,
typename
From
,
MIGRAPHX_REQUIRES
(
is_trivially_copyable
<
To
>{}
and
is_trivially_copyable
<
From
>
{})
>
inline
constexpr
To
bit_cast
(
From
fr
)
noexcept
{
static_assert
(
sizeof
(
To
)
==
sizeof
(
From
));
return
__builtin_bit_cast
(
To
,
fr
);
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_BITCAST_HPP
src/targets/gpu/kernels/include/migraphx/kernels/float8.hpp
0 → 100644
View file @
f8a75f8a
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP
#define MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wfloat-equal"
#pragma clang diagnostic ignored "-Wc++20-extensions" // required for "asm" inside constexpr
#endif // __clang__
// We are clipping in down conversion by default
#define MIGRAPHX_F8_DOWNCAST_CLIPPING 1 // NOLINT
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/float8_impl.hpp>
namespace
migraphx
{
namespace
fp8
{
enum
class
rounding_mode
{
standard
,
// standard rounding is doing RNE -- round to nearest even
stochastic
};
enum
class
f8_type
{
bf8
=
0
,
// s1e5m2
fp8
=
1
// s1e4m3
};
template
<
typename
T
>
class
numeric_limits
;
template
<
migraphx
::
fp8
::
f8_type
T
=
migraphx
::
fp8
::
f8_type
::
fp8
,
bool
FNUZ
=
true
>
struct
float8
{
uint8_t
data
;
// default constructor
__device__
constexpr
float8
()
=
default
;
// default copy constructor
__device__
constexpr
float8
(
const
float8
&
y
)
=
default
;
struct
from_bits_t
{
};
static
constexpr
__device__
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
__device__
explicit
constexpr
float8
(
uint8_t
bits
,
from_bits_t
)
:
data
(
bits
)
{}
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// device specific optimized F8 down-conversion code
template
<
bool
stochastic_rounding
=
false
>
static
__device__
uint8_t
cast_to_f8_from_f32
(
float
v
,
uint32_t
rng
=
0
)
{
uint8_t
i8data
=
0x00
;
union
{
float
fval
;
uint32_t
i32val
;
uint8_t
i8val
[
4
];
// NOTE: not endian independent
}
val
;
uint32_t
ival
=
0
;
val
.
fval
=
v
;
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
else
{
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
57344.0
,
-
57344.0
);
}
#endif
if
(
stochastic_rounding
)
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
ival
=
__builtin_amdgcn_cvt_sr_fp8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
else
{
ival
=
__builtin_amdgcn_cvt_sr_bf8_f32
(
val
.
fval
,
rng
,
ival
,
0
);
// 0 pos
}
}
else
// RNE CVT
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
}
else
{
ival
=
__builtin_amdgcn_cvt_pk_bf8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0}
}
}
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
// little endian
return
i8data
;
}
#endif // __gfx940__
// constructor from float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: ON-DEVICE... always optimal bias
explicit
constexpr
__device__
float8
(
const
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
(
__builtin_is_constant_evaluated
())
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
}
}
else
{
// runtime branch, use cast_to_f8_from_f32 if want to avoid it
if
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
)
data
=
cast_to_f8_from_f32
<
true
>
(
v
,
rng
);
else
data
=
cast_to_f8_from_f32
<
false
>
(
v
);
}
}
#else
// DEVICE for non-gfx940 using s/w simulation
explicit
constexpr
__device__
float8
(
const
float
v
,
migraphx
::
fp8
::
rounding_mode
rm
=
migraphx
::
fp8
::
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_F8_DOWNCAST_CLIPPING
}
else
{
#ifdef MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#else // MIGRAPHX_F8_DOWNCAST_CLIPPING
data
=
migraphx
::
fp8
::
impl
::
cast_to_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
,
false
/*clip*/
>
(
v
,
(
rm
==
migraphx
::
fp8
::
rounding_mode
::
stochastic
),
rng
);
#endif // MIGRAPHX_FP8_DOWNCAST_CLIPPING}
}
}
#endif // __gfx940___
// Constructor from half
explicit
constexpr
__device__
float8
(
const
_Float16
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
(
static_cast
<
float
>
(
v
),
rm
,
rng
)
{
}
// constructor from int
explicit
constexpr
__device__
float8
(
const
int
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
(
static_cast
<
float
>
(
v
),
rm
,
rng
)
{
}
// constructor from uint
explicit
constexpr
__device__
float8
(
const
uint32_t
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
(
static_cast
<
float
>
(
v
),
rm
,
rng
)
{
}
// constructor from double
explicit
constexpr
__device__
float8
(
const
double
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
(
static_cast
<
float
>
(
v
),
rm
,
rng
)
{
}
// constructor from bool
explicit
constexpr
__device__
float8
(
const
bool
v
,
rounding_mode
rm
=
rounding_mode
::
standard
,
uint32_t
rng
=
0
)
:
float8
(
static_cast
<
float
>
(
v
),
rm
,
rng
)
{
}
// convert to float
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) // NOLINT
// upcast using device specific intrinsic
inline
constexpr
__device__
operator
float
()
const
{
if
(
__builtin_is_constant_evaluated
())
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
else
{
float
fval
=
0
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
__asm__
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
else
{
__asm__
volatile
(
"v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
}
return
fval
;
}
}
#else // non gfx940
inline
constexpr
__device__
operator
float
()
const
{
if
constexpr
(
T
==
migraphx
::
fp8
::
f8_type
::
fp8
)
{
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
3
,
4
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
// else
return
migraphx
::
fp8
::
impl
::
cast_from_f8
<
2
,
5
,
float
,
FNUZ
/*negative_zero_nan*/
>
(
data
);
}
#endif
inline
constexpr
explicit
__device__
operator
bool
()
const
{
return
not
is_zero
();
}
// check for zero
inline
__device__
constexpr
bool
is_zero
()
const
{
if
constexpr
(
FNUZ
)
{
return
data
==
0x00
;
}
else
{
return
(
data
==
0x00
)
||
(
data
==
0x80
);
}
}
// check for nan
inline
__device__
constexpr
bool
is_nan
()
const
{
if
constexpr
(
FNUZ
)
{
return
data
==
0x80
;
}
else
{
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
return
(
data
==
0x7D
)
or
(
data
==
0x7E
)
or
(
data
==
0x7F
)
or
(
data
==
0xFD
)
or
(
data
==
0xFE
)
or
(
data
==
0xFF
);
}
else
{
return
(
data
==
0x7F
)
or
(
data
==
0xFF
);
}
}
}
// check for inf
inline
__device__
constexpr
bool
is_inf
()
const
{
if
constexpr
(
FNUZ
)
{
return
data
==
0x80
;
}
else
{
if
(
T
==
migraphx
::
fp8
::
f8_type
::
bf8
)
{
return
(
data
==
0x7C
)
or
(
data
==
0xFC
);
}
else
{
// no infinities in e4m3fn, represent them as NaNs
return
(
data
==
0x7F
)
or
(
data
==
0xFF
);
}
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_SHORT_UNARY_OP(unary_op, binary_op) \
constexpr float8& __device__ operator unary_op(const float8& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
} \
constexpr float8& __device__ operator unary_op(const float& rhs) \
{ \
const auto tmp = static_cast<float>(*this) binary_op static_cast<float>(rhs); \
*this = static_cast<float8>(tmp); \
return *this; \
}
MIGRAPHX_FP8_SHORT_UNARY_OP
(
*=
,
*
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
-=
,
-
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
+=
,
+
)
MIGRAPHX_FP8_SHORT_UNARY_OP
(
/=
,
/
)
inline
__device__
constexpr
float8
&
operator
=
(
const
float8
&
rhs
)
=
default
;
inline
__device__
constexpr
float8
&
operator
=
(
float8
&&
rhs
)
noexcept
=
default
;
inline
__device__
constexpr
bool
operator
<
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
<
them
;
}
inline
__device__
constexpr
bool
operator
>
(
const
float8
&
rhs
)
const
{
const
auto
we
=
static_cast
<
float
>
(
*
this
);
const
auto
them
=
static_cast
<
float
>
(
rhs
);
return
we
>
them
;
}
};
// https://onnx.ai/onnx/technical/float8.html
using
fp8e4m3fn
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
false
>
;
using
fp8e5m2
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
false
>
;
using
fp8e4m3fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
fp8
,
true
>
;
using
fp8e5m2fnuz
=
float8
<
migraphx
::
fp8
::
f8_type
::
bf8
,
true
>
;
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_BINARY_OP(binary_op, T, U) \
inline constexpr U __device__ operator binary_op(const T& lhs, const T& rhs) \
{ \
return U(static_cast<float>(lhs) binary_op static_cast<float>(rhs)); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_OTHER_OPS(T) \
inline constexpr __device__ T fabs(T v) \
{ \
/*NOLINTNEXTLINE*/
\
v.data = v.data & 0x7f; \
return v; \
} \
inline __device__ constexpr bool operator==(const T& lhs, const T& rhs) \
{ \
if(rhs.is_nan() or rhs.is_inf() or lhs.is_nan() or lhs.is_inf()) \
return false; \
else if((rhs.is_zero() and lhs.is_zero()) or (lhs.data == rhs.data)) \
return true; \
return false; \
}
// NOLINTNEXTLINE
#define MIGRAPHX_FP8_GEN_OP_OVERLOADS(T) \
MIGRAPHX_FP8_BINARY_OP(*, T, T) \
MIGRAPHX_FP8_BINARY_OP(-, T, T) \
MIGRAPHX_FP8_BINARY_OP(/, T, T) \
MIGRAPHX_FP8_BINARY_OP(+, T, T) \
MIGRAPHX_FP8_BINARY_OP(>=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(<=, T, bool) \
MIGRAPHX_FP8_BINARY_OP(!=, T, bool) \
MIGRAPHX_FP8_OTHER_OPS(T)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e5m2fnuz
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e4m3fn
)
MIGRAPHX_FP8_GEN_OP_OVERLOADS
(
fp8e4m3fnuz
)
template
<
>
class
numeric_limits
<
fp8e4m3fnuz
>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
__device__
fp8e4m3fnuz
epsilon
()
{
return
fp8e4m3fnuz
(
0x28
,
fp8e4m3fnuz
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
__device__
fp8e4m3fnuz
quiet_NaN
()
{
return
fp8e4m3fnuz
(
0x80
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fnuz
max
()
{
return
fp8e4m3fnuz
(
0x7F
,
fp8e4m3fnuz
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fnuz
min
()
{
return
fp8e4m3fnuz
(
0x08
,
fp8e4m3fnuz
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fnuz
lowest
()
{
return
fp8e4m3fnuz
(
0xFF
,
fp8e4m3fnuz
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
fp8e4m3fn
>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
__device__
fp8e4m3fn
epsilon
()
{
return
fp8e4m3fn
(
0x20
,
fp8e4m3fn
::
from_bits
());
}
// NOLINTNEXTLINE
static
constexpr
__device__
fp8e4m3fn
quiet_NaN
()
{
return
fp8e4m3fn
(
0x7F
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fn
max
()
{
return
fp8e4m3fn
(
0x7E
,
fp8e4m3fn
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01
static
constexpr
__device__
fp8e4m3fn
min
()
{
return
fp8e4m3fn
(
0x08
,
fp8e4m3fn
::
from_bits
());
}
static
constexpr
__device__
fp8e4m3fn
lowest
()
{
return
fp8e4m3fn
(
0xFE
,
fp8e4m3fn
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
fp8e5m2fnuz
>
{
public:
static
constexpr
bool
has_infinity
=
false
;
static
constexpr
__device__
fp8e5m2fnuz
epsilon
()
{
return
fp8e5m2fnuz
(
0x34
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2fnuz
quiet_NaN
()
// NOLINT
{
return
fp8e5m2fnuz
(
0x80
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2fnuz
max
()
{
return
fp8e5m2fnuz
(
0x7F
,
fp8e5m2fnuz
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2fnuz
min
()
{
return
fp8e5m2fnuz
(
0x4
,
fp8e5m2fnuz
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2fnuz
lowest
()
{
return
fp8e5m2fnuz
(
0xFF
,
fp8e5m2fnuz
::
from_bits
());
}
};
template
<
>
class
numeric_limits
<
fp8e5m2
>
{
public:
static
constexpr
bool
has_infinity
=
true
;
static
constexpr
__device__
fp8e5m2
epsilon
()
{
return
fp8e5m2
(
0x34
,
fp8e5m2
::
from_bits
());
}
// 7D, 7E, 7F are positive NaNs and FD, FE, FF are negative NaNs
static
constexpr
__device__
fp8e5m2
quiet_NaN
()
// NOLINT
{
return
fp8e5m2
(
0xFF
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
max
()
{
return
fp8e5m2
(
0x7B
,
fp8e5m2
::
from_bits
());
}
// this is min value that is not DeNormalized(DeNorm). DeNorm min is 0x01. I am not sure if we
// want to make this distinction. For the floating points we would end up using lowest most of
// the times.
static
constexpr
__device__
fp8e5m2
min
()
{
return
fp8e5m2
(
0x4
,
fp8e5m2
::
from_bits
());
}
static
constexpr
__device__
fp8e5m2
lowest
()
{
return
fp8e5m2
(
0xFB
,
fp8e5m2
::
from_bits
());
}
// 7C and FC both are infinity
static
constexpr
__device__
fp8e5m2
infinity
()
{
return
fp8e5m2
(
0x7C
,
fp8e5m2
::
from_bits
());
}
};
}
// namespace fp8
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_max
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_max different overload definition
// compared to numeric_max defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
max
();
}
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_same
<
T
,
fp8
::
fp8e4m3fnuz
>{}
or
is_same
<
T
,
fp8
::
fp8e5m2fnuz
>
{}
or
is_same
<
T
,
fp8
::
fp8e4m3fn
>
{}
or
is_same
<
T
,
fp8
::
fp8e5m2
>
{})
>
constexpr
T
numeric_lowest
(
migraphx
::
fp8
::
f8_type
unused
=
migraphx
::
fp8
::
f8_type
::
fp8
)
{
// unused parameter is added to make this numeric_lowest different overload definition
// compared to numeric_lowest defined in type_traits.hpp
(
void
)(
unused
);
return
fp8
::
numeric_limits
<
T
>::
lowest
();
}
}
// namespace migraphx
// =================================================================================================
#if defined(__clang__)
#pragma clang diagnostic pop
#endif // __clang__
#endif // MIGRAPHX_GUARD_KERNELS_FLOAT8_HPP
src/targets/gpu/kernels/include/migraphx/kernels/float8_impl.hpp
0 → 100644
View file @
f8a75f8a
/* ************************************************************************
* Copyright (C) 2016-2023 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell cop-
* ies of the Software, and to permit persons to whom the Software is furnished
* to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IM-
* PLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
* FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
* COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
* IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNE-
* CTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*
* ************************************************************************ */
#ifndef MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#define MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
#include <migraphx/kernels/bit_cast.hpp>
#include <migraphx/kernels/type_traits.hpp>
#if defined(__clang__)
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
#endif
namespace
migraphx
{
namespace
fp8
{
namespace
impl
{
// NOLINTBEGIN
template
<
int
Wm
,
int
We
,
typename
T
,
bool
NegativeZeroNan
,
bool
Clip
>
__device__
constexpr
uint8_t
cast_to_f8
(
T
f_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
constexpr
bool
is_float
=
true
;
// half is not supported for now
constexpr
bool
is_half
=
false
;
static_assert
(
Wm
+
We
==
7
,
"Wm+We==7"
);
static_assert
(
is_float
or
is_half
,
"Only float can be cast to f8"
);
const
uint32_t
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
typename
migraphx
::
conditional_t
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
x
;
if
constexpr
(
sizeof
(
T
)
==
4
)
x
=
migraphx
::
bit_cast
<
uint32_t
>
(
f_x
);
else
x
=
migraphx
::
bit_cast
<
uint16_t
>
(
f_x
);
uint32_t
head
=
0
;
uint32_t
mantissa
=
0
;
int
exponent
=
0
;
uint32_t
bias
=
0
;
uint32_t
sign
=
0
;
if
constexpr
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
We
)
-
1
)
<<
Wm
);
uint32_t
signed_all_ones
=
(
sign
<<
7
)
+
((((
1
<<
We
)
-
1
)
<<
Wm
)
+
((
1
<<
Wm
)
-
1
));
// Calcualte maximum singed value FLT_MAX, FLT_MIN
uint32_t
signed_max
=
signed_all_ones
;
if
(
not
NegativeZeroNan
)
signed_max
=
(
Wm
==
2
)
?
(
signed_max
-
4
)
:
(
signed_max
-
1
);
// Deal with inf and NaNs
if
(
NegativeZeroNan
)
// For the FNUZ cases, it is simple just return NaNs
{
if
((
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
or
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
)))
return
0x80
;
}
else
{
// calculate most common NaN mantissa for FP8, which is all Ones in binary
uint32_t
nan_mantissa
=
1
;
for
(
auto
i
=
1
;
i
<
Wm
;
++
i
)
{
nan_mantissa
|=
(
nan_mantissa
<<
1
);
}
if
((
sizeof
(
T
)
==
4
and
((
x
&
0x7F800000
)
==
0x7F800000
))
or
(
sizeof
(
T
)
==
2
and
((
x
&
0x7C00
)
==
0x7C00
)))
{
// infinity
if
(
mantissa
==
0
)
{
if
(
sign
==
0
)
return
(
Wm
==
2
)
?
0x7B
:
0x7E
;
else
return
(
Wm
==
2
)
?
0xFB
:
0xFE
;
}
else
// NaNs
return
signed_inf
+
nan_mantissa
;
}
}
// handle positive zero
if
(
x
==
0
)
return
0
;
// handle negative zero
else
if
((
sizeof
(
T
)
==
4
and
x
==
0x80000000
)
or
(
sizeof
(
T
)
==
2
and
x
==
0x8000
))
{
return
NegativeZeroNan
?
0
:
0x80
;
// For FNUZ types neg zero is just positive zero
}
/* First need to check if it is normal or denorm as there is a difference of implict 1
Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift
The mantissa. Then for stochastic rounding, add rng to mantissa and truncate. And for
RNE, no need to add rng. Then probably need to check whether there is carry and adjust
exponent and mantissa again*/
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent bits
const
int
f8_bias
=
(
1
<<
(
We
-
1u
))
-
1
+
(
NegativeZeroNan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
/* act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
f8_exponent is the converted f8 exponent with bias encoding
exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
the difference needs to be adjusted and mantissa shifted*/
int
act_exponent
=
0
;
int
f8_exponent
=
0
;
int
exponent_diff
=
0
;
if
(
exponent
==
0
and
mantissa
!=
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we mostly concern fp16
here. In this case, f8 is usually in denormal. But there could be exceptions. fp16 denormal
has exponent bias 15 while bf8 with FNUZ has exponent bias 16. It means that there are some
numbers in fp16 denormal but they are bf8 (FNUZ) normals - smallest bf8 (FNUZ) normal is
2^-15. fp16 numbers where exponent==0 (actual exponent -14) and highest bit of mantissa is 1
are bf8 (FNUZ) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
1
-
bias
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal range.
For example fp8 FNUZ mode, denormal exponent is -7, but if the fp32/fp16
actual exponent is -7, it is actually larger due to the implict 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 FNUZ */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference for this case,
// act_exponent could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
}
// need to know whether the number is right in the middle of two adjacent fp8 numbers. use max
// value of 31 to avoid undefined behaviour
bool
midpoint
=
(
mantissa
&
((
1u
<<
(
mfmt
-
Wm
+
exponent_diff
))
-
1
))
==
(
1u
<<
(
mfmt
-
Wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be done before we
shift right as shift right could rip off some residual part and make something not midpoint look
like midpoint. For example, the fp16 number 0x1002 (0 00100 0000000010), it is larger than
midpoint, but after shift right by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
mantissa
>>=
exponent_diff
;
else
if
(
exponent_diff
==
-
1
)
mantissa
<<=
-
exponent_diff
;
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implict 1, it means the f8 is denormal and need to adjust to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
Wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
Wm
));
// if the least significant bit that is not truncated is 1
/*
This part is doing rounding by adding mantissa part that is going to get dropped.
e.g. if the dropped part for less than 0.5 than it would round down.
if the dropped part is more than 0.5 then it would round up by rolling carry to LSB of retained
mantissa.
For the mid point when bit pattern is like this for Odd: `xy1:10000000` for Odd and
`xy0:10000000` for the Even. where `:` is delimiter for dropped v/s retained part.
For the odd case :
this will add xy1:10000000 + 000:10000000 which would roll over carry to LSB of retained
part making it RNE.
For the even case : this will add xy0:10000000 + 000:01111111 which would
round down and keep number Even
*/
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
and
((
1
<<
mfmt
)
&
mantissa
))
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
else
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
mantissa
>>=
(
mfmt
-
Wm
);
// above range: quantize to maximum possible float of the same sign
// for e5m2 case, max_exp is 14, since exp = 15 is reserved for Infs and Nans
const
int
max_exp
=
(
1
<<
We
)
-
((
NegativeZeroNan
or
Wm
==
3
)
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
Clip
)
return
signed_max
;
else
{
// https://onnx.ai/onnx/technical/float8.html#cast
if
(
NegativeZeroNan
)
return
0x80
;
else
return
(
Wm
==
2
)
?
signed_inf
:
signed_all_ones
;
}
}
if
(
f8_exponent
==
0
and
mantissa
==
0
)
return
NegativeZeroNan
?
0
:
(
sign
<<
7
);
mantissa
&=
(
1
<<
Wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
Wm
)
|
mantissa
;
}
// NOLINTEND
template
<
int
Wm
,
int
We
,
typename
T
,
bool
NegativeZeroNan
>
__device__
constexpr
T
cast_from_f8
(
uint8_t
x
)
{
// half is not supported for now
constexpr
bool
is_half
=
false
;
constexpr
bool
is_float
=
true
;
static_assert
(
is_float
or
is_half
,
"Only float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
// NOLINTNEXTLINE
T
f_inf
,
f_neg_inf
,
f_nan
,
f_neg0
;
if
constexpr
(
is_float
)
{
const
uint32_t
if_inf
=
0x7F800000
;
const
uint32_t
if_neg_inf
=
0xFF800000
;
const
uint32_t
if_nan
=
0x7F800001
;
const
uint32_t
if_neg0
=
0x80000000
;
f_inf
=
migraphx
::
bit_cast
<
float
>
(
if_inf
);
f_neg_inf
=
migraphx
::
bit_cast
<
float
>
(
if_neg_inf
);
f_nan
=
migraphx
::
bit_cast
<
float
>
(
if_nan
);
f_neg0
=
migraphx
::
bit_cast
<
float
>
(
if_neg0
);
}
if
(
x
==
0
)
return
0
;
uint32_t
sign
=
x
>>
7
;
// NOLINT
uint32_t
mantissa
=
x
&
((
1
<<
Wm
)
-
1
);
// NOLINT
int
exponent
=
(
x
&
0x7F
)
>>
Wm
;
// NOLINT
if
(
NegativeZeroNan
)
{
if
(
x
==
0x80
)
return
f_nan
;
}
else
{
if
(
x
==
0x80
)
return
f_neg0
;
if
(
exponent
==
((
1
<<
We
)
-
1
)
and
Wm
==
2
)
// NOLINT
return
(
mantissa
==
0
)
?
(
sign
?
f_neg_inf
:
f_inf
)
:
f_nan
;
else
if
(
Wm
==
3
and
(
x
==
0x7F
or
x
==
0xFF
))
return
f_nan
;
}
typename
migraphx
::
conditional_t
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>
retval
;
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
We
-
1
))
+
1
-
(
NegativeZeroNan
?
1
:
0
);
// NOLINT
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
__builtin_clz
(
mantissa
)
-
(
32
-
Wm
);
mantissa
<<=
sh
;
// NOLINT
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
Wm
)
-
1
);
// NOLINT
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
Wm
;
// NOLINT
// subnormal output (occurs when T=half, We=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
// NOLINT
mantissa
>>=
1
-
exponent
;
// NOLINT
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
// NOLINT
else
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
// NOLINT
return
migraphx
::
bit_cast
<
T
>
(
retval
);
}
}
// namespace impl
}
// namespace fp8
}
// namespace migraphx
#if defined(__clang__)
#pragma clang diagnostic pop
#endif
#endif // MIGRAPHX_GUARD_KERNELS_FP8_IMPL_HPP
src/targets/gpu/kernels/include/migraphx/kernels/gathernd.hpp
View file @
f8a75f8a
...
...
@@ -53,35 +53,35 @@ __device__ void gathernd(const T& data_t, const U& indices_t, const V& output_t,
auto
indices_shape_lens
=
indices_shape
.
lens
;
auto
data_shape_lens
=
data_shape
.
lens
;
auto
num_slice_dims
=
indices_shape_lens
.
back
();
std
::
size_t
num_slices
=
size_t
num_slices
=
accumulate
(
indices_shape_lens
.
begin
(),
indices_shape_lens
.
end
()
-
1
,
1
,
op
::
product
{});
std
::
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
std
::
size_t
num_batches
=
size_t
slice_size
=
accumulate
(
data_shape_lens
.
begin
()
+
num_slice_dims
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
size_t
num_batches
=
accumulate
(
data_shape_lens
.
begin
(),
data_shape_lens
.
begin
()
+
batch_dims
,
1
,
op
::
product
{});
const
std
::
size_t
data_batch_stride
=
const
size_t
data_batch_stride
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
,
data_shape_lens
.
end
(),
1
,
op
::
product
{});
const
auto
num_slices_per_batch
=
num_slices
/
num_batches
;
ind
.
global_stride
(
output_shape
.
elements
(),
[
&
](
auto
i
)
{
const
auto
*
indices_ptr
=
indices_t
.
data
();
const
std
::
size_t
j
=
i
/
slice_size
;
const
std
::
size_t
batch_idx
=
j
/
num_slices_per_batch
;
const
size_t
j
=
i
/
slice_size
;
const
size_t
batch_idx
=
j
/
num_slices_per_batch
;
auto
*
slice_indices
=
indices_ptr
+
(
j
*
num_slice_dims
);
std
::
size_t
relative_slice_offset
=
0
;
for
(
std
::
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
size_t
relative_slice_offset
=
0
;
for
(
size_t
idx
=
0
;
idx
<
num_slice_dims
;
++
idx
)
{
int64_t
index
=
slice_indices
[
idx
];
const
std
::
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
size_t
input_dim_idx
=
batch_dims
+
idx
;
const
auto
input_dim
=
data_shape_lens
[
input_dim_idx
];
MIGRAPHX_ASSERT
(
index
>=
-
static_cast
<
int64_t
>
(
input_dim
)
and
index
<
static_cast
<
int64_t
>
(
input_dim
));
if
(
index
<
0
)
index
+=
input_dim
;
std
::
size_t
size_from_slice_dims
=
size_t
size_from_slice_dims
=
accumulate
(
data_shape_lens
.
begin
()
+
batch_dims
+
idx
+
1
,
data_shape_lens
.
begin
()
+
batch_dims
+
num_slice_dims
,
slice_size
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
f8a75f8a
...
...
@@ -52,22 +52,25 @@ __device__ void generic_binary_layernorm(
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
using
vec_value_type
=
vec_type
<
value_type
>
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
constexpr
auto
relements_r
=
vec_
type
<
value_type
>
{
1.0
/
relements
};
constexpr
auto
relements_r
=
vec_value_type
{
1.0
/
relements
};
auto
relements_rsqrt
=
sqrt
(
relements_r
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing higher values
// before overflow in low precision
auto
x2_sqrt
=
x
*
relements_rsqrt
;
return
make_array
(
x_out
,
x2_sqrt
*
x2_sqrt
);
})(
input
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_value_type
>
(
vec_value_type
{
0
},
vec_value_type
{
0
}),
[
&
](
auto
x
)
{
auto
x_out
=
x
*
relements_r
;
// dividing x by sqrt(relements) before squaring allows computing
// higher values before overflow in low precision
auto
x2_sqrt
=
x
*
relements_rsqrt
;
return
make_array
(
x_out
,
x2_sqrt
*
x2_sqrt
);
})(
input
);
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
eps
;
//
implicit
conversion
for
eps
value_type
eps_val
=
implicit
_
conversion
(
eps
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
View file @
f8a75f8a
...
...
@@ -29,11 +29,15 @@
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/hip.hpp>
#include <migraphx/kernels/float8.hpp>
namespace
migraphx
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
constexpr
float
as_float
(
migraphx
::
fp8
::
fp8e4m3fnuz
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
{
...
...
@@ -57,14 +61,14 @@ constexpr T as_float(T x)
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type
\
auto __device__ name(type x, Ts... xs)
->
type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_BINARY_FOR(type, name, fname) \
inline auto __device__ name(type x, type y)->type { return fname(x, y); }
inline auto __device__ name(type x, type y)
->
type { return fname(x, y); }
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
...
...
@@ -72,6 +76,12 @@ constexpr T as_float(T x)
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FP8(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::fp8::fp8e4m3fnuz x, Ts... xs) MIGRAPHX_RETURNS( \
migraphx::fp8::fp8e4m3fnuz(fname(math::as_float(x), math::as_float(xs)...)))
// Template with two overloads for math functions, one for half2 type and one for more generic
// <half, N> vectorization where N is 4 or another even number.
...
...
@@ -162,6 +172,33 @@ MIGRAPHX_DEVICE_MATH_HALF(tan, ::tan)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
fmod
,
::
fmod
)
// use float to compute fp8 overload
MIGRAPHX_DEVICE_MATH_FP8
(
abs
,
::
abs
)
MIGRAPHX_DEVICE_MATH_FP8
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH_FP8
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH_FP8
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH_FP8
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_FP8
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_FP8
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_FP8
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_FP8
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_FP8
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_FP8
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_FP8
(
exp
,
::
exp
)
MIGRAPHX_DEVICE_MATH_FP8
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_FP8
(
isnan
,
::
isnan
)
MIGRAPHX_DEVICE_MATH_FP8
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH_FP8
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_FP8
(
remainder
,
::
remainder
)
MIGRAPHX_DEVICE_MATH_FP8
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_FP8
(
rsqrt
,
::
rsqrt
)
MIGRAPHX_DEVICE_MATH_FP8
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_FP8
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_FP8
(
sqrt
,
::
sqrt
)
MIGRAPHX_DEVICE_MATH_FP8
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_FP8
(
tanh
,
::
tanh
)
MIGRAPHX_DEVICE_MATH_FP8
(
fmod
,
::
fmod
)
// Map math functions to hip half2 functions
// The half2 type is defined in include/hip/amd_detail/hip_fp16_gcc.h and is 2 16-bit floats
// packed into a 32-bit number. See include/hip/amd_detail/hip_fp16_math_fwd.h for the HIP names
...
...
@@ -253,7 +290,7 @@ MIGRAPHX_DEVICE_MATH_VEC(where)
template
<
class
T
,
class
U
>
constexpr
auto
convert
(
U
v
)
{
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
x
;
});
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
static_cast
<
T
>
(
x
)
;
});
}
}
// namespace migraphx
...
...
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment