Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
a98d86d9
Commit
a98d86d9
authored
Dec 05, 2023
by
Umang Yadav
Browse files
Merge branch 'rocblas_fp8' into rocblas_mlir_fp8
parents
a3d4b013
7e80f627
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
99 additions
and
357 deletions
+99
-357
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+26
-14
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
+0
-44
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
+0
-48
src/targets/gpu/include/migraphx/gpu/gather.hpp
src/targets/gpu/include/migraphx/gpu/gather.hpp
+0
-62
src/targets/gpu/include/migraphx/gpu/pad.hpp
src/targets/gpu/include/migraphx/gpu/pad.hpp
+0
-61
src/targets/gpu/mlir.cpp
src/targets/gpu/mlir.cpp
+15
-47
src/targets/gpu/pad.cpp
src/targets/gpu/pad.cpp
+0
-46
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+1
-2
test/gpu/mlir.cpp
test/gpu/mlir.cpp
+52
-29
test/onnx/.onnxrt-commit
test/onnx/.onnxrt-commit
+1
-1
test/verify/gemm_2args_mm_8.cpp
test/verify/gemm_2args_mm_8.cpp
+1
-1
test/verify/gemm_add_broadcast2.cpp
test/verify/gemm_add_broadcast2.cpp
+1
-1
tools/format.py
tools/format.py
+2
-1
No files found.
src/targets/gpu/gemm_impl.cpp
View file @
a98d86d9
...
...
@@ -22,6 +22,7 @@
* THE SOFTWARE.
*/
#include <rocblas/internal/rocblas-types.h>
#include <rocblas/rocblas.h>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
...
...
@@ -36,6 +37,20 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
/*
Regular rocBLAS API takes compute_type as `rocblas_datatype` enum value v/s "ex3" BETA API takes it
as `rocblas_computetype` enum value. `rb_compute_type` is faciliator to implictly cast integer enum
value to required type that can be used inside `common_args` generator.
*/
struct
rb_compute_type
{
int
type
=
0
;
rb_compute_type
(
rocblas_datatype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
rb_compute_type
(
rocblas_computetype
t
)
:
type
(
static_cast
<
int
>
(
t
))
{}
operator
rocblas_datatype
()
const
{
return
static_cast
<
rocblas_datatype
>
(
type
);
}
operator
rocblas_computetype
()
const
{
return
static_cast
<
rocblas_computetype
>
(
type
);
}
};
// Convert rocBLAS datatypes to equivalent Migraphx data types
rocblas_datatype
get_type
(
shape
::
type_t
type
)
{
...
...
@@ -185,12 +200,17 @@ struct gemm_impl
{
output_type
=
rocblas_datatype_i32_r
;
}
compute_type
=
output_type
;
compute_type
=
rb_compute_type
{
output_type
}
;
if
(
compute_fp32
)
{
if
(
arg_type
==
rocblas_datatype_f16_r
)
compute_type
=
rocblas_datatype_f32_r
;
}
if
(
arg_type
==
rocblas_datatype_f8_r
)
{
assert
(
get_type
(
input_shapes
[
1
].
type
())
==
rocblas_datatype_f8_r
);
compute_type
=
rocblas_compute_type_f32
;
}
auto
a_lens
=
input_shapes
[
0
].
lens
();
auto
b_lens
=
input_shapes
[
1
].
lens
();
...
...
@@ -230,7 +250,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
...
...
@@ -240,7 +259,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex3
,
common_args
,
rocblas_compute_type_f32
,
rocblas_gemm_algo_standard
,
solution_idx
,
gemm_flags
);
...
...
@@ -254,7 +272,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -264,7 +281,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
gemm_flags
);
...
...
@@ -304,7 +320,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -314,7 +329,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
check_valid
=
rocblas_invoke
(
&
rocblas_gemm_ex
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
solution_idx
,
rocblas_gemm_flags_check_solution_index
);
...
...
@@ -365,7 +379,8 @@ struct gemm_impl
output_type
,
ldd
,
d_stride
,
num_matrices
);
num_matrices
,
compute_type
);
}
/**
* Helper method to create that subset of a long rocBLAS argument list that is common
...
...
@@ -398,7 +413,8 @@ struct gemm_impl
ldc
,
is_3inputs
?
args
[
3
].
data
()
:
args
[
2
].
data
(),
output_type
,
ldd
);
ldd
,
compute_type
);
}
#ifdef MIGRAPHX_USE_ROCBLAS_TUNING_API
...
...
@@ -428,7 +444,6 @@ struct gemm_impl
auto
common_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -438,7 +453,6 @@ struct gemm_impl
auto
common_sol_args
=
create_strided_batched_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
@@ -449,7 +463,6 @@ struct gemm_impl
auto
common_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
nullptr
,
...
...
@@ -459,7 +472,6 @@ struct gemm_impl
auto
common_sol_args
=
create_gemm_ex_args_common
(
ctx
,
input_args
);
rocblas_invoke
(
&
rocblas_gemm_ex_get_solutions
,
common_sol_args
,
compute_type
,
rocblas_gemm_algo_solution_index
,
gemm_flags
,
solution_indices
.
data
(),
...
...
@@ -521,7 +533,7 @@ struct gemm_impl
rocblas_int
c_stride
=
0
;
rocblas_int
d_stride
=
0
;
rocblas_datatype
arg_type
=
rocblas_datatype_f32_r
;
r
ocblas_data
type
compute_type
=
rocblas_datatype_f32_r
;
r
b_compute_
type
compute_type
=
rocblas_datatype_f32_r
;
rocblas_datatype
output_type
=
rocblas_datatype_f32_r
;
bool
strided_batched
=
true
;
bool
is_3inputs
=
true
;
...
...
src/targets/gpu/include/migraphx/gpu/device/gather.hpp
deleted
100644 → 0
View file @
a3d4b013
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
MIGRAPHX_DEVICE_EXPORT
gather
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
argument
arg2
,
int64_t
axis
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/device/pad.hpp
deleted
100644 → 0
View file @
a3d4b013
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/config.hpp>
#include <hip/hip_runtime_api.h>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
argument
MIGRAPHX_DEVICE_EXPORT
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
);
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/gather.hpp
deleted
100644 → 0
View file @
a3d4b013
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#define MIGRAPHX_GUARD_RTGLIB_GATHER_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_gather
{
op
::
gather
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::gather"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/include/migraphx/gpu/pad.hpp
deleted
100644 → 0
View file @
a3d4b013
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#define MIGRAPHX_GUARD_RTGLIB_PAD_HPP
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/pad.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
struct
context
;
struct
hip_pad
{
op
::
pad
op
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
migraphx
::
reflect
(
self
.
op
,
f
);
}
std
::
string
name
()
const
{
return
"gpu::pad"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
;
argument
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
;
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/mlir.cpp
View file @
a98d86d9
...
...
@@ -37,7 +37,7 @@
#include <mlir-c/Pass.h>
#include <mlir-c/Support.h>
#include <mutex>
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
3
#if !defined(MLIR_MIGRAPHX_DIALECT_API_VERSION) || MLIR_MIGRAPHX_DIALECT_API_VERSION !=
4
#warning "Incompatible version of rocMLIR library used, disabling"
// Only undefine when not using cppcheck
#ifndef CPPCHECK
...
...
@@ -321,31 +321,30 @@ struct mlir_program
return
result
;
}
MlirType
make_
tensor
(
const
shape
&
s
)
const
MlirType
make_
mlir_shaped
(
const
shape
&
s
)
const
{
if
(
not
s
.
standard
())
MIGRAPHX_THROW
(
"MLIR expects all tensors to be in standard shape"
);
if
(
s
.
dynamic
())
MIGRAPHX_THROW
(
"MLIR does not support dynamic shapes"
);
std
::
vector
<
int64_t
>
lens
(
s
.
lens
().
begin
(),
s
.
lens
().
end
());
return
mlirRankedTensorTypeGet
(
lens
.
size
(),
lens
.
data
(),
make_type
(
s
.
type
()),
mlirAttributeGetNull
());
std
::
vector
<
int64_t
>
strides
(
s
.
strides
().
begin
(),
s
.
strides
().
end
());
return
rocmlirMIXRShapedTypeGet
(
lens
.
size
(),
lens
.
data
(),
strides
.
data
(),
make_type
(
s
.
type
()));
}
template
<
class
Range
>
std
::
vector
<
MlirType
>
make_
tensor
s
(
const
Range
&
r
)
std
::
vector
<
MlirType
>
make_
mlir_shaped
s
(
const
Range
&
r
)
{
std
::
vector
<
MlirType
>
result
;
std
::
transform
(
r
.
begin
(),
r
.
end
(),
std
::
back_inserter
(
result
),
[
&
](
const
auto
&
s
)
{
return
make_
tensor
(
s
);
return
make_
mlir_shaped
(
s
);
});
return
result
;
}
MlirType
make_function_type
(
const
std
::
vector
<
shape
>&
inputs
,
const
std
::
vector
<
shape
>&
outputs
)
{
auto
in
=
make_
tensor
s
(
inputs
);
auto
out
=
make_
tensor
s
(
outputs
);
auto
in
=
make_
mlir_shaped
s
(
inputs
);
auto
out
=
make_
mlir_shaped
s
(
outputs
);
return
mlirFunctionTypeGet
(
ctx
.
get
(),
in
.
size
(),
in
.
data
(),
out
.
size
(),
out
.
data
());
}
...
...
@@ -507,11 +506,7 @@ struct mlir_program
mlir_operation_state
&
add_results
(
const
std
::
vector
<
shape
>&
outputs
)
{
std
::
vector
<
shape
>
reshaped
(
outputs
.
size
());
std
::
transform
(
outputs
.
begin
(),
outputs
.
end
(),
reshaped
.
begin
(),
[](
const
shape
&
r
)
{
return
shape
{
r
.
type
(),
r
.
lens
()};
});
auto
x
=
prog
->
make_tensors
(
reshaped
);
auto
x
=
prog
->
make_mlir_shapeds
(
outputs
);
if
(
not
x
.
empty
())
{
mlirOperationStateAddResults
(
&
op_state
,
x
.
size
(),
x
.
data
());
...
...
@@ -584,7 +579,7 @@ struct mlir_program
std
::
vector
<
shape
>
outputs
=
m
.
get_output_shapes
();
std
::
vector
<
MlirLocation
>
arg_locs
(
inputs
.
size
(),
location
);
auto
body_inputs
=
make_
tensor
s
(
inputs
);
auto
body_inputs
=
make_
mlir_shaped
s
(
inputs
);
mlir_region
region
=
mlirRegionCreate
();
mlir_block
fbody
=
mlirBlockCreate
(
body_inputs
.
size
(),
body_inputs
.
data
(),
arg_locs
.
data
());
MlirBlock
result
=
fbody
.
get
();
...
...
@@ -610,7 +605,7 @@ struct mlir_program
return
"func.return"
;
if
(
ins
->
name
()
==
"@literal"
)
{
return
"
tosa.const
"
;
return
"
migraphx.literal
"
;
}
return
"migraphx."
+
ins
->
name
();
}
...
...
@@ -669,7 +664,8 @@ struct mlir_program
if
(
ins
->
name
()
==
"@literal"
)
{
literal
r
=
ins
->
get_literal
();
MlirType
tensor_type
=
make_tensor
(
ins
->
get_shape
());
MlirType
shaped_type
=
make_mlir_shaped
(
ins
->
get_shape
());
MlirType
tensor_type
=
rocmlirMIXRShapedTypeAsTensor
(
shaped_type
);
MlirAttribute
mlir_value_attr
=
mlirDenseElementsAttrRawBufferGet
(
tensor_type
,
r
.
get_shape
().
bytes
(),
r
.
data
());
ops
.
add_attributes
({{
"value"
,
mlir_value_attr
}});
...
...
@@ -947,35 +943,7 @@ void adjust_param_shapes(module& m, const std::vector<shape>& inputs)
auto
param
=
m
.
get_parameter
(
name
);
if
(
input
.
standard
())
continue
;
auto
lens
=
input
.
lens
();
auto
strides
=
input
.
strides
();
std
::
vector
<
operation
>
ops
;
if
(
input
.
transposed
())
{
auto
perm
=
find_permutation
(
input
);
auto
iperm
=
invert_permutation
(
perm
);
lens
=
reorder_dims
(
lens
,
iperm
);
strides
=
reorder_dims
(
strides
,
iperm
);
ops
.
push_back
(
make_op
(
"transpose"
,
{{
"permutation"
,
perm
}}));
}
if
(
input
.
broadcasted
())
{
std
::
transform
(
lens
.
begin
(),
lens
.
end
(),
strides
.
begin
(),
lens
.
begin
(),
[](
auto
len
,
auto
stride
)
->
std
::
size_t
{
if
(
stride
==
0
)
return
1
;
return
len
;
});
ops
.
push_back
(
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
input
.
lens
()}}));
}
auto
new_param
=
std
::
accumulate
(
ops
.
begin
(),
ops
.
end
(),
m
.
add_parameter
(
name
+
".0"
,
shape
{
input
.
type
(),
lens
}),
[
&
](
auto
x
,
auto
op
)
{
return
m
.
insert_instruction
(
param
,
op
,
x
);
});
auto
new_param
=
m
.
add_parameter
(
name
+
".0"
,
input
);
m
.
replace_instruction
(
param
,
new_param
);
m
.
remove_instruction
(
param
);
}
...
...
src/targets/gpu/pad.cpp
deleted
100644 → 0
View file @
a3d4b013
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/gpu/pad.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device/pad.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
shape
hip_pad
::
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
inputs
.
pop_back
();
check_shapes
{
inputs
,
*
this
}.
has
(
1
).
standard
();
return
op
.
compute_shape
(
inputs
);
}
argument
hip_pad
::
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
return
device
::
pad
(
ctx
.
get_stream
().
get
(),
args
.
back
(),
args
.
front
(),
op
.
value
,
op
.
pads
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/target.cpp
View file @
a98d86d9
...
...
@@ -52,7 +52,6 @@
#include <migraphx/simplify_qdq.hpp>
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/split_single_dyn_dim.hpp>
#include <migraphx/eliminate_fp8.hpp>
#include <migraphx/gpu/allocation_model.hpp>
#include <migraphx/gpu/compile_miopen.hpp>
#include <migraphx/gpu/compile_ops.hpp>
...
...
@@ -150,7 +149,7 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
prefuse_ops
{},
dead_code_elimination
{},
auto_contiguous
{},
eliminate_
fp8
{
unsupported_fp8_ops
},
eliminate_
data_type
{{
migraphx
::
shape
::
fp8e4m3fnuz_type
},
shape
::
float_type
,
unsupported_fp8_ops
},
dead_code_elimination
{},
optimize_module
{},
fuse_pointwise
{},
...
...
test/gpu/mlir.cpp
View file @
a98d86d9
...
...
@@ -141,9 +141,9 @@ TEST_CASE(conv)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution(%arg0:
tensor<2x8x3x3xf32>, %arg1: tensor<1x8x4x4xf32>) -> tensor
<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
(
%arg1, %arg0
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xf32
>, tensor<2x8x3x3xf32>) -> tensor
<1x2x2x2xf32>
return %0 :
tensor
<1x2x2x2xf32>
func.func @mlir_convolution(%arg0:
!migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
%arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32
, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
return %0 :
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -160,15 +160,38 @@ module {
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
conv_nhwc
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution(%arg0: !migraphx.shaped<2x8x3x3xf32, 72x1x24x8>, %arg1: !migraphx.shaped<1x8x4x4xf32, 128x1x32x8>) -> !migraphx.shaped<1x2x2x2xf32, 8x1x4x2> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution %arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32, 128x1x32x8>, <2x8x3x3xf32, 72x1x24x8> -> <1x2x2x2xf32, 8x1x4x2>
return %0 : !migraphx.shaped<1x2x2x2xf32, 8x1x4x2>
}
}
)__migraphx__"
;
migraphx
::
module
m
;
auto
x
=
m
.
add_parameter
(
"x"
,
{
migraphx
::
shape
::
float_type
,
{
1
,
8
,
4
,
4
},
{
128
,
1
,
32
,
8
}});
auto
w
=
m
.
add_parameter
(
"w"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
8
,
3
,
3
},
{
72
,
1
,
24
,
8
}});
auto
conv
=
m
.
add_instruction
(
migraphx
::
make_op
(
"convolution"
),
x
,
w
);
m
.
add_return
({
conv
});
auto
s
=
migraphx
::
gpu
::
dump_mlir
(
m
);
// Skip test if MLIR is not enabled
if
(
s
.
empty
())
return
;
CHECK
(
encode
(
s
)
==
encode
(
mlir_output
));
EXPECT
(
verify_mlir
(
m
));
}
TEST_CASE
(
conv_add_relu
)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_convolution_add_relu(%arg0:
tensor
<1x2x2x2xf32>, %arg1:
tensor<2x8x3x3xf32>, %arg2: tensor<1x8x4x4xf32>) -> tensor
<1x2x2x2xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
(
%arg2, %arg1
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xf32
>, tensor<2x8x3x3xf32>) -> tensor
<1x2x2x2xf32>
%1 = migraphx.add
(
%0, %arg0
)
:
(tensor
<1x2x2x2xf32
>, tensor<1x2x2x2xf32>) -> tensor
<1x2x2x2xf32>
%2 = migraphx.relu
(
%1
)
:
(tensor
<1x2x2x2xf32
>) -> tensor
<1x2x2x2xf32>
return %2 :
tensor
<1x2x2x2xf32>
func.func @mlir_convolution_add_relu(%arg0:
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>, %arg1:
!migraphx.shaped<2x8x3x3xf32, 72x9x3x1>, %arg2: !migraphx.shaped<1x8x4x4xf32, 128x16x4x1>) -> !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.convolution
%arg2, %arg1 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xf32
, 128x16x4x1>, <2x8x3x3xf32, 72x9x3x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%1 = migraphx.add
%0, %arg0 : <1x2x2x2xf32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%2 = migraphx.relu
%1 : <1x2x2x2xf32
, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
return %2 :
!migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -192,10 +215,10 @@ TEST_CASE(quant_dot_add)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_quant_dot_add(%arg0:
tensor
<1x5x4xi8>, %arg1:
tensor
<1x4x3xi8>, %arg2:
tensor<1x5x3xi32>) -> tensor
<1x5x3xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xi8
>, tensor<1x4x3xi8>) -> tensor
<1x5x3xi32>
%1 = migraphx.add
(
%0, %arg2
)
:
(tensor
<1x5x3xi32
>, tensor<1x5x3xi32>) -> tensor
<1x5x3xi32>
return %1 :
tensor
<1x5x3xi32>
func.func @mlir_quant_dot_add(%arg0:
!migraphx.shaped
<1x5x4xi8
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xi8
, 12x3x1
>, %arg2:
!migraphx.shaped<1x5x3xi32, 15x3x1>) -> !migraphx.shaped
<1x5x3xi32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_dot
%arg0, %arg1 : <1x5x4xi8
, 20x4x1>, <1x4x3xi8, 12x3x1> ->
<1x5x3xi32
, 15x3x1
>
%1 = migraphx.add
%0, %arg2 : <1x5x3xi32
, 15x3x1>, <1x5x3xi32, 15x3x1> ->
<1x5x3xi32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xi32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -219,10 +242,10 @@ TEST_CASE(dot_add)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_add(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor
<1x4x3xf32>, %arg2:
tensor<1x5x3xf32>) -> tensor
<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.add
(
%0, %arg2
)
:
(tensor
<1x5x3xf32
>, tensor<1x5x3xf32>) -> tensor
<1x5x3xf32>
return %1 :
tensor
<1x5x3xf32>
func.func @mlir_dot_add(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xf32
, 12x3x1
>, %arg2:
!migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped
<1x5x3xf32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.add
%0, %arg2 : <1x5x3xf32
, 15x3x1>, <1x5x3xf32, 15x3x1> ->
<1x5x3xf32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -245,11 +268,11 @@ TEST_CASE(conv_int8_dequantize_quantize)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0:
tensor<2x8x3x3xi8>, %arg1: tensor<1x8x4x4xi8>, %arg2: tensor
<1x2x2x2xf32>, %arg3:
tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xi32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution
(
%arg1, %arg0
)
{dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} :
(tensor
<1x8x4x4xi8
>, tensor<2x8x3x3xi8>) -> tensor
<1x2x2x2xi32>
%1 = migraphx.dequantizelinear
(
%0, %arg2, %arg3
)
:
(tensor
<1x2x2x2xi32
>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xf32>
%2 = migraphx.quantizelinear
(
%1, %arg2, %arg3
)
:
(tensor
<1x2x2x2xf32
>, tensor<1x2x2x2xf32>, tensor<1x2x2x2xi32>) -> tensor
<1x2x2x2xi32>
return %2 :
tensor
<1x2x2x2xi32>
func.func @mlir_quant_convolution_dequantizelinear_quantizelinear(%arg0:
!migraphx.shaped<2x8x3x3xi8, 72x9x3x1>, %arg1: !migraphx.shaped<1x8x4x4xi8, 128x16x4x1>, %arg2: !migraphx.shaped
<1x2x2x2xf32
, 8x4x2x1
>, %arg3:
!migraphx.shaped<1x2x2x2xi32, 8x4x2x1>) -> !migraphx.shaped
<1x2x2x2xi32
, 8x4x2x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.quant_convolution
%arg1, %arg0 {dilation = [1, 1], group = 1 : i64, padding = [0, 0, 0, 0], padding_mode = 0 : i64, stride = [1, 1]} : <1x8x4x4xi8
, 128x16x4x1>, <2x8x3x3xi8, 72x9x3x1> ->
<1x2x2x2xi32
, 8x4x2x1
>
%1 = migraphx.dequantizelinear
%0, %arg2, %arg3 : <1x2x2x2xi32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> ->
<1x2x2x2xf32
, 8x4x2x1
>
%2 = migraphx.quantizelinear
%1, %arg2, %arg3 : <1x2x2x2xf32
, 8x4x2x1>, <1x2x2x2xf32, 8x4x2x1>, !migraphx.shaped<1x2x2x2xi32, 8x4x2x1> ->
<1x2x2x2xi32
, 8x4x2x1
>
return %2 :
!migraphx.shaped
<1x2x2x2xi32
, 8x4x2x1
>
}
}
)__migraphx__"
;
...
...
@@ -278,10 +301,10 @@ TEST_CASE(dot_convert)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_convert(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor<1x4x3xf32>) -> tensor
<1x5x3xf16> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.convert
(
%0
)
{target_type = 1 : i64} :
(tensor
<1x5x3xf32
>) -> tensor
<1x5x3xf16>
return %1 :
tensor
<1x5x3xf16>
func.func @mlir_dot_convert(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped<1x4x3xf32, 12x3x1>) -> !migraphx.shaped
<1x5x3xf16
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.convert
%0 {target_type = 1 : i64} : <1x5x3xf32
, 15x3x1> to
<1x5x3xf16
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf16
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
@@ -304,10 +327,10 @@ TEST_CASE(dot_where)
{
const
std
::
string
mlir_output
=
R"__migraphx__(
module {
func.func @mlir_dot_where(%arg0:
tensor
<1x5x4xf32>, %arg1:
tensor
<1x4x3xf32>, %arg2:
tensor
<1x5x3xi8>, %arg3:
tensor<1x5x3xf32>) -> tensor
<1x5x3xf32> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
(
%arg0, %arg1
)
:
(tensor
<1x5x4xf32
>, tensor<1x4x3xf32>) -> tensor
<1x5x3xf32>
%1 = migraphx.where
(
%arg2, %0, %arg3
)
:
(tensor
<1x5x3xi8
>, tensor<1x5x3xf32>, tensor<1x5x3xf32>) -> tensor
<1x5x3xf32>
return %1 :
tensor
<1x5x3xf32>
func.func @mlir_dot_where(%arg0:
!migraphx.shaped
<1x5x4xf32
, 20x4x1
>, %arg1:
!migraphx.shaped
<1x4x3xf32
, 12x3x1
>, %arg2:
!migraphx.shaped
<1x5x3xi8
, 15x3x1
>, %arg3:
!migraphx.shaped<1x5x3xf32, 15x3x1>) -> !migraphx.shaped
<1x5x3xf32
, 15x3x1
> attributes {arch = "", kernel = "mixr", num_cu = 0 : i64} {
%0 = migraphx.dot
%arg0, %arg1 : <1x5x4xf32
, 20x4x1>, <1x4x3xf32, 12x3x1> ->
<1x5x3xf32
, 15x3x1
>
%1 = migraphx.where
%arg2, %0, %arg3 : <1x5x3xi8
, 15x3x1>, <1x5x3xf32, 15x3x1>, <1x5x3xf32, 15x3x1> ->
<1x5x3xf32
, 15x3x1
>
return %1 :
!migraphx.shaped
<1x5x3xf32
, 15x3x1
>
}
}
)__migraphx__"
;
...
...
test/onnx/.onnxrt-commit
View file @
a98d86d9
a5537f2f563d4975c7e6121a7eb260bbbfd9455a
d69842226b47e5336568103541b071447caeb9bf
test/verify/gemm_2args_mm_8.cpp
View file @
a98d86d9
...
...
@@ -48,5 +48,5 @@ struct gemm_2args_mm_8 : verify_program<gemm_2args_mm_8<DType>>
};
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
half_type
>;
//
template struct gemm_2args_mm_8<migraphx::shape::half_type>;
template
struct
gemm_2args_mm_8
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
test/verify/gemm_add_broadcast2.cpp
View file @
a98d86d9
...
...
@@ -51,5 +51,5 @@ struct gemm_add_broadcast2 : verify_program<gemm_add_broadcast2<DType>>
};
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
float_type
>;
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
half_type
>;
//
template struct gemm_add_broadcast2<migraphx::shape::half_type>;
template
struct
gemm_add_broadcast2
<
migraphx
::
shape
::
fp8e4m3fnuz_type
>;
tools/format.py
View file @
a98d86d9
...
...
@@ -63,7 +63,8 @@ def clang_format(against, apply=False, path=CLANG_FORMAT_PATH):
print
(
f
"
{
git_clang_format
}
not installed. Skipping format."
)
return
diff_flag
=
""
if
apply
else
"--diff"
run
(
f
"
{
git_clang_format
}
--binary
{
clang_format
}
{
diff_flag
}
{
base
}
"
)
run
(
f
"
{
git_clang_format
}
--extensions c,cpp,hpp,h,cl,hip,in --binary
{
clang_format
}
{
diff_flag
}
{
base
}
"
)
def
get_files_changed
(
against
,
ext
=
(
'py'
)):
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment