Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
928cb435
Commit
928cb435
authored
Jul 23, 2019
by
Paul
Browse files
Refactor nary
parent
784dc2aa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
117 additions
and
30 deletions
+117
-30
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+117
-30
No files found.
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
928cb435
...
@@ -118,6 +118,49 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -118,6 +118,49 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
});
});
}
}
template
<
class
F
,
class
...
Arguments
>
void
nary_double_broadcast_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
argument
barg1
,
argument
barg2
,
Arguments
...
args
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
barg1
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
std
::
size_t
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)([
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
launch
(
stream
,
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
MIGRAPHX_DEVICE_SHARED
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
binput1
.
data
()[
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
+
bdim_len
]
=
binput2
.
data
()[
i
+
bdim_len
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
nelements
;
i
+=
nglobal
)
{
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b1
=
buffer
[
bidx
];
auto
b2
=
buffer
[
bidx
+
bdim_len
];
output
.
data
()[
i
]
=
f
(
inputs
.
data
()[
i
]...,
b1
,
b2
);
}
});
});
}
template
<
class
F
,
class
...
Arguments
>
template
<
class
F
,
class
...
Arguments
>
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
void
nary_standard_vec_impl
(
hipStream_t
stream
,
F
f
,
argument
result
,
Arguments
...
args
)
{
{
...
@@ -176,46 +219,90 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
...
@@ -176,46 +219,90 @@ auto nary_standard(hipStream_t stream, argument result, Arguments... args)
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
,
args
...);
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
)
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
result
,
argument
barg
,
Arguments
...
args
)
{
divisible_by_4
=
false
;
auto
bshape
=
barg
.
get_shape
();
const
bool
standard
=
all_of
({
args
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
{
args
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
bshape
.
strides
();
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
return
true
;
}
}
return
false
;
}
inline
bool
broadcastable
(
bool
&
divisible_by_4
,
argument
,
argument
)
{
divisible_by_4
=
false
;
return
false
;
}
// Nullary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
)
{
{
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
return
[
=
](
auto
f
)
{
nary_standard_impl
(
stream
,
f
,
result
);
};
}
}
// Unary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
)
{
return
[
=
](
auto
f
)
{
nary_impl
(
stream
,
f
,
result
,
arg
);
};
}
// Binary
inline
auto
nary
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
argument
barg
)
{
return
[
=
](
auto
f
)
{
bool
divisible_by_4
=
false
;
if
(
broadcastable
(
divisible_by_4
,
result
,
barg
,
arg
))
{
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
arg
);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
arg
);
}
else
{
nary_impl
(
stream
,
f
,
result
,
arg
,
barg
);
}
};
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
auto
nary
(
hipStream_t
stream
,
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
auto
barg
=
back_args
(
args
...);
auto
barg
1
=
back_args
(
args
...);
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
bool
fallback
=
pop_back_args
(
args
...)([
&
](
auto
&&
...
args2
)
{
auto
bshape
=
barg
.
get_shape
();
bool
divisible_by_4
=
false
;
const
bool
standard
=
if
(
broadcastable
(
divisible_by_4
,
result
,
barg1
,
args2
...))
all_of
({
args2
.
get_shape
()...},
[](
const
shape
&
s
)
{
return
s
.
standard
();
});
const
bool
same_shapes
=
all_of
(
{
args2
.
get_shape
()...},
[
&
](
const
shape
&
s
)
{
return
s
==
result
.
get_shape
();
});
// TODO: Check result and args shape is the same
if
(
standard
and
same_shapes
and
bshape
.
broadcasted
()
and
not
bshape
.
scalar
())
{
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
if
(
divisible_by_4
)
const
auto
&
strides
=
bshape
.
strides
();
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
else
auto
b_idx
=
std
::
distance
(
strides
.
begin
(),
b_it
);
nary_broadcast_impl
(
stream
,
f
,
result
,
barg1
,
args2
...);
auto
b_len
=
result
.
get_shape
().
lens
()[
b_idx
];
return
false
;
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
assert
(
bshape
.
lens
()[
b_idx
]
==
b_len
);
if
(
b_len
<=
2048
and
std
::
none_of
(
std
::
next
(
b_it
),
strides
.
end
(),
not_zero
))
{
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
front_args
(
args
...).
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
nary_broadcast_vec_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
else
nary_broadcast_impl
(
stream
,
f
,
result
,
barg
,
args2
...);
return
false
;
}
}
}
return
true
;
return
true
;
});
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment