Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
70d9faf7
Unverified
Commit
70d9faf7
authored
Dec 13, 2023
by
Chris Austen
Committed by
GitHub
Dec 13, 2023
Browse files
Merge branch 'develop' into mi200
parents
a56c531c
a60bdb67
Changes
442
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
525 additions
and
543 deletions
+525
-543
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+7
-13
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+30
-48
src/targets/gpu/pack_int8_args.cpp
src/targets/gpu/pack_int8_args.cpp
+0
-225
src/targets/gpu/pad.cpp
src/targets/gpu/pad.cpp
+0
-46
src/targets/gpu/prefuse_ops.cpp
src/targets/gpu/prefuse_ops.cpp
+35
-10
src/targets/gpu/rocblas.cpp
src/targets/gpu/rocblas.cpp
+4
-8
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+27
-2
src/targets/ref/CMakeLists.txt
src/targets/ref/CMakeLists.txt
+1
-5
src/targets/ref/gemm.cpp
src/targets/ref/gemm.cpp
+0
-157
src/targets/ref/lowering.cpp
src/targets/ref/lowering.cpp
+6
-17
src/tf/CMakeLists.txt
src/tf/CMakeLists.txt
+9
-2
src/tmp_dir.cpp
src/tmp_dir.cpp
+11
-1
src/verify_args.cpp
src/verify_args.cpp
+0
-1
src/version.h.in
src/version.h.in
+1
-1
test/CMakeLists.txt
test/CMakeLists.txt
+6
-2
test/api/CMakeLists.txt
test/api/CMakeLists.txt
+6
-4
test/api/test_cpu.cpp
test/api/test_cpu.cpp
+25
-0
test/api/test_gpu.cpp
test/api/test_gpu.cpp
+55
-0
test/float_equal.cpp
test/float_equal.cpp
+11
-1
test/fp8e4m3fn.cpp
test/fp8e4m3fn.cpp
+291
-0
No files found.
src/targets/gpu/lowering.cpp
View file @
70d9faf7
...
@@ -61,9 +61,8 @@ struct miopen_apply
...
@@ -61,9 +61,8 @@ struct miopen_apply
const
lowering
*
pass
=
nullptr
;
const
lowering
*
pass
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
std
::
unordered_map
<
std
::
string
,
std
::
function
<
instruction_ref
(
instruction_ref
)
>>
apply_map
{};
instruction_ref
last
{};
instruction_ref
last
{};
bool
offload_copy
=
false
;
bool
offload_copy
=
false
;
bool
int8_x4_format
=
true
;
bool
compute_fp32
=
false
;
bool
compute_fp32
=
false
;
context
&
get_context
()
const
context
&
get_context
()
const
{
{
...
@@ -84,10 +83,8 @@ struct miopen_apply
...
@@ -84,10 +83,8 @@ struct miopen_apply
assert
(
mod
!=
nullptr
);
assert
(
mod
!=
nullptr
);
assert
(
pass
!=
nullptr
);
assert
(
pass
!=
nullptr
);
auto
&
ctx
=
get_context
();
compute_fp32
=
get_compute_fp32_flag
();
int8_x4_format
=
get_int8_x4_format
(
ctx
);
offload_copy
=
(
mod
==
mpm
->
get_root_module
())
?
pass
->
offload_copy
:
false
;
compute_fp32
=
get_compute_fp32_flag
();
offload_copy
=
(
mod
==
mpm
->
get_root_module
())
?
pass
->
offload_copy
:
false
;
add_generic_op
(
"contiguous"
);
add_generic_op
(
"contiguous"
);
add_extend_op
(
"argmax"
);
add_extend_op
(
"argmax"
);
...
@@ -231,18 +228,15 @@ struct miopen_apply
...
@@ -231,18 +228,15 @@ struct miopen_apply
assert
(
refs
.
size
()
==
2
);
assert
(
refs
.
size
()
==
2
);
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
refs
.
push_back
(
output
);
refs
.
push_back
(
output
);
return
mod
->
replace_instruction
(
return
mod
->
replace_instruction
(
ins
,
rocblas_gemm
<
Op
>
{
Op
{},
1
,
0
,
compute_fp32
},
refs
);
ins
,
rocblas_gemm
<
Op
>
{
Op
{},
1
,
0
,
int8_x4_format
,
compute_fp32
},
refs
);
});
});
}
}
void
add_convolution_op
(
const
std
::
string
&
name
)
void
add_convolution_op
(
const
std
::
string
&
name
)
{
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
apply_map
.
emplace
(
name
,
[
=
](
instruction_ref
ins
)
{
operation
conv
=
make_op
(
operation
conv
=
make_op
(
"gpu::"
+
name
,
{{
"op"
,
ins
->
get_operator
().
to_value
()}});
"gpu::"
+
name
,
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
{{
"op"
,
ins
->
get_operator
().
to_value
()},
{
"int8_x4_format"
,
int8_x4_format
}});
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
return
mod
->
replace_instruction
(
ins
,
return
mod
->
replace_instruction
(
ins
,
make_op
(
"gpu::miopen_op"
,
{{
"op"
,
to_value
(
conv
)}}),
make_op
(
"gpu::miopen_op"
,
{{
"op"
,
to_value
(
conv
)}}),
...
...
src/targets/gpu/mlir.cpp
View file @
70d9faf7
...
@@ -37,7 +37,7 @@
...
@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h>
#include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mlir-c/Support.h>
#include <mutex>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
3
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
4
#warning "Incompatible version of rocMLIR library used, disabling"
#warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
// Only undefine when not using cppcheck
#ifndef CPPCHECK
#ifndef CPPCHECK
...
@@ -73,6 +73,7 @@ namespace gpu {
...
@@ -73,6 +73,7 @@ namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_TRACE_MLIR
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNE_LIMIT
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_DB
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_MLIR_TUNING_CFG
);
...
@@ -299,6 +300,8 @@ struct mlir_program
...
@@ -299,6 +300,8 @@ struct mlir_program
result
=
mlirF32TypeGet
(
ctx
.
get
());
result
=
mlirF32TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
else
if
(
as
.
type_enum
()
==
shape
::
half_type
)
result
=
mlirF16TypeGet
(
ctx
.
get
());
result
=
mlirF16TypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
fp8e4m3fnuz_type
)
result
=
mlirFloat8E4M3FNUZTypeGet
(
ctx
.
get
());
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
else
if
(
as
.
type_enum
()
==
shape
::
double_type
)
result
=
mlirF64TypeGet
(
ctx
.
get
());
result
=
mlirF64TypeGet
(
ctx
.
get
());
else
if
(
as
.
is_integral
())
else
if
(
as
.
is_integral
())
...
@@ -318,31 +321,30 @@ struct mlir_program
...
@@ -318,31 +321,30 @@ struct mlir_program
return
result
;
return
result
;
}
}
MlirType
make_
tensor
(
const
shape
&
s
)
const
MlirType
make_
mlir_shaped
(
const
shape
&
s
)
const
{
{
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR expects all tensors to be in standard shape"
);
if
(
s
.
dynamic
())
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MLIR does not support dynamic shapes"
);
MIGRAPHX_THROW
(
"MLIR does not support dynamic shapes"
);
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
return
mlirRankedTensorTypeGet
(
std
::
vector
<
int64_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
());
lens
.
size
(),
lens
.
data
(),
make_type
(
s
.
type
()),
mlirAttributeGetNull
());
return
rocmlirMIXRShapedTypeGet
(
lens
.
size
(),
lens
.
data
(),
strides
.
data
(),
make_type
(
s
.
type
()));
}
}
template
<
class
Range
>
template
<
class
Range
>
std
::
vector
<
MlirType
>
make_
tensor
s
(
const
Range
&
r
)
std
::
vector
<
MlirType
>
make_
mlir_shaped
s
(
const
Range
&
r
)
{
{
std
::
vector
<
MlirType
>
result
;
std
::
vector
<
MlirType
>
result
;
std
::
transform
(
r
.
begin
(),
r
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
std
::
transform
(
r
.
begin
(),
r
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
return
make_
tensor
(
s
);
return
make_
mlir_shaped
(
s
);
});
});
return
result
;
return
result
;
}
}
MlirType
make_function_type
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
outputs
)
MlirType
make_function_type
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
outputs
)
{
{
auto
in
=
make_
tensor
s
(
inputs
);
auto
in
=
make_
mlir_shaped
s
(
inputs
);
auto
out
=
make_
tensor
s
(
outputs
);
auto
out
=
make_
mlir_shaped
s
(
outputs
);
return
mlirFunctionTypeGet
(
ctx
.
get
(),
in
.
size
(),
in
.
data
(),
out
.
size
(),
out
.
data
());
return
mlirFunctionTypeGet
(
ctx
.
get
(),
in
.
size
(),
in
.
data
(),
out
.
size
(),
out
.
data
());
}
}
...
@@ -504,11 +506,7 @@ struct mlir_program
...
@@ -504,11 +506,7 @@ struct mlir_program
mlir_operation_state
&
add_results
(
const
std
::
vector
<
shape
>&
outputs
)
mlir_operation_state
&
add_results
(
const
std
::
vector
<
shape
>&
outputs
)
{
{
std
::
vector
<
shape
>
reshaped
(
outputs
.
size
());
auto
x
=
prog
->
make_mlir_shapeds
(
outputs
);
std
::
transform
(
outputs
.
begin
(),
outputs
.
end
(),
reshaped
.
begin
(),
[](
const
shape
&
r
)
{
return
shape
{
r
.
type
(),
r
.
lens
()};
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
if
(
not
x
.
empty
())
if
(
not
x
.
empty
())
{
{
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
...
@@ -581,7 +579,7 @@ struct mlir_program
...
@@ -581,7 +579,7 @@ struct mlir_program
std
::
vector
<
shape
>
outputs
=
m
.
get_output_shapes
();
std
::
vector
<
shape
>
outputs
=
m
.
get_output_shapes
();
std
::
vector
<
MlirLocation
>
arg_locs
(
inputs
.
size
(),
location
);
std
::
vector
<
MlirLocation
>
arg_locs
(
inputs
.
size
(),
location
);
auto
body_inputs
=
make_
tensor
s
(
inputs
);
auto
body_inputs
=
make_
mlir_shaped
s
(
inputs
);
mlir_region
region
=
mlirRegionCreate
();
mlir_region
region
=
mlirRegionCreate
();
mlir_block
fbody
=
mlirBlockCreate
(
body_inputs
.
size
(),
body_inputs
.
data
(),
arg_locs
.
data
());
mlir_block
fbody
=
mlirBlockCreate
(
body_inputs
.
size
(),
body_inputs
.
data
(),
arg_locs
.
data
());
MlirBlock
result
=
fbody
.
get
();
MlirBlock
result
=
fbody
.
get
();
...
@@ -607,7 +605,7 @@ struct mlir_program
...
@@ -607,7 +605,7 @@ struct mlir_program
return
"func.return"
;
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
return
"
tosa.const
"
;
return
"
migraphx.literal
"
;
}
}
return
"migraphx."
+
ins
->
name
();
return
"migraphx."
+
ins
->
name
();
}
}
...
@@ -666,7 +664,8 @@ struct mlir_program
...
@@ -666,7 +664,8 @@ struct mlir_program
if
(
ins
->
name
()
==
"@literal"
)
if
(
ins
->
name
()
==
"@literal"
)
{
{
literal
r
=
ins
->
get_literal
();
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirType
shaped_type
=
make_mlir_shaped
(
ins
->
get_shape
());
MlirType
tensor_type
=
rocmlirMIXRShapedTypeAsTensor
(
shaped_type
);
MlirAttribute
mlir_value_attr
=
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
...
@@ -796,7 +795,9 @@ struct mlir_program
...
@@ -796,7 +795,9 @@ struct mlir_program
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
if
(
enabled
(
MIGRAPHX_MLIR_TUNE_EXHAUSTIVE
{}))
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
tuning_mode
=
RocmlirTuningParamSetKindExhaustive
;
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
mlir_tuning_space
params
{
mlirRockTuningSpaceCreate
(
mmodule
.
get
(),
tuning_mode
)};
for
(
auto
i
:
range
(
mlirRockTuningGetNumParams
(
params
.
get
())))
const
auto
limit
=
value_of
(
MIGRAPHX_MLIR_TUNE_LIMIT
{},
std
::
numeric_limits
<
std
::
size_t
>::
max
());
for
(
auto
i
:
range
(
std
::
min
<
std
::
size_t
>
(
limit
,
mlirRockTuningGetNumParams
(
params
.
get
()))))
{
{
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
mlir_tuning_param
param
{
mlirRockTuningParamCreate
()};
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
if
(
not
mlirRockTuningParamGet
(
params
.
get
(),
i
,
param
.
get
()))
...
@@ -942,35 +943,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
...
@@ -942,35 +943,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto
param
=
m
.
get_parameter
(
name
);
auto
param
=
m
.
get_parameter
(
name
);
if
(
input
.
standard
())
if
(
input
.
standard
())
continue
;
continue
;
auto
lens
=
input
.
lens
();
auto
new_param
=
m
.
add_parameter
(
name
+
".0"
,
input
);
auto
strides
=
input
.
strides
();
std
::
vector
<
operation
>
ops
;
if
(
input
.
transposed
())
{
auto
perm
=
find_permutation
(
input
);
auto
iperm
=
invert_permutation
(
perm
);
lens
=
reorder_dims
(
lens
,
iperm
);
strides
=
reorder_dims
(
strides
,
iperm
);
ops
.
push_back
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}));
}
if
(
input
.
broadcasted
())
{
std
::
transform
(
lens
.
begin
(),
lens
.
end
(),
strides
.
begin
(),
lens
.
begin
(),
[](
auto
len
,
auto
stride
)
->
std
::
size_t
{
if
(
stride
==
0
)
return
1
;
return
len
;
});
ops
.
push_back
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input
.
lens
()}}));
}
auto
new_param
=
std
::
accumulate
(
ops
.
begin
(),
ops
.
end
(),
m
.
add_parameter
(
name
+
".0"
,
shape
{
input
.
type
(),
lens
}),
[
&
](
auto
x
,
auto
op
)
{
return
m
.
insert_instruction
(
param
,
op
,
x
);
});
m
.
replace_instruction
(
param
,
new_param
);
m
.
replace_instruction
(
param
,
new_param
);
m
.
remove_instruction
(
param
);
m
.
remove_instruction
(
param
);
}
}
...
@@ -1032,6 +1005,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
...
@@ -1032,6 +1005,15 @@ tuning_config get_tuning_config_mlir(const context& migraphx_ctx,
mlir_program
mp
;
mlir_program
mp
;
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
set_gpu_properties
(
migraphx_ctx
);
mp
.
parse
(
m
);
mp
.
parse
(
m
);
const
bool
trace
=
enabled
(
MIGRAPHX_TRACE_MLIR
{});
static
std
::
mutex
mutex
;
if
(
trace
)
{
const
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex
);
auto
mod_op
=
mlirModuleGetOperation
(
mp
.
mmodule
.
get
());
std
::
cout
<<
mlir_print
(
&
mlirOperationPrint
,
mod_op
)
<<
std
::
endl
;
}
return
mp
.
get_tuning_config
(
exhaustive
);
return
mp
.
get_tuning_config
(
exhaustive
);
}
}
...
...
src/targets/gpu/pack_int8_args.cpp
deleted
100644 → 0
View file @
a56c531c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <iterator>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/int8_gemm_pack.hpp>
#include <migraphx/gpu/int8_conv_pack.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/program.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/permutation.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
static
instruction_ref
pad_ins
(
module
&
m
,
instruction_ref
ins
,
int
offset
)
{
auto
s
=
ins
->
get_shape
();
auto
lens
=
s
.
lens
();
auto
k
=
lens
[
lens
.
size
()
+
offset
];
auto
pad_k
=
(
k
+
3
)
/
4
*
4
;
auto
pad_lens
=
lens
;
pad_lens
[
lens
.
size
()
+
offset
]
=
pad_k
;
auto
ret_ins
=
ins
;
if
(
pad_k
!=
k
)
{
std
::
vector
<
int64_t
>
pad_dims
(
lens
.
size
()
*
2
,
0
);
pad_dims
[
lens
.
size
()
+
offset
]
=
pad_k
-
k
;
shape
ps
{
s
.
type
(),
pad_lens
};
auto
ins_out
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
ps
)}}));
auto
pad
=
make_op
(
"pad"
,
{{
"pads"
,
pad_dims
}});
ret_ins
=
m
.
insert_instruction
(
std
::
next
(
ins
),
make_op
(
"gpu::pad"
,
pad
.
to_value
()),
ins
,
ins_out
);
}
return
ret_ins
;
}
static
std
::
vector
<
instruction_ref
>
pad_inputs
(
module
&
m
,
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
ret_inputs
;
auto
inputs
=
ins
->
inputs
();
auto
in0
=
inputs
.
at
(
0
);
auto
sa
=
in0
->
get_shape
();
bool
transa
=
sa
.
transposed
();
if
(
transa
)
{
auto
perm
=
find_permutation
(
sa
);
auto
val
=
in0
->
get_operator
().
to_value
();
if
(
val
.
contains
(
"dims"
))
{
int
offset
=
static_cast
<
int
>
(
perm
.
back
())
-
static_cast
<
int
>
(
perm
.
size
());
auto
t_in
=
in0
->
inputs
().
front
();
auto
p_in
=
pad_ins
(
m
,
t_in
,
offset
);
auto
dims
=
val
.
at
(
"dims"
).
to_vector
<
int64_t
>
();
auto
r_in
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
dims
}}),
p_in
);
ret_inputs
.
push_back
(
r_in
);
}
else
{
shape
cs
{
in0
->
get_shape
().
type
(),
in0
->
get_shape
().
lens
()};
auto
con_out
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
cs
)}}));
auto
cin0
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::contiguous"
),
in0
,
con_out
);
ret_inputs
.
push_back
(
pad_ins
(
m
,
cin0
,
-
1
));
}
}
else
{
ret_inputs
.
push_back
(
pad_ins
(
m
,
in0
,
-
1
));
}
auto
in1
=
inputs
.
at
(
1
);
auto
sb
=
in1
->
get_shape
();
bool
transb
=
sb
.
transposed
();
if
(
transb
)
{
auto
perm
=
find_permutation
(
sb
);
auto
val
=
in1
->
get_operator
().
to_value
();
if
(
val
.
contains
(
"dims"
))
{
int
offset
=
static_cast
<
int
>
(
perm
[
perm
.
size
()
-
2
])
-
static_cast
<
int
>
(
perm
.
size
());
auto
t_in
=
in1
->
inputs
().
front
();
auto
p_in
=
pad_ins
(
m
,
t_in
,
offset
);
auto
dims
=
val
.
at
(
"dims"
).
to_vector
<
int64_t
>
();
auto
r_in
=
m
.
insert_instruction
(
ins
,
make_op
(
"transpose"
,
{{
"permutation"
,
dims
}}),
p_in
);
ret_inputs
.
push_back
(
r_in
);
}
else
{
shape
cs
{
in1
->
get_shape
().
type
(),
in1
->
get_shape
().
lens
()};
auto
con_out
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
cs
)}}));
auto
cin1
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::contiguous"
),
in1
,
con_out
);
ret_inputs
.
push_back
(
pad_ins
(
m
,
cin1
,
-
2
));
}
}
else
{
ret_inputs
.
push_back
(
pad_ins
(
m
,
in1
,
-
2
));
}
std
::
copy
(
inputs
.
begin
()
+
2
,
inputs
.
end
(),
std
::
back_inserter
(
ret_inputs
));
return
ret_inputs
;
}
void
pack_int8_args
::
apply
(
module
&
m
)
const
{
for
(
auto
ins
:
iterator_for
(
m
))
{
if
(
ins
->
name
()
==
"gpu::quant_gemm"
)
{
auto
val
=
ins
->
get_operator
().
to_value
();
assert
(
val
.
contains
(
"int8_x4_format"
));
if
(
not
val
.
at
(
"int8_x4_format"
).
to
<
bool
>
())
{
continue
;
}
auto
inputs
=
ins
->
inputs
();
auto
lens
=
inputs
.
at
(
0
)
->
get_shape
().
lens
();
// gemm need the k to be multiple of 4, so need packing that dimension
auto
old_inputs
=
inputs
;
if
((
lens
.
back
()
%
4
)
!=
0
)
{
inputs
=
pad_inputs
(
m
,
ins
);
}
bool
transa
=
inputs
[
0
]
->
get_shape
().
transposed
();
bool
transb
=
inputs
[
1
]
->
get_shape
().
transposed
();
if
(
not
transb
)
{
auto
packed_b
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
inputs
[
1
]
->
get_shape
())}}));
auto
output_b
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::int8_gemm_pack_a"
),
{
inputs
[
1
],
packed_b
});
inputs
[
1
]
=
output_b
;
}
if
(
transa
)
{
auto
packed_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
inputs
[
0
]
->
get_shape
())}}));
auto
output_a
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::int8_gemm_pack_b"
),
{
inputs
[
0
],
packed_a
});
inputs
[
0
]
=
output_a
;
}
if
(
inputs
!=
old_inputs
)
{
m
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
inputs
);
}
}
else
if
(
ins
->
name
()
==
"gpu::quant_convolution"
)
{
auto
val
=
ins
->
get_operator
().
to_value
();
if
(
not
val
.
at
(
"int8_x4_format"
).
to
<
bool
>
())
{
continue
;
}
auto
inputs
=
ins
->
inputs
();
auto
packed_x
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
pack_int8_shape
(
inputs
[
0
]
->
get_shape
()))}}));
auto
output_x
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::int8_conv_pack"
),
{
inputs
[
0
],
packed_x
});
instruction
::
replace_argument
(
ins
,
inputs
[
0
],
output_x
);
auto
packed_w
=
m
.
insert_instruction
(
ins
,
make_op
(
"hip::allocate"
,
{{
"shape"
,
to_value
(
pack_int8_shape
(
inputs
[
1
]
->
get_shape
()))}}));
auto
output_w
=
m
.
insert_instruction
(
ins
,
make_op
(
"gpu::int8_conv_pack"
),
{
inputs
[
1
],
packed_w
});
instruction
::
replace_argument
(
ins
,
inputs
[
1
],
output_w
);
}
}
}
shape
pack_int8_args
::
pack_int8_shape
(
const
shape
&
s
)
const
{
if
(
s
.
type
()
!=
shape
::
int8_type
)
{
MIGRAPHX_THROW
(
"PACK_INT8_ARGS: only process int8_type"
);
}
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
[
1
]
=
(
lens
[
1
]
+
3
)
/
4
*
4
;
strides
[
0
]
=
strides
[
1
]
*
lens
[
1
];
return
{
s
.
type
(),
lens
,
strides
};
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/pad.cpp
deleted
100644 → 0
View file @
a56c531c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_pad
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
return
op
.
compute_shape
(
inputs
);
}
argument
hip_pad
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
pad
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
value
,
op
.
pads
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/prefuse_ops.cpp
View file @
70d9faf7
...
@@ -28,7 +28,10 @@
...
@@ -28,7 +28,10 @@
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/dead_code_elimination.hpp>
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
#include <migraphx/gpu/ck.hpp>
#include <migraphx/gpu/ck.hpp>
#endif
#include <migraphx/gpu/fuse_mlir.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
...
@@ -128,26 +131,49 @@ struct pre_gemm_softmax_gemm : gemm_softmax_gemm
};
};
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_REGISTER_OP
(
pre_gemm_softmax_gemm
);
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
auto
is_ck_gemm
(
)
{
{
if
(
ins
->
name
()
!=
"dot"
)
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
#ifdef MIGRAPHX_USE_COMPOSABLEKERNEL
if
(
not
enabled
(
MIGRAPHX_ENABLE_CK
{}))
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
return
false
;
return
true
;
#else
(
void
)
ins
;
return
false
;
return
false
;
if
(
not
pre_gemm_softmax_gemm
::
is_ck_supported_type
(
ins
->
get_shape
().
type
()))
#endif
return
false
;
});
return
true
;
}
auto
is_mlir_gemm
()
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
not
mlir_attention_enabled
())
return
false
;
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
return
std
::
all_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[
&
](
auto
i
)
{
return
pre_gemm_softmax_gemm
::
is_mlir_supported_type
(
i
->
get_shape
().
type
());
});
});
}
}
struct
find_gemm_softmax_gemm
struct
find_gemm_softmax_gemm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
gemm1
=
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
(
).
bind
(
"gemm1"
)));
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()
).
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
auto
mul
=
match
::
name
(
"mul"
)(
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
match
::
nargs
(
2
),
match
::
either_arg
(
0
,
1
)(
match
::
is_constant
().
bind
(
"scale"
),
gemm1
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
arg
(
0
)(
mul
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
return
match
::
name
(
"dot"
)(
match
::
any_of
(
is_ck_gemm
(),
is_mlir_gemm
()).
bind
(
"gemm2"
))(
match
::
arg
(
0
)(
softmax
));
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
...
@@ -182,8 +208,7 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_layernorm
{});
mpm
.
run_pass
(
dead_code_elimination
{});
mpm
.
run_pass
(
dead_code_elimination
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
match
::
find_matches
(
mpm
.
get_module
(),
find_add_layernorm
{});
if
(
enabled
(
MIGRAPHX_ENABLE_CK
{}))
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/rocblas.cpp
View file @
70d9faf7
...
@@ -53,19 +53,15 @@ bool get_compute_fp32_flag()
...
@@ -53,19 +53,15 @@ bool get_compute_fp32_flag()
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
return
(
starts_with
(
device_name
,
"gfx9"
)
and
device_name
>=
"gfx908"
);
}
}
bool
get_int8_x4_format
(
context
&
ctx
)
bool
rocblas_fp8_available
(
)
{
{
#if ROCBLAS_VERSION_MAJOR >= 3
#ifndef MIGRAPHX_USE_ROCBLAS_FP8_API
(
void
)(
ctx
);
return
false
;
return
false
;
#else
#else
// int8x4 packed format is only available starting from rocblas-v2.38 and it is deprecated in
return
gfx_has_fp8_intrinsics
();
// v3.0 and will be removed in v4.0
rocblas_gemm_flags
flag
;
rocblas_query_int8_layout_flag
(
ctx
.
get_stream
().
get_rocblas
(),
&
flag
);
return
flag
==
rocblas_gemm_flags_pack_int8x4
;
#endif
#endif
}
}
}
// namespace gpu
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
70d9faf7
...
@@ -63,7 +63,6 @@
...
@@ -63,7 +63,6 @@
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/gpu/pack_int8_args.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/schedule_model.hpp>
#include <migraphx/gpu/sync_device.hpp>
#include <migraphx/gpu/sync_device.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/target.hpp>
...
@@ -99,12 +98,37 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -99,12 +98,37 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
ctx
.
set_exhaustive_tune_flag
(
options
.
exhaustive_tune
);
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
std
::
set
<
shape
::
type_t
>
unsupported_types
(
shape
::
types
().
begin
(),
shape
::
types
().
end
());
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
float_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
fp8e4m3fnuz_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
half_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
bool_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
uint8_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
int32_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
unsupported_types
.
erase
(
shape
::
type_t
::
tuple_type
);
// whiltelist supported Ops for the FP8
std
::
set
<
std
::
string
>
unsupported_fp8_ops
=
{};
if
(
not
gpu
::
rocblas_fp8_available
())
{
unsupported_fp8_ops
.
insert
(
"dot"
);
unsupported_fp8_ops
.
insert
(
"quant_dot"
);
}
// MIOpen doesn't have support for fp8 pooling yet.
unsupported_fp8_ops
.
insert
(
"pooling"
);
if
(
not
gpu
::
gfx_has_fp8_intrinsics
())
{
unsupported_fp8_ops
.
insert
(
"convolution"
);
unsupported_fp8_ops
.
insert
(
"quant_convolution"
);
}
// add all device kernels
unsupported_fp8_ops
.
insert
(
"logsoftmax"
);
unsupported_fp8_ops
.
insert
(
"nonzero"
);
unsupported_fp8_ops
.
insert
(
"prefix_scan_sum"
);
unsupported_fp8_ops
.
insert
(
"scatter_none"
);
unsupported_fp8_ops
.
insert
(
"topk"
);
unsupported_fp8_ops
.
insert
(
"rnn_var_sl_shift_output"
);
unsupported_fp8_ops
.
insert
(
"multinomial"
);
unsupported_fp8_ops
.
insert
(
"argmax"
);
unsupported_fp8_ops
.
insert
(
"argmin"
);
// clang-format off
// clang-format off
return
return
{
{
...
@@ -136,6 +160,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -136,6 +160,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
prefuse_ops
{},
dead_code_elimination
{},
dead_code_elimination
{},
auto_contiguous
{},
auto_contiguous
{},
eliminate_data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
optimize_module
{},
fuse_pointwise
{},
fuse_pointwise
{},
dead_code_elimination
{},
dead_code_elimination
{},
...
@@ -154,7 +180,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -154,7 +180,6 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination
{},
dead_code_elimination
{},
compile_miopen
{
&
gctx
},
compile_miopen
{
&
gctx
},
dead_code_elimination
{},
dead_code_elimination
{},
pack_int8_args
{},
dead_code_elimination
{},
dead_code_elimination
{},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
fuse_ops
{
&
ctx
,
options
.
fast_math
},
dead_code_elimination
{},
dead_code_elimination
{},
...
...
src/targets/ref/CMakeLists.txt
View file @
70d9faf7
...
@@ -25,17 +25,13 @@
...
@@ -25,17 +25,13 @@
add_library
(
migraphx_ref
add_library
(
migraphx_ref
target.cpp
target.cpp
lowering.cpp
lowering.cpp
gemm.cpp
)
)
set_target_properties
(
migraphx_ref PROPERTIES EXPORT_NAME ref
)
set_target_properties
(
migraphx_ref PROPERTIES EXPORT_NAME ref
)
rocm_set_soversion
(
migraphx_ref
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_ref
${
MIGRAPHX_SO_VERSION
}
)
find_path
(
BLAZE_INCLUDE blaze/Blaze.h
)
rocm_clang_tidy_check
(
migraphx_ref
)
rocm_clang_tidy_check
(
migraphx_ref
)
target_link_libraries
(
migraphx_ref PRIVATE Threads::Threads
)
target_link_libraries
(
migraphx_ref PUBLIC migraphx
)
target_link_libraries
(
migraphx_ref PUBLIC migraphx
)
target_include_directories
(
migraphx_ref PRIVATE
${
BLAZE_INCLUDE
}
)
target_compile_definitions
(
migraphx_ref PRIVATE -DBLAZE_USE_CPP_THREADS
)
migraphx_generate_export_header
(
migraphx_ref
)
migraphx_generate_export_header
(
migraphx_ref
)
...
...
src/targets/ref/gemm.cpp
deleted
100644 → 0
View file @
a56c531c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/ref/gemm.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/requires.hpp>
#include <migraphx/par_for.hpp>
#include <blaze/math/CustomMatrix.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
ref
{
template
<
class
T
>
using
matrix
=
blaze
::
CustomMatrix
<
T
,
blaze
::
unaligned
,
blaze
::
unpadded
>
;
// NOLINT
template
<
class
T
>
static
auto
make_mat
(
tensor_view
<
T
>
x
)
{
const
auto
&
s
=
x
.
get_shape
();
// assert(s.lens().size() == 2);
std
::
size_t
n_dims
=
s
.
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
if
(
s
.
transposed
())
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_1
],
s
.
lens
()[
dim_0
],
s
.
strides
()[
dim_1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_0
],
s
.
lens
()[
dim_1
],
s
.
strides
()[
dim_0
]};
}
template
<
class
T
,
class
F
>
static
void
visit_mat
(
tensor_view
<
T
>
x
,
F
f
)
{
auto
mat
=
make_mat
(
x
);
if
(
x
.
get_shape
().
transposed
())
f
(
blaze
::
trans
(
mat
));
else
f
(
mat
);
}
template
<
class
T
>
struct
is_fast_gemm_type
:
std
::
false_type
{
};
template
<
>
struct
is_fast_gemm_type
<
float
>
:
std
::
true_type
{
};
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
,
std
::
true_type
)
{
visit_mat
(
amat
,
[
&
](
const
auto
&
a
)
{
visit_mat
(
bmat
,
[
&
](
const
auto
&
b
)
{
auto
c
=
make_mat
(
cmat
);
c
=
beta
*
c
;
// This is a simple optimization to avoid
// compute A * B if alpha is 0.0
if
(
alpha
!=
0.0
)
{
c
=
c
+
alpha
*
a
*
b
;
}
});
});
}
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
,
std
::
false_type
)
{
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_0
]
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
cmat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_1
]);
auto
cs
=
cmat
.
get_shape
();
par_for
(
cs
.
elements
(),
[
&
](
auto
i
)
{
auto
c_idx
=
cs
.
multi
(
i
);
auto
a_idx
=
c_idx
;
auto
b_idx
=
c_idx
;
double
s
=
0.0
;
dfor
(
k
)([
&
](
auto
kk
)
{
a_idx
[
dim_1
]
=
b_idx
[
dim_0
]
=
kk
;
s
+=
amat
(
a_idx
.
begin
(),
a_idx
.
end
())
*
bmat
(
b_idx
.
begin
(),
b_idx
.
end
());
});
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
=
alpha
*
s
+
cmat
(
c_idx
.
begin
(),
c_idx
.
end
())
*
beta
;
});
}
template
<
class
T
,
class
F
>
void
migemm_impl
(
tensor_view
<
T
>
cmat
,
tensor_view
<
T
>
amat
,
tensor_view
<
T
>
bmat
,
F
alpha
,
F
beta
)
{
auto
lens
=
amat
.
get_shape
().
lens
();
bool
batch_mul
=
std
::
accumulate
(
lens
.
rbegin
()
+
2
,
lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
())
==
1
;
if
(
batch_mul
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
is_fast_gemm_type
<
T
>
{});
}
else
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
,
std
::
false_type
{});
}
}
template
<
class
F
>
void
migemm_tpl
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
F
alpha
,
F
beta
)
{
visit_all
(
c_arg
,
a_arg
,
b_arg
)(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
migemm_impl
(
cmat
,
amat
,
bmat
,
alpha
,
beta
);
});
}
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
float
alpha
,
float
beta
)
{
migemm_tpl
(
c_arg
,
a_arg
,
b_arg
,
alpha
,
beta
);
}
void
migemm
(
const
argument
&
c_arg
,
const
argument
&
a_arg
,
const
argument
&
b_arg
,
int32_t
alpha
,
int32_t
beta
)
{
migemm_tpl
(
c_arg
,
a_arg
,
b_arg
,
alpha
,
beta
);
}
}
// namespace ref
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/ref/lowering.cpp
View file @
70d9faf7
...
@@ -44,7 +44,6 @@
...
@@ -44,7 +44,6 @@
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/clamp.hpp>
#include <migraphx/ref/gemm.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/tune_axis.hpp>
#include <migraphx/tune_axis.hpp>
...
@@ -283,8 +282,8 @@ struct ref_gemm
...
@@ -283,8 +282,8 @@ struct ref_gemm
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
dyn_output
&
dyn_out
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
dyn_out
.
computed_shape
};
argument
result
{
dyn_out
.
computed_shape
};
migemm
(
result
,
args
[
0
],
args
[
1
]
,
1.0
f
,
0.0
f
);
visit_all
(
result
,
args
[
0
],
args
[
1
]
)(
[
&
](
auto
cmat
,
auto
amat
,
auto
bmat
)
{
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
return
result
;
return
result
;
}
}
};
};
...
@@ -306,24 +305,14 @@ struct ref_quant_gemm
...
@@ -306,24 +305,14 @@ struct ref_quant_gemm
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// first, convert the args[0] and args[1] from int8_t to int32_t
result
.
visit
([
&
](
auto
cmat
)
{
argument
arg_0
{{
shape
::
int32_type
,
{
args
.
at
(
0
).
get_shape
().
lens
()}}};
visit_all
(
args
.
at
(
0
),
args
.
at
(
1
))(
argument
arg_1
{{
shape
::
int32_type
,
{
args
.
at
(
1
).
get_shape
().
lens
()}}};
[
&
](
auto
amat
,
auto
bmat
)
{
return
gemm
(
cmat
,
amat
,
bmat
,
1.0
f
,
0.0
f
);
});
arg_0
.
visit
([
&
](
auto
output
)
{
args
.
at
(
0
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
arg_1
.
visit
([
&
](
auto
output
)
{
args
.
at
(
1
).
visit
(
[
&
](
auto
input
)
{
std
::
copy
(
input
.
begin
(),
input
.
end
(),
output
.
begin
());
});
});
});
migemm
(
result
,
arg_0
,
arg_1
,
int32_t
{
1
},
int32_t
{
0
});
return
result
;
return
result
;
}
}
};
};
MIGRAPHX_REGISTER_OP
(
ref_gemm
)
MIGRAPHX_REGISTER_OP
(
ref_gemm
)
template
<
class
Op
>
template
<
class
Op
>
...
...
src/tf/CMakeLists.txt
View file @
70d9faf7
...
@@ -38,7 +38,11 @@ protobuf_generate_cpp(
...
@@ -38,7 +38,11 @@ protobuf_generate_cpp(
)
)
add_library
(
tf-proto STATIC
${
PROTO_SRCS
}
)
add_library
(
tf-proto STATIC
${
PROTO_SRCS
}
)
target_include_directories
(
tf-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_include_directories
(
tf-proto SYSTEM PUBLIC
${
CMAKE_CURRENT_BINARY_DIR
}
${
PROTOBUF_INCLUDE_DIR
}
)
target_compile_options
(
tf-proto PRIVATE -w
)
if
(
MSVC
)
target_compile_options
(
tf-proto PRIVATE /w
)
else
()
target_compile_options
(
tf-proto PRIVATE -w
)
endif
()
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
target_link_libraries
(
tf-proto PRIVATE
${
PROTOBUF_LIBRARY
}
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
set_target_properties
(
tf-proto PROPERTIES POSITION_INDEPENDENT_CODE On
)
...
@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
...
@@ -49,7 +53,10 @@ target_include_directories(migraphx_tf PRIVATE include)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
set_target_properties
(
migraphx_tf PROPERTIES EXPORT_NAME tf
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
rocm_set_soversion
(
migraphx_tf
${
MIGRAPHX_SO_VERSION
}
)
rocm_clang_tidy_check
(
migraphx_tf
)
rocm_clang_tidy_check
(
migraphx_tf
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
"-Wl,--exclude-libs,ALL"
)
target_link_libraries
(
migraphx_tf PRIVATE tf-proto
)
if
(
NOT WIN32
)
target_link_libraries
(
migraphx_tf PRIVATE
"-Wl,--exclude-libs,ALL"
)
endif
()
target_link_libraries
(
migraphx_tf PUBLIC migraphx
)
target_link_libraries
(
migraphx_tf PUBLIC migraphx
)
rocm_install_targets
(
rocm_install_targets
(
...
...
src/tmp_dir.cpp
View file @
70d9faf7
...
@@ -31,8 +31,18 @@
...
@@ -31,8 +31,18 @@
#include <sstream>
#include <sstream>
#include <iostream>
#include <iostream>
#include <string>
#include <string>
#include <sys/types.h>
#ifdef _WIN32
// cppcheck-suppress definePrefix
#define WIN32_LEAN_AND_MEAN
#include <Windows.h>
#undef getpid
// cppcheck-suppress [definePrefix, defineUpperCase]
#define getpid _getpid
#else
#include <unistd.h>
#include <unistd.h>
#include <sys/types.h>
#endif
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/verify_args.cpp
View file @
70d9faf7
...
@@ -88,7 +88,6 @@ bool verify_args(const std::string& name,
...
@@ -88,7 +88,6 @@ bool verify_args(const std::string& name,
if
(
target_nan_idx
>=
0
)
if
(
target_nan_idx
>=
0
)
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
std
::
cout
<<
"Non finite number found in target at "
<<
target_nan_idx
<<
": "
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
<<
target
[
target_nan_idx
]
<<
std
::
endl
;
std
::
cout
<<
"MIGraphX verification passed successfully."
<<
std
::
endl
;
}
}
});
});
return
passed
;
return
passed
;
...
...
src/version.h.in
View file @
70d9faf7
...
@@ -25,5 +25,5 @@
...
@@ -25,5 +25,5 @@
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define MIGRAPHX_VERSION_MAJOR @PROJECT_VERSION_MAJOR@
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define MIGRAPHX_VERSION_MINOR @PROJECT_VERSION_MINOR@
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define MIGRAPHX_VERSION_PATCH @PROJECT_VERSION_PATCH@
#define MIGRAPHX_VERSION_TWEAK @PROJECT_VERSION_TWEAK@
#define MIGRAPHX_VERSION_TWEAK
"
@PROJECT_VERSION_TWEAK@
"
// clang-format on
// clang-format on
test/CMakeLists.txt
View file @
70d9faf7
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
cmake_policy
(
SET CMP0057 NEW
)
cmake_policy
(
SET CMP0057 NEW
)
find_package
(
Threads REQUIRED
)
find_package
(
Threads REQUIRED
)
rocm_test_link_libraries
(
Threads::Threads migraphx
migraphx_ref
migraphx_onnx migraphx_tf
)
rocm_test_link_libraries
(
Threads::Threads migraphx migraphx_onnx migraphx_tf
)
rocm_test_include_directories
(
include
)
rocm_test_include_directories
(
include
)
set
(
MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL
""
)
set
(
MIGRAPHX_DISABLE_LARGE_BUFFER_TESTS Off CACHE BOOL
""
)
...
@@ -146,7 +146,11 @@ endfunction()
...
@@ -146,7 +146,11 @@ endfunction()
function
(
test_headers PREFIX
)
function
(
test_headers PREFIX
)
file
(
GLOB HEADERS CONFIGURE_DEPENDS
${
ARGN
}
)
file
(
GLOB HEADERS CONFIGURE_DEPENDS
${
ARGN
}
)
if
(
NOT MIGRAPHX_USE_COMPOSABLEKERNEL
)
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/targets/gpu/include/migraphx/gpu/ck.hpp
)
endif
()
list
(
REMOVE_ITEM HEADERS
${
CMAKE_SOURCE_DIR
}
/src/include/migraphx/float8_impl.hpp
)
foreach
(
HEADER
${
HEADERS
}
)
foreach
(
HEADER
${
HEADERS
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
file
(
RELATIVE_PATH HEADER_REL
${
CMAKE_SOURCE_DIR
}
${
HEADER
}
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
string
(
MAKE_C_IDENTIFIER
${
HEADER_REL
}
TEST_NAME
)
...
...
test/api/CMakeLists.txt
View file @
70d9faf7
...
@@ -30,6 +30,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
...
@@ -30,6 +30,9 @@ function(add_api_test TEST_NAME TEST_SRC TEST_DIR)
add_test
(
NAME
${
NAME
}
COMMAND $<TARGET_FILE:
${
NAME
}
> WORKING_DIRECTORY
${
TEST_DIR
}
)
add_test
(
NAME
${
NAME
}
COMMAND $<TARGET_FILE:
${
NAME
}
> WORKING_DIRECTORY
${
TEST_DIR
}
)
add_dependencies
(
tests
${
NAME
}
)
add_dependencies
(
tests
${
NAME
}
)
add_dependencies
(
check
${
NAME
}
)
add_dependencies
(
check
${
NAME
}
)
if
(
WIN32
)
target_compile_definitions
(
${
NAME
}
PRIVATE _CRT_SECURE_NO_WARNINGS
)
endif
()
endfunction
()
endfunction
()
# Workaround: C file dont work with clang-tidy right now, need a fix in rocm-cmake
# Workaround: C file dont work with clang-tidy right now, need a fix in rocm-cmake
...
@@ -41,6 +44,9 @@ function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
...
@@ -41,6 +44,9 @@ function(add_c_api_test TEST_NAME TEST_SRC TEST_DIR)
add_test
(
NAME
${
NAME
}
COMMAND $<TARGET_FILE:
${
NAME
}
> WORKING_DIRECTORY
${
TEST_DIR
}
)
add_test
(
NAME
${
NAME
}
COMMAND $<TARGET_FILE:
${
NAME
}
> WORKING_DIRECTORY
${
TEST_DIR
}
)
add_dependencies
(
tests
${
NAME
}
)
add_dependencies
(
tests
${
NAME
}
)
add_dependencies
(
check
${
NAME
}
)
add_dependencies
(
check
${
NAME
}
)
if
(
WIN32
)
target_compile_definitions
(
${
NAME
}
PRIVATE _CRT_SECURE_NO_WARNINGS
)
endif
()
endfunction
()
endfunction
()
add_api_test
(
array_base test_array_base.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
array_base test_array_base.cpp
${
TEST_ONNX_DIR
}
)
...
@@ -57,10 +63,6 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
...
@@ -57,10 +63,6 @@ add_api_test(custom_op test_custom_op.cpp ${TEST_ONNX_DIR})
add_api_test
(
tf_parser test_tf_parser.cpp
${
TEST_TF_DIR
}
)
add_api_test
(
tf_parser test_tf_parser.cpp
${
TEST_TF_DIR
}
)
# GPU-based tests
# GPU-based tests
if
(
MIGRAPHX_ENABLE_GPU
)
if
(
MIGRAPHX_ENABLE_GPU
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm
)
find_package
(
hip
)
add_api_test
(
gpu test_gpu.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
gpu test_gpu.cpp
${
TEST_ONNX_DIR
}
)
target_link_libraries
(
test_api_gpu
)
add_api_test
(
custom_op_gpu test_custom_op_gpu.cpp
${
TEST_ONNX_DIR
}
)
add_api_test
(
custom_op_gpu test_custom_op_gpu.cpp
${
TEST_ONNX_DIR
}
)
target_link_libraries
(
test_api_custom_op_gpu
)
endif
()
endif
()
test/api/test_cpu.cpp
View file @
70d9faf7
...
@@ -198,4 +198,29 @@ TEST_CASE(set_loop_default_iter_num)
...
@@ -198,4 +198,29 @@ TEST_CASE(set_loop_default_iter_num)
EXPECT
(
out_shapes
[
1
].
lengths
()
==
out_lens1
);
EXPECT
(
out_shapes
[
1
].
lengths
()
==
out_lens1
);
}
}
TEST_CASE
(
set_loop_limit_iterations
)
{
migraphx
::
onnx_options
option
;
option
.
set_default_loop_iterations
(
15
);
option
.
set_limit_loop_iterations
(
10
);
auto
p
=
migraphx
::
parse_onnx
(
"loop_default_test.onnx"
,
option
);
auto
out_shapes
=
p
.
get_output_shapes
();
std
::
vector
<
std
::
size_t
>
out_lens0
=
{
1
};
EXPECT
(
out_shapes
[
0
].
lengths
()
==
out_lens0
);
std
::
vector
<
std
::
size_t
>
out_lens1
=
{
10
,
1
};
EXPECT
(
out_shapes
[
1
].
lengths
()
==
out_lens1
);
}
TEST_CASE
(
set_loop_limit_iterations2
)
{
migraphx
::
onnx_options
option
;
option
.
set_limit_loop_iterations
(
10
);
auto
p
=
migraphx
::
parse_onnx
(
"loop_test_implicit_tripcnt.onnx"
,
option
);
auto
out_shapes
=
p
.
get_output_shapes
();
std
::
vector
<
std
::
size_t
>
out_lens0
=
{
1
};
EXPECT
(
out_shapes
[
0
].
lengths
()
==
out_lens0
);
std
::
vector
<
std
::
size_t
>
out_lens1
=
{
10
,
1
};
EXPECT
(
out_shapes
[
1
].
lengths
()
==
out_lens1
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/api/test_gpu.cpp
View file @
70d9faf7
...
@@ -317,4 +317,59 @@ TEST_CASE(loop_test)
...
@@ -317,4 +317,59 @@ TEST_CASE(loop_test)
}
}
}
}
TEST_CASE
(
loop_test_limit_max_iter
)
{
auto
run_prog
=
[
&
](
int64_t
limit_max_iterations
)
{
migraphx
::
onnx_options
parse_options
;
parse_options
.
set_limit_loop_iterations
(
limit_max_iterations
);
auto
p
=
migraphx
::
parse_onnx
(
"loop_test_implicit_tripcnt.onnx"
,
parse_options
);
auto
shapes_before
=
p
.
get_output_shapes
();
migraphx
::
compile_options
options
;
options
.
set_offload_copy
();
p
.
compile
(
migraphx
::
target
(
"gpu"
),
options
);
auto
shapes_after
=
p
.
get_output_shapes
();
CHECK
(
shapes_before
.
size
()
==
2
);
CHECK
(
bool
{
shapes_before
.
front
()
==
shapes_after
.
front
()});
migraphx
::
program_parameters
pp
;
auto
param_shapes
=
p
.
get_parameter_shapes
();
auto
aas
=
param_shapes
[
"a"
];
std
::
vector
<
float
>
xd
=
{
1.0
f
};
pp
.
add
(
"a"
,
migraphx
::
argument
(
aas
,
xd
.
data
()));
auto
bbs
=
param_shapes
[
"b"
];
std
::
vector
<
float
>
yd
=
{
2.0
};
pp
.
add
(
"b"
,
migraphx
::
argument
(
bbs
,
yd
.
data
()));
auto
cs
=
param_shapes
[
"keep_going_cond"
];
bool
cond
=
true
;
pp
.
add
(
"keep_going_cond"
,
migraphx
::
argument
(
cs
,
&
cond
));
auto
outputs
=
p
.
eval
(
pp
);
auto
output
=
outputs
[
0
];
std
::
vector
<
std
::
vector
<
float
>>
ret
;
ret
.
push_back
(
output
.
as_vector
<
float
>
());
output
=
outputs
[
1
];
ret
.
push_back
(
output
.
as_vector
<
float
>
());
return
ret
;
};
{
auto
result_vector
=
run_prog
(
5
);
std
::
vector
<
float
>
gold0
=
{
2.0
f
};
EXPECT
(
result_vector
.
at
(
0
)
==
gold0
);
std
::
vector
<
float
>
gold1
=
{
-
2
,
4
,
0
,
0
,
0
};
EXPECT
(
result_vector
.
at
(
1
)
==
gold1
);
}
{
auto
result_vector
=
run_prog
(
20
);
std
::
vector
<
float
>
gold0
=
{
2.0
f
};
EXPECT
(
result_vector
.
at
(
0
)
==
gold0
);
std
::
vector
<
float
>
gold1
=
{
-
2
,
4
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
};
EXPECT
(
result_vector
.
at
(
1
)
==
gold1
);
}
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
test/float_equal.cpp
View file @
70d9faf7
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
* THE SOFTWARE.
*/
*/
#include <migraphx/float_equal.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/half.hpp>
#include "test.hpp"
#include "test.hpp"
...
@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
...
@@ -53,7 +54,7 @@ auto test_float_equal(T x, U y)
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
void
test_equality
()
void
test_equality
()
{
{
auto
x1
=
T
(
0.1
);
auto
x1
=
T
(
0.1
25
);
auto
x2
=
U
(
0.0
);
auto
x2
=
U
(
0.0
);
auto
x3
=
U
(
1.0
);
auto
x3
=
U
(
1.0
);
EXPECT
(
test_float_equal
(
x1
,
x1
));
EXPECT
(
test_float_equal
(
x1
,
x1
));
...
@@ -71,8 +72,12 @@ void test_equality()
...
@@ -71,8 +72,12 @@ void test_equality()
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_equality
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
int
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
half
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_equality
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
int
>
);
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
void
test_limits
()
void
test_limits
()
...
@@ -110,8 +115,13 @@ void test_limits()
...
@@ -110,8 +115,13 @@ void test_limits()
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
float
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
double
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
float
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
half
>
);
TEST_CASE_REGISTER
(
test_limits
<
int
,
migraphx
::
fp8
::
fp8e4m3fnuz
>
);
TEST_CASE_REGISTER
(
test_limits
<
migraphx
::
fp8
::
fp8e4m3fnuz
,
migraphx
::
half
>
);
#ifndef _WIN32
#ifndef _WIN32
// On Windows, types int and long have the same min and max values.
// On Windows, types int and long have the same min and max values.
TEST_CASE_REGISTER
(
test_limits
<
long
,
int
>
);
TEST_CASE_REGISTER
(
test_limits
<
long
,
int
>
);
...
...
test/fp8e4m3fn.cpp
0 → 100644
View file @
70d9faf7
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <cmath>
#include <migraphx/float_equal.hpp>
#include <migraphx/float8.hpp>
#include <migraphx/half.hpp>
#include <migraphx/ranges.hpp>
#include "test.hpp"
#include <limits>
float
fp8e4m3fn_to_fp32_value
(
uint8_t
input
)
{
constexpr
std
::
array
<
float
,
256
>
e4m3fnuz_lut
=
{
0.0
,
0.001953125
,
0.00390625
,
0.005859375
,
0.0078125
,
0.009765625
,
0.01171875
,
0.013671875
,
0.015625
,
0.017578125
,
0.01953125
,
0.021484375
,
0.0234375
,
0.025390625
,
0.02734375
,
0.029296875
,
0.03125
,
0.03515625
,
0.0390625
,
0.04296875
,
0.046875
,
0.05078125
,
0.0546875
,
0.05859375
,
0.0625
,
0.0703125
,
0.078125
,
0.0859375
,
0.09375
,
0.1015625
,
0.109375
,
0.1171875
,
0.125
,
0.140625
,
0.15625
,
0.171875
,
0.1875
,
0.203125
,
0.21875
,
0.234375
,
0.25
,
0.28125
,
0.3125
,
0.34375
,
0.375
,
0.40625
,
0.4375
,
0.46875
,
0.5
,
0.5625
,
0.625
,
0.6875
,
0.75
,
0.8125
,
0.875
,
0.9375
,
1.0
,
1.125
,
1.25
,
1.375
,
1.5
,
1.625
,
1.75
,
1.875
,
2.0
,
2.25
,
2.5
,
2.75
,
3.0
,
3.25
,
3.5
,
3.75
,
4.0
,
4.5
,
5.0
,
5.5
,
6.0
,
6.5
,
7.0
,
7.5
,
8.0
,
9.0
,
10.0
,
11.0
,
12.0
,
13.0
,
14.0
,
15.0
,
16.0
,
18.0
,
20.0
,
22.0
,
24.0
,
26.0
,
28.0
,
30.0
,
32.0
,
36.0
,
40.0
,
44.0
,
48.0
,
52.0
,
56.0
,
60.0
,
64.0
,
72.0
,
80.0
,
88.0
,
96.0
,
104.0
,
112.0
,
120.0
,
128.0
,
144.0
,
160.0
,
176.0
,
192.0
,
208.0
,
224.0
,
240.0
,
256.0
,
288.0
,
320.0
,
352.0
,
384.0
,
416.0
,
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
-
0.0
,
-
0.001953125
,
-
0.00390625
,
-
0.005859375
,
-
0.0078125
,
-
0.009765625
,
-
0.01171875
,
-
0.013671875
,
-
0.015625
,
-
0.017578125
,
-
0.01953125
,
-
0.021484375
,
-
0.0234375
,
-
0.025390625
,
-
0.02734375
,
-
0.029296875
,
-
0.03125
,
-
0.03515625
,
-
0.0390625
,
-
0.04296875
,
-
0.046875
,
-
0.05078125
,
-
0.0546875
,
-
0.05859375
,
-
0.0625
,
-
0.0703125
,
-
0.078125
,
-
0.0859375
,
-
0.09375
,
-
0.1015625
,
-
0.109375
,
-
0.1171875
,
-
0.125
,
-
0.140625
,
-
0.15625
,
-
0.171875
,
-
0.1875
,
-
0.203125
,
-
0.21875
,
-
0.234375
,
-
0.25
,
-
0.28125
,
-
0.3125
,
-
0.34375
,
-
0.375
,
-
0.40625
,
-
0.4375
,
-
0.46875
,
-
0.5
,
-
0.5625
,
-
0.625
,
-
0.6875
,
-
0.75
,
-
0.8125
,
-
0.875
,
-
0.9375
,
-
1.0
,
-
1.125
,
-
1.25
,
-
1.375
,
-
1.5
,
-
1.625
,
-
1.75
,
-
1.875
,
-
2.0
,
-
2.25
,
-
2.5
,
-
2.75
,
-
3.0
,
-
3.25
,
-
3.5
,
-
3.75
,
-
4.0
,
-
4.5
,
-
5.0
,
-
5.5
,
-
6.0
,
-
6.5
,
-
7.0
,
-
7.5
,
-
8.0
,
-
9.0
,
-
10.0
,
-
11.0
,
-
12.0
,
-
13.0
,
-
14.0
,
-
15.0
,
-
16.0
,
-
18.0
,
-
20.0
,
-
22.0
,
-
24.0
,
-
26.0
,
-
28.0
,
-
30.0
,
-
32.0
,
-
36.0
,
-
40.0
,
-
44.0
,
-
48.0
,
-
52.0
,
-
56.0
,
-
60.0
,
-
64.0
,
-
72.0
,
-
80.0
,
-
88.0
,
-
96.0
,
-
104.0
,
-
112.0
,
-
120.0
,
-
128.0
,
-
144.0
,
-
160.0
,
-
176.0
,
-
192.0
,
-
208.0
,
-
224.0
,
-
240.0
,
-
256.0
,
-
288.0
,
-
320.0
,
-
352.0
,
-
384.0
,
-
416.0
,
-
448.0
,
std
::
numeric_limits
<
float
>::
quiet_NaN
(),
};
return
e4m3fnuz_lut
[
input
];
}
TEST_CASE
(
test_fp8_cast_to_float
)
{
std
::
vector
<
uint8_t
>
bit_vals
(
256
);
std
::
iota
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
0
);
EXPECT
(
bool
{
std
::
all_of
(
bit_vals
.
begin
(),
bit_vals
.
end
(),
[](
uint8_t
bit_val
)
{
migraphx
::
fp8
::
fp8e4m3fn
fp8_val
(
bit_val
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
if
(
std
::
isnan
(
float
(
fp8_val
))
and
std
::
isnan
(
fp8e4m3fn_to_fp32_value
(
bit_val
)))
{
return
true
;
}
return
migraphx
::
float_equal
(
float
(
fp8_val
),
fp8e4m3fn_to_fp32_value
(
bit_val
));
})});
}
TEST_CASE
(
test_fp8_cast_from_float
)
{
std
::
unordered_map
<
float
,
uint8_t
>
test_vals
=
{
{{
512
,
0x7e
},
{
-
512
,
0xfe
},
{
448
,
0x7e
},
{
-
448
,
0xfe
},
{
256
,
0x78
},
{
-
256
,
0xf8
},
{
240
,
0x77
},
{
-
240
,
0xf7
},
{
1e-07
,
0x0
},
{
1e+07
,
0x7e
},
{
1
,
0x38
},
{
-
1
,
0xb8
},
{
0.1
,
0x1d
},
{
0.11
,
0x1e
},
{
0.111
,
0x1e
},
{
0.1111
,
0x1e
},
{
-
0.1
,
0x9d
},
{
-
0.11
,
0x9e
},
{
-
0.111
,
0x9e
},
{
-
0.1111
,
0x9e
},
{
0.2
,
0x25
},
{
2
,
0x40
},
{
20
,
0x5a
},
{
200
,
0x74
},
{
-
0.2
,
0xa5
},
{
-
2
,
0xc0
},
{
-
20
,
0xda
},
{
-
200
,
0xf4
},
{
0.5
,
0x30
},
{
-
0.5
,
0xb0
},
{
1.17549e-38
,
0x0
},
{
1.4013e-45
,
0x0
},
{
0.0078125
,
0x4
},
{
-
0.0078125
,
0x84
},
{
0.000976562
,
0x0
},
{
-
0.000976562
,
0x80
},
{
0.000488281
,
0x0
},
{
-
0.000488281
,
0x80
}}};
EXPECT
(
bool
{
std
::
all_of
(
test_vals
.
begin
(),
test_vals
.
end
(),
[](
const
auto
sample
)
{
return
migraphx
::
float_equal
(
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
first
),
migraphx
::
fp8
::
fp8e4m3fn
(
sample
.
second
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
()));
})});
}
TEST_CASE
(
test_positive_zero
)
{
float
zero
=
0.0
;
migraphx
::
fp8
::
fp8e4m3fn
fp8_zero
(
zero
);
EXPECT
(
fp8_zero
.
is_zero
());
EXPECT
(
migraphx
::
float_equal
(
zero
,
float
(
fp8_zero
)));
}
TEST_CASE
(
test_negative_zero
)
{
float
nzero
=
-
0.0
;
migraphx
::
fp8
::
fp8e4m3fn
fp8_nzero
(
nzero
);
EXPECT
(
fp8_nzero
.
is_zero
());
// negative zero is preserved for fp8e4m3fn
EXPECT
(
migraphx
::
float_equal
(
nzero
,
float
(
fp8_nzero
)));
}
TEST_CASE
(
test_pos_zero_eq_neg_zero
)
{
float
nzero
=
-
0.0
;
float
pzero
=
0.0
;
migraphx
::
fp8
::
fp8e5m2
fp8_nzero
(
nzero
);
migraphx
::
fp8
::
fp8e5m2
fp8_pzero
(
pzero
);
EXPECT
(
fp8_nzero
==
fp8_pzero
);
}
TEST_CASE
(
test_nan_1
)
{
float
fnan
=
std
::
numeric_limits
<
float
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
);
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
}
TEST_CASE
(
test_nan_2
)
{
auto
fnan
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_nan
(
fnan
.
data
,
migraphx
::
fp8
::
fp8e4m3fn
::
from_bits
());
EXPECT
(
fp8_nan
.
is_nan
());
EXPECT
(
std
::
isnan
(
fp8_nan
));
EXPECT
(
std
::
isnan
(
float
(
fp8_nan
)));
}
TEST_CASE
(
test_infinity_1
)
{
float
finf
=
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to max()
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
finf
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_infinity_2
)
{
// neg inf
float
finf
=
-
1.0
*
std
::
numeric_limits
<
float
>::
infinity
();
// no inf in fp8e4m3fn, it gets clipped to lowest
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
finf
);
EXPECT
(
bool
{
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
()});
}
TEST_CASE
(
test_numeric_max_1
)
{
float
fmax
=
std
::
numeric_limits
<
float
>::
max
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_max_2
)
{
// gets clipped to max
float
fmax
=
2
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_max
(
fmax
);
EXPECT
(
fp8_max
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
());
}
TEST_CASE
(
test_numeric_lowest_1
)
{
float
flowest
=
std
::
numeric_limits
<
float
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
flowest
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_numeric_lowest_2
)
{
// gets clipped to lowest
float
fmin
=
2.0
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
();
migraphx
::
fp8
::
fp8e4m3fn
fp8_lowest
(
fmin
);
EXPECT
(
fp8_lowest
==
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
());
}
TEST_CASE
(
test_max_eq_lowest
)
{
EXPECT
(
migraphx
::
float_equal
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
lowest
(),
-
1
*
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
max
()));
}
TEST_CASE
(
test_isfinite
)
{
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
0.0
)));
EXPECT
(
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
-
0.0
)));
EXPECT
(
not
std
::
isfinite
(
migraphx
::
fp8
::
fp8e4m3fn
(
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
())));
}
TEST_CASE
(
test_no_infinity
)
{
EXPECT
(
not
bool
{
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
has_infinity
});
}
TEST_CASE
(
test_binary_ops
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fn
(
1.0
);
auto
c
=
migraphx
::
fp8
::
fp8e4m3fn
(
0.0
);
auto
d
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
0.0
);
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
c
));
EXPECT
(
migraphx
::
float_equal
((
c
+
d
),
d
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
c
));
EXPECT
(
migraphx
::
float_equal
((
a
+
b
),
d
));
auto
e
=
migraphx
::
fp8
::
fp8e4m3fn
(
10.0
);
auto
f
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
10.0
);
EXPECT
(
bool
{
e
>
f
});
EXPECT
(
bool
{
f
<
e
});
EXPECT
(
bool
{
f
<=
e
});
EXPECT
(
bool
{
e
>=
f
});
EXPECT
(
bool
{
e
<=
e
});
EXPECT
(
bool
{
f
>=
f
});
EXPECT
(
not
migraphx
::
float_equal
(
f
,
e
));
}
TEST_CASE
(
test_fabs
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
auto
b
=
migraphx
::
fp8
::
fp8e4m3fn
(
1.0
);
EXPECT
(
migraphx
::
float_equal
(
b
,
migraphx
::
fp8
::
fabs
(
a
)));
}
TEST_CASE
(
test_stream_op
)
{
auto
a
=
migraphx
::
fp8
::
fp8e4m3fn
(
-
1.0
);
std
::
stringstream
ss
;
ss
<<
a
;
EXPECT
(
std
::
string
(
"-1"
)
==
ss
.
str
());
ss
=
std
::
stringstream
();
auto
b
=
std
::
numeric_limits
<
migraphx
::
fp8
::
fp8e4m3fn
>::
quiet_NaN
();
ss
<<
b
;
EXPECT
(
std
::
string
(
"nan"
)
==
ss
.
str
());
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
…
6
7
8
9
10
11
12
13
14
…
23
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment