Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bb666690
Commit
bb666690
authored
Aug 28, 2018
by
Paul
Browse files
Fix clang tidy complexity issue
parent
26a78750
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
67 additions
and
63 deletions
+67
-63
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+67
-63
No files found.
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
bb666690
...
...
@@ -51,15 +51,9 @@ auto nary_nonstandard_impl(F f, argument result, Arguments... args)
});
}
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
inline
auto
binary_broadcast_vec
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
template
<
class
F
>
void
binary_broadcast_vec_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
...
...
@@ -106,57 +100,54 @@ inline auto binary_broadcast_vec(const argument& result, const argument& arg1, c
}
});
});
};
}
inline
auto
binary_broadcast
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
template
<
class
F
>
void
binary_broadcast_impl
(
F
f
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
});
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPH_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
});
};
}
)
;
}
template
<
class
...
Arguments
>
auto
nary_standard_vec
(
argument
result
,
Arguments
...
args
)
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec
_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
...
...
@@ -177,13 +168,11 @@ auto nary_standard_vec(argument result, Arguments... args)
outp
[
i
]
=
out
;
});
});
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
template
<
class
F
,
class
...
Arguments
>
void
nary_standard
_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
...
...
@@ -192,29 +181,44 @@ auto nary_standard(argument result, Arguments... args)
gs_launch
(
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
};
}
template
<
class
...
Arguments
>
auto
nary_impl
(
argument
result
,
Arguments
...
args
)
template
<
class
F
,
class
...
Arguments
>
void
nary_impl
(
F
f
,
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
bool
packed
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
packed
();
});
bool
same_shapes
=
all_of
({
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
if
(
standard
or
(
packed
and
same_shapes
))
nary_standard
(
result
,
args
...)
(
f
)
;
nary_standard
_impl
(
f
,
result
,
args
...);
else
nary_nonstandard
(
result
,
args
...)(
f
);
nary_nonstandard_impl
(
f
,
result
,
args
...);
}
template
<
class
...
Arguments
>
auto
nary_nonstandard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
f
,
result
,
args
...);
};
}
template
<
class
...
Arguments
>
auto
nary
(
argument
result
,
Arguments
...
args
)
{
return
nary_impl
(
result
,
args
...);
return
[
=
](
auto
f
)
{
nary_impl
(
f
,
result
,
args
...);
};
}
inline
auto
nary
(
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
...
...
@@ -235,13 +239,13 @@ inline auto nary(const argument& result, const argument& arg1, const argument& a
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec
(
result
,
arg1
,
arg2
)
(
f
)
;
binary_broadcast_vec
_impl
(
f
,
result
,
arg1
,
arg2
);
else
binary_broadcast
(
result
,
arg1
,
arg2
)
(
f
)
;
binary_broadcast
_impl
(
f
,
result
,
arg1
,
arg2
);
return
;
}
}
nary_impl
(
result
,
arg1
,
arg2
)
(
f
)
;
nary_impl
(
f
,
result
,
arg1
,
arg2
);
};
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment