Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
74bd6d61
"vscode:/vscode.git/clone" did not exist on "627d8ef35a6da8ad268b5197e3045ccdfb4ac684"
Commit
74bd6d61
authored
Sep 17, 2022
by
Paul
Browse files
Merge branch 'jit-concat' into jit-concat-pointwise
parents
b30c3408
8109aac8
Changes
44
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
111 additions
and
91 deletions
+111
-91
CMakeLists.txt
CMakeLists.txt
+1
-1
Dockerfile
Dockerfile
+1
-1
doc/src/reference/py.rst
doc/src/reference/py.rst
+16
-1
src/include/migraphx/op/fmod.hpp
src/include/migraphx/op/fmod.hpp
+0
-9
src/include/migraphx/op/mod.hpp
src/include/migraphx/op/mod.hpp
+0
-9
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+6
-5
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+18
-3
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+39
-0
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+4
-19
src/targets/gpu/compile_gen.cpp
src/targets/gpu/compile_gen.cpp
+10
-3
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-5
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+6
-1
src/targets/gpu/include/migraphx/gpu/gather.hpp
src/targets/gpu/include/migraphx/gpu/gather.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
+0
-1
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
+2
-14
src/targets/gpu/include/migraphx/gpu/lrn.hpp
src/targets/gpu/include/migraphx/gpu/lrn.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/reverse.hpp
src/targets/gpu/include/migraphx/gpu/reverse.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/softmax.hpp
src/targets/gpu/include/migraphx/gpu/softmax.hpp
+2
-14
No files found.
CMakeLists.txt
View file @
74bd6d61
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
...
@@ -63,7 +63,7 @@ set(CMAKE_EXTRA_INCLUDE_FILES)
include
(
ROCMSetupVersion
)
include
(
ROCMSetupVersion
)
rocm_setup_version
(
VERSION 2.
3
)
rocm_setup_version
(
VERSION 2.
4
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
set
(
MIGRAPHX_SO_VERSION
${
PROJECT_VERSION_MAJOR
}
.
${
PROJECT_VERSION_MINOR
}
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
option
(
BUILD_SHARED_LIBS
"Build as a shared library"
ON
)
...
...
Dockerfile
View file @
74bd6d61
...
@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -86,7 +86,7 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/llvm-project-mlir@
d2cb9e580550e92ab75a0a417e7a4abd02a24ed
f
-DBUILD_MIXR_TARGET
=
On
RUN
cget
-p
/usr/local
install
ROCmSoftwarePlatform/llvm-project-mlir@
e8e77eb16be413d301ea8509726d47f265d9011
f
-DBUILD_MIXR_TARGET
=
On
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
...
...
doc/src/reference/py.rst
View file @
74bd6d61
...
@@ -84,6 +84,12 @@ argument
...
@@ -84,6 +84,12 @@ argument
Construct an argument from a python buffer. This can include numpy arrays.
Construct an argument from a python buffer. This can include numpy arrays.
.. py:method:: data_ptr()
Returns the address to the underlying argument data.
:rtype: int
.. py:method:: get_shape()
.. py:method:: get_shape()
Returns the shape of the argument.
Returns the shape of the argument.
...
@@ -113,7 +119,16 @@ argument
...
@@ -113,7 +119,16 @@ argument
:param shape s: Shape of argument to fill.
:param shape s: Shape of argument to fill.
:param int value: Value to fill in the argument.
:param int value: Value to fill in the argument.
:rtype argument
:rtype: argument
.. py:function:: argument_from_pointer(shape, address)
Create argument from data stored in given address without copy.
:param shape shape: Shape of the data stored in address.
:param long address: Memory address of data from another source
:rtype: argument
target
target
------
------
...
...
src/include/migraphx/op/fmod.hpp
View file @
74bd6d61
...
@@ -24,17 +24,8 @@
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_FMOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/include/migraphx/op/mod.hpp
View file @
74bd6d61
...
@@ -24,17 +24,8 @@
...
@@ -24,17 +24,8 @@
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#ifndef MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#define MIGRAPHX_GUARD_OPERATORS_MOD_HPP
#include <array>
#include <migraphx/op/binary.hpp>
#include <migraphx/op/binary.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/streamutils.hpp>
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <cmath>
#include <cmath>
#include <utility>
#include <type_traits>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/py/migraphx_py.cpp
View file @
74bd6d61
...
@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -264,12 +264,13 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
py
::
class_
<
migraphx
::
argument
>
(
m
,
"argument"
,
py
::
buffer_protocol
())
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def_buffer
([](
migraphx
::
argument
&
x
)
->
py
::
buffer_info
{
return
to_buffer_info
(
x
);
})
.
def
(
"__init__"
,
.
def
(
py
::
init
([](
py
::
buffer
b
)
{
[](
migraphx
::
argument
&
x
,
py
::
buffer
b
)
{
py
::
buffer_info
info
=
b
.
request
();
py
::
buffer_info
info
=
b
.
request
();
return
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
new
(
&
x
)
migraphx
::
argument
(
to_shape
(
info
),
info
.
ptr
);
}))
})
.
def
(
"get_shape"
,
&
migraphx
::
argument
::
get_shape
)
.
def
(
"get_shape"
,
&
migraphx
::
argument
::
get_shape
)
.
def
(
"data_ptr"
,
[](
migraphx
::
argument
&
x
)
{
return
reinterpret_cast
<
std
::
uintptr_t
>
(
x
.
data
());
})
.
def
(
"tolist"
,
.
def
(
"tolist"
,
[](
migraphx
::
argument
&
x
)
{
[](
migraphx
::
argument
&
x
)
{
py
::
list
l
{
x
.
get_shape
().
elements
()};
py
::
list
l
{
x
.
get_shape
().
elements
()};
...
...
src/simplify_algebra.cpp
View file @
74bd6d61
...
@@ -1001,20 +1001,35 @@ struct find_split_reshape
...
@@ -1001,20 +1001,35 @@ struct find_split_reshape
auto
rsp_lens
=
rsp
->
get_shape
().
lens
();
auto
rsp_lens
=
rsp
->
get_shape
().
lens
();
auto
rsp_strides
=
rsp
->
get_shape
().
strides
();
auto
rsp_strides
=
rsp
->
get_shape
().
strides
();
rsp_strides
.
insert
(
rsp_strides
.
begin
(),
rsp_strides
[
0
]
*
rsp_lens
[
0
]);
rsp_strides
.
insert
(
rsp_strides
.
begin
(),
rsp_strides
[
0
]
*
rsp_lens
[
0
]);
auto
ait
=
std
::
find
(
rsp_strides
.
begin
(),
rsp_strides
.
end
(),
slc_dim_size
);
auto
ait
=
std
::
find
(
rsp_strides
.
begin
(),
rsp_strides
.
end
(),
slc_dim_size
);
int
rsp_axis
=
-
1
;
if
(
ait
==
rsp_strides
.
end
())
if
(
ait
==
rsp_strides
.
end
())
{
{
return
;
return
;
}
}
int
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
else
if
(
ait
==
rsp_strides
.
end
()
-
1
)
{
// edge case
// slice_dim == 1, in that case it could match with last stride of 1.
// it should accumulate lengths from last dim in that case. discount 1 to avoid going
// out of bounds.
assert
(
slc_dim_size
==
1
);
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
)
-
1
;
}
else
{
rsp_axis
=
std
::
distance
(
rsp_strides
.
begin
(),
ait
);
}
// calculate reshape output shape
// calculate reshape output shape
std
::
vector
<
int64_t
>
vec_dims
(
vec_rsp
.
size
());
std
::
vector
<
int64_t
>
vec_dims
(
vec_rsp
.
size
());
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
vec_dims
.
begin
(),
[
&
](
auto
is
)
{
std
::
transform
(
vec_rsp
.
begin
(),
vec_rsp
.
end
(),
vec_dims
.
begin
(),
[
&
](
auto
is
)
{
return
is
->
get_shape
().
lens
()[
rsp_axis
];
return
is
->
get_shape
().
lens
()[
rsp_axis
];
});
});
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
std
::
vector
<
int64_t
>
rsp_out_lens
(
rsp_lens
.
begin
(),
rsp_lens
.
end
());
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
rsp_out_lens
[
rsp_axis
]
=
std
::
accumulate
(
vec_dims
.
begin
(),
vec_dims
.
end
(),
std
::
int64_t
{
0
});
// insert the reshape instruction and add contiguous if needed
// insert the reshape instruction and add contiguous if needed
...
...
src/simplify_reshapes.cpp
View file @
74bd6d61
...
@@ -271,6 +271,44 @@ struct find_nested_slice
...
@@ -271,6 +271,44 @@ struct find_nested_slice
}
}
};
};
struct
find_concat_multibroadcasts
{
auto
matcher
()
const
{
return
match
::
name
(
"concat"
)(
match
::
all_of
[
match
::
inputs
()](
match
::
name
(
"multibroadcast"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
mr
)
const
{
auto
ins
=
mr
.
result
;
auto
op
=
any_cast
<
op
::
concat
>
(
ins
->
get_operator
());
auto
out_lens
=
ins
->
get_shape
().
lens
();
auto
inputs
=
ins
->
inputs
();
auto
in_strides
=
inputs
.
front
()
->
get_shape
().
strides
();
// Only apply when concat axis is not a broadcasted dimension
if
(
std
::
any_of
(
inputs
.
begin
(),
inputs
.
end
(),
[
&
](
auto
i
)
{
return
i
->
get_shape
().
strides
()[
op
.
axis
]
==
0
;
}))
{
return
;
}
// Use inputs of multibroadcast ops as inputs to new concat op
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[](
auto
i
)
{
return
i
->
inputs
().
front
();
});
// Reduce axis by number of leading broadcasted dimensions
if
(
inputs
.
front
()
->
get_shape
().
lens
().
size
()
<
out_lens
.
size
())
op
.
axis
-=
std
::
count
(
in_strides
.
begin
(),
in_strides
.
begin
()
+
op
.
axis
,
0
);
auto
concat
=
m
.
insert_instruction
(
ins
,
op
,
inputs
);
m
.
replace_instruction
(
ins
,
migraphx
::
make_op
(
"multibroadcast"
,
{{
"out_lens"
,
out_lens
}}),
concat
);
}
};
struct
find_concat_transpose
struct
find_concat_transpose
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
...
@@ -764,6 +802,7 @@ void simplify_reshapes::apply(module& m) const
find_reshaper
{},
find_reshaper
{},
find_transpose
{},
find_transpose
{},
find_concat_transpose
{},
find_concat_transpose
{},
find_concat_multibroadcasts
{},
find_nested_convert
{},
find_nested_convert
{},
find_nested_slice
{},
find_nested_slice
{},
find_nested_concat
{},
find_nested_concat
{},
...
...
src/targets/gpu/CMakeLists.txt
View file @
74bd6d61
...
@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
...
@@ -322,26 +322,11 @@ message(STATUS "extractkernel: ${MIGRAPHX_EXTRACT_KERNEL}")
set
(
MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL
""
)
set
(
MIGRAPHX_ENABLE_MLIR OFF CACHE BOOL
""
)
if
(
MIGRAPHX_ENABLE_MLIR
)
if
(
MIGRAPHX_ENABLE_MLIR
)
find_library
(
MLIRAPI_LIBRARY MLIRMIOpen
# Find package rocMLIR
PATH_SUFFIXES
find_package
(
rocMLIR 1.0.0 CONFIG REQUIRED
)
# Workaournd broken mlir install
message
(
STATUS
"Build with rocMLIR::rockCompiler
${
rocMLIR_VERSION
}
"
)
lib/ lib/lib
)
# REQUIRED is not supported before cmake 3.18
if
(
NOT MLIRAPI_LIBRARY
)
message
(
FATAL_ERROR
"libMLIRMIOpen not found"
)
else
()
message
(
STATUS
"Build with libMLIRMIOpen: "
${
MLIRAPI_LIBRARY
}
)
endif
()
find_path
(
MLIRAPI_HEADERS NAMES mlir-c/Dialect/MIGraphX.h
)
# Workaround MLIR broken installation
find_path
(
MLIRAPI_HEADERS2 NAMES mlir-c/Registration.h
PATH_SUFFIXES
include/external/include external/include
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_MLIR"
)
target_compile_definitions
(
migraphx_gpu PRIVATE
"-DMIGRAPHX_MLIR"
)
target_include_directories
(
migraphx_gpu SYSTEM PRIVATE
${
MLIRAPI_HEADERS
}
${
MLIRAPI_HEADERS2
}
)
target_link_libraries
(
migraphx_gpu PUBLIC rocMLIR::rockCompiler
)
target_link_libraries
(
migraphx_gpu PUBLIC
${
MLIRAPI_LIBRARY
}
)
endif
()
endif
()
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
""
)
set
(
MIGRAPHX_USE_HIPRTC OFF CACHE BOOL
""
)
...
...
src/targets/gpu/compile_gen.cpp
View file @
74bd6d61
...
@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
...
@@ -61,12 +61,19 @@ vectorize vectorize::elements(std::size_t axis, const std::vector<shape>& inputs
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
[
&
](
const
auto
&
input
)
->
std
::
size_t
{
auto
stride
=
input
.
strides
()[
axis
];
auto
stride
=
input
.
strides
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
auto
len
=
input
.
lens
()[
axis
];
if
(
stride
!=
0
and
stride
!=
1
)
if
(
not
contains
({
0
,
1
},
stride
)
)
return
1
;
return
1
;
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
if
(
len
==
1
and
input
.
elements
()
>
sizes
.
front
())
return
sizes
.
front
();
return
sizes
.
front
();
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
vsize
)
{
sizes
.
begin
(),
sizes
.
end
(),
[
&
](
auto
i
)
{
return
(
len
%
i
)
==
0
;
});
// The len is divisible by the size and all the strides are divisible by
// the size
return
(
len
%
vsize
)
==
0
and
std
::
all_of
(
input
.
strides
().
begin
(),
input
.
strides
().
end
(),
[
&
](
auto
i
)
{
return
contains
({
0
,
1
},
i
)
or
i
%
vsize
==
0
;
});
});
if
(
it
!=
sizes
.
end
())
if
(
it
!=
sizes
.
end
())
return
*
it
;
return
*
it
;
return
1
;
return
1
;
...
...
src/targets/gpu/fuse_ops.cpp
View file @
74bd6d61
...
@@ -26,7 +26,6 @@
...
@@ -26,7 +26,6 @@
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/clip.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/convolution.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/oper.hpp>
...
@@ -50,8 +49,6 @@
...
@@ -50,8 +49,6 @@
#include <migraphx/array.hpp>
#include <migraphx/array.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/op/clip.hpp>
#include <migraphx/op/contiguous.hpp>
#include <cmath>
#include <cmath>
#include <set>
#include <set>
...
@@ -262,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
...
@@ -262,7 +259,7 @@ struct hip_add_relu : binary_device<hip_add_relu, &device::add_relu>
};
};
MIGRAPHX_REGISTER_OP
(
hip_add_relu
)
MIGRAPHX_REGISTER_OP
(
hip_add_relu
)
struct
hip_add_sigmoid
:
binary_device
<
hip_add_
relu
,
&
device
::
add_sigmoid
>
struct
hip_add_sigmoid
:
binary_device
<
hip_add_
sigmoid
,
&
device
::
add_sigmoid
>
{
{
};
};
MIGRAPHX_REGISTER_OP
(
hip_add_sigmoid
)
MIGRAPHX_REGISTER_OP
(
hip_add_sigmoid
)
...
@@ -1036,7 +1033,7 @@ struct find_gemm_pointwise
...
@@ -1036,7 +1033,7 @@ struct find_gemm_pointwise
// const-fold input if not standard shape since rocblas can't handle it
// const-fold input if not standard shape since rocblas can't handle it
if
(
not
c_ins
->
get_shape
().
standard
())
if
(
not
c_ins
->
get_shape
().
standard
())
{
{
auto
c
=
op
::
contiguous
{}
;
auto
c
=
make_op
(
"
contiguous
"
)
;
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
auto
l
=
c
.
compute
(
c
.
compute_shape
({
c_ins
->
get_shape
()}),
{
c_ins
->
eval
()});
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
c_ins
=
m
.
add_literal
(
l
.
get_shape
(),
l
.
data
());
}
}
...
...
src/targets/gpu/gemm_impl.cpp
View file @
74bd6d61
...
@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
...
@@ -176,8 +176,13 @@ void gemm_impl(context& ctx,
auto
num_matrices
=
std
::
accumulate
(
auto
num_matrices
=
std
::
accumulate
(
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
out_lens
.
rbegin
()
+
2
,
out_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
if
(
num_matrices
==
1
)
if
(
num_matrices
==
1
or
(
num_matrices
>
1
and
get_batch_stride
(
args
[
1
])
==
0
)
)
{
{
// If the batch dimension of B is broadcasted, then we can
// multiply m by the batch_size and use rocblas_gemm_ex
// instead of rocblas_gemm_strided_batched_ex.
m
*=
num_matrices
;
// the rocblas_gemm API handles inputs and output matrices as
// the rocblas_gemm API handles inputs and output matrices as
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
...
...
src/targets/gpu/include/migraphx/gpu/gather.hpp
View file @
74bd6d61
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/op/gather.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/int8_conv_pack.hpp
View file @
74bd6d61
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_CONV_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <utility>
#include <utility>
...
...
src/targets/gpu/include/migraphx/gpu/int8_gemm_pack.hpp
View file @
74bd6d61
...
@@ -25,7 +25,6 @@
...
@@ -25,7 +25,6 @@
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#define MIGRAPHX_GUARD_RTGLIB_INT8_GEMM_PACK_HPP
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/op/quant_dot.hpp>
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <utility>
#include <utility>
...
...
src/targets/gpu/include/migraphx/gpu/logsoftmax.hpp
View file @
74bd6d61
...
@@ -24,22 +24,10 @@
...
@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_LOGSOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/lrn.hpp
View file @
74bd6d61
...
@@ -26,7 +26,7 @@
...
@@ -26,7 +26,7 @@
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/prefuse_ops.hpp
View file @
74bd6d61
...
@@ -25,7 +25,7 @@
...
@@ -25,7 +25,7 @@
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#define MIGRAPHX_GUARD_GPU_PREFUSE_OPS_HPP
#include <migraphx/config.hpp>
#include <migraphx/config.hpp>
#include <
migraphx/gpu/context.hpp
>
#include <
string
>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/reverse.hpp
View file @
74bd6d61
...
@@ -27,7 +27,7 @@
...
@@ -27,7 +27,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/op/reverse.hpp>
#include <migraphx/gpu/
miopen
.hpp>
#include <migraphx/gpu/
context
.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
src/targets/gpu/include/migraphx/gpu/softmax.hpp
View file @
74bd6d61
...
@@ -24,22 +24,10 @@
...
@@ -24,22 +24,10 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#ifndef MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#define MIGRAPHX_GUARD_RTGLIB_SOFTMAX_HPP
#include <migraphx/gpu/lowering.hpp>
#include <migraphx/manage_ptr.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/reflect.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gpu/miopen.hpp>
#include <migraphx/gpu/hip.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/gpu/rocblas.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <utility>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment