Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bf0a4713
Unverified
Commit
bf0a4713
authored
May 24, 2022
by
Paul Fultz II
Committed by
GitHub
May 24, 2022
Browse files
Improve applicable batched gemms (#1214)
* Improve applicable batched gemms for bert
parent
150d6d20
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
35 deletions
+86
-35
src/reduce_dims.cpp
src/reduce_dims.cpp
+18
-4
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+38
-6
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+7
-25
test/reduce_dims.cpp
test/reduce_dims.cpp
+23
-0
No files found.
src/reduce_dims.cpp
View file @
bf0a4713
...
@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -16,10 +16,8 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto
bstride
=
s
.
strides
()[
n
+
1
];
auto
bstride
=
s
.
strides
()[
n
+
1
];
auto
blen
=
s
.
lens
()[
n
+
1
];
auto
blen
=
s
.
lens
()[
n
+
1
];
if
(
astride
==
bstride
*
blen
)
if
(
astride
==
bstride
*
blen
or
alen
==
1
)
{
new_lens
.
push_back
(
alen
*
blen
);
new_lens
.
push_back
(
alen
*
blen
);
}
}
}
if
(
new_lens
.
size
()
!=
shapes
.
size
())
if
(
new_lens
.
size
()
!=
shapes
.
size
())
return
false
;
return
false
;
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return
true
;
return
true
;
}
}
void
reduce_dim1
(
std
::
vector
<
shape
>&
shapes
)
{
if
(
std
::
any_of
(
shapes
.
begin
(),
shapes
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
().
size
()
<
2
or
s
.
lens
().
back
()
!=
1
;
}))
return
;
for
(
auto
&
s
:
shapes
)
{
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
.
pop_back
();
strides
.
pop_back
();
s
=
shape
{
s
.
type
(),
lens
,
strides
};
}
}
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
{
{
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
return
n
+
1
;
return
n
+
1
;
}
}
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
n
=
reduce_dim_all
(
shapes
,
n
);
n
=
reduce_dim_all
(
shapes
,
n
);
reduce_dim1
(
shapes
);
}
}
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
...
...
src/targets/gpu/gemm_impl.cpp
View file @
bf0a4713
#include <rocblas.h>
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
}
}
void
blas_shape
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
2
)
return
;
if
(
std
::
none_of
(
s
.
strides
().
end
()
-
2
,
s
.
strides
().
end
(),
[
&
](
auto
i
)
{
return
i
==
1
;
}))
MIGRAPHX_THROW
(
"GPU_GEMM: needs to have one matrix stride as 1"
);
if
(
s
.
lens
().
size
()
<
3
)
return
;
shape
batch_shape
{
s
.
type
(),
{
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
},
{
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
}};
auto
batch_shapes
=
reduce_dims
({
batch_shape
});
if
(
batch_shapes
.
front
().
lens
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
{
...
@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
...
@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return
f
(
xs
...,
nullptr
,
nullptr
);
return
f
(
xs
...,
nullptr
,
nullptr
);
}
}
static
bool
is_transposed
(
const
shape
&
s
)
{
if
(
not
s
.
transposed
())
return
false
;
return
s
.
strides
().
back
()
!=
1
;
}
static
rocblas_int
get_batch_stride
(
const
argument
&
a
)
{
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
}
template
<
class
T
>
template
<
class
T
>
void
gemm_impl
(
context
&
ctx
,
void
gemm_impl
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
...
@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
bool
transa
=
args
[
0
].
get_shape
()
.
transposed
(
);
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
args
[
1
].
get_shape
()
.
transposed
(
);
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
dim_0
=
n_dim
-
2
;
...
@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
...
@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
}
}
else
else
{
{
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
...
@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
k
*
n
,
b_stride
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
m
*
k
,
a_stride
,
beta_v
,
beta_v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
bf0a4713
...
@@ -18,6 +18,8 @@ namespace gpu {
...
@@ -18,6 +18,8 @@ namespace gpu {
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
{
{
...
@@ -50,13 +52,14 @@ struct rocblas_gemm
...
@@ -50,13 +52,14 @@ struct rocblas_gemm
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
();
b
atch_not_transposed
(
inputs
[
0
].
strides
()
);
b
las_shape
(
inputs
[
0
]
);
b
atch_not_transposed
(
inputs
[
1
].
strides
()
);
b
las_shape
(
inputs
[
1
]
);
// if gemm and add are fused
// if gemm and add are fused
if
(
not
float_equal
(
beta
,
0
)
)
if
(
in_shapes
.
size
()
>
2
)
{
{
auto
cmat_shape
=
in_shapes
.
back
();
auto
cmat_shape
=
in_shapes
.
back
();
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
if
(
cmat_shape
.
lens
()
!=
op_out_shape
.
lens
())
if
(
cmat_shape
.
lens
()
!=
op_out_shape
.
lens
())
{
{
...
@@ -71,6 +74,7 @@ struct rocblas_gemm
...
@@ -71,6 +74,7 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
}
return
op_out_shape
;
}
}
return
op
.
compute_shape
(
in_shapes
);
return
op
.
compute_shape
(
in_shapes
);
...
@@ -96,28 +100,6 @@ struct rocblas_gemm
...
@@ -96,28 +100,6 @@ struct rocblas_gemm
return
args
.
back
();
return
args
.
back
();
}
}
void
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
{
if
(
strides
.
size
()
<=
2
)
return
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
std
::
vector
<
std
::
size_t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
if
(
std
::
all_of
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
)
{
return
(
i
<
matrix_size
);
}))
{
MIGRAPHX_THROW
(
"GPU_GEMM: matrix size and batch size {"
+
to_string_range
(
strides
)
+
"} are transposed!"
);
}
if
(
std
::
adjacent_find
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
<
j
or
i
<
matrix_size
or
j
<
matrix_size
);
})
!=
batch
.
end
())
{
MIGRAPHX_THROW
(
"GPU_GEMM: batch size {"
+
to_string_range
(
strides
)
+
"} is transposed!"
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
...
test/reduce_dims.cpp
View file @
bf0a4713
...
@@ -109,6 +109,29 @@ TEST_CASE(transposed1)
...
@@ -109,6 +109,29 @@ TEST_CASE(transposed1)
EXPECT
(
eshapes
==
rshapes
);
EXPECT
(
eshapes
==
rshapes
);
}
}
TEST_CASE
(
non_packed_empty1
)
{
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
1
,
12
},
{
589824
,
64
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
12
},
{
64
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
eshapes
==
rshapes
);
}
TEST_CASE
(
non_packed_empty2
)
{
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
12
,
1
},
{
64
,
589824
})};
std
::
vector
<
migraphx
::
shape
>
eshapes
=
{
make_shape
({
12
},
{
64
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
eshapes
==
rshapes
);
}
TEST_CASE
(
single_dim
)
{
std
::
vector
<
migraphx
::
shape
>
ishapes
=
{
make_shape
({
1
},
{
1
})};
auto
rshapes
=
migraphx
::
reduce_dims
(
ishapes
);
EXPECT
(
ishapes
==
rshapes
);
}
TEST_CASE
(
empty
)
TEST_CASE
(
empty
)
{
{
auto
rshapes
=
migraphx
::
reduce_dims
({});
auto
rshapes
=
migraphx
::
reduce_dims
({});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment