Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
928cb435
Commit
928cb435
authored
Jul 23, 2019
by
Paul
Browse files
Refactor nary
parent
784dc2aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
30 deletions
+117
-30
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+117
-30
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
928cb435
...
@@ -118,6 +118,49 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,6 +118,49 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)([
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
+
bdim_len
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b1
,
b2
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -176,46 +219,90 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -176,46 +219,90 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
)
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
{
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
bshape
.
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
}
}
return
false
;
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
,
argument
)
{
divisible_by_4
=
false
;
return
false
;
}
// Nullary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
}
}
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
// Binary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
argument
barg
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg
,
arg
))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
);
}
else
{
nary_impl
(
stream
,
f
,
result
,
arg
,
barg
);
}
};
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
auto
barg
1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
bool
divisible_by_4
=
false
;
const
bool
standard
=
if
(
broadcastable
(
divisible_by_4
,
result
,
barg1
,
args2
...))
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
if
(
divisible_by_4
)
const
auto
&
strides
=
bshape
.
strides
();
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
else
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
nary_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
return
false
;
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
return
false
;
}
}
}
return
true
;
return
true
;
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment