Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
6459204c
Commit
6459204c
authored
Aug 28, 2018
by
Paul
Browse files
Fix broadcast tensor op
parent
3b04798c
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
101 additions
and
82 deletions
+101
-82
src/include/migraph/generate.hpp
src/include/migraph/generate.hpp
+1
-0
src/targets/gpu/device/add_relu.cpp
src/targets/gpu/device/add_relu.cpp
+1
-1
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
+84
-81
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+15
-0
No files found.
src/include/migraph/generate.hpp
View file @
6459204c
...
@@ -80,6 +80,7 @@ std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed
...
@@ -80,6 +80,7 @@ std::vector<T> generate_tensor_data(const migraph::shape& s, unsigned long seed
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
vector
<
T
>
result
(
s
.
elements
());
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
std
::
generate
(
result
.
begin
(),
result
.
end
(),
xorshf96_generator
<
T
>
{
seed
});
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), [&]{ return seed % 7; });
// std::generate(result.begin(), result.end(), []{ return 1; });
return
result
;
return
result
;
}
}
...
...
src/targets/gpu/device/add_relu.cpp
View file @
6459204c
...
@@ -8,7 +8,7 @@ namespace device {
...
@@ -8,7 +8,7 @@ namespace device {
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
void
add_relu
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
{
nary
(
std
::
move
(
result
),
std
::
move
(
arg1
),
std
::
move
(
arg2
))(
nary
(
std
::
move
(
result
),
std
::
move
(
arg1
),
std
::
move
(
arg2
))(
[](
auto
x
,
auto
y
)
{
return
max
(
0
,
x
+
y
);
});
[](
auto
x
,
auto
y
)
{
return
std
::
max
<
decltype
(
x
+
y
)
>
(
0
,
x
+
y
);
});
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraph/gpu/device/nary.hpp
View file @
6459204c
...
@@ -14,25 +14,15 @@ template <class T>
...
@@ -14,25 +14,15 @@ template <class T>
using
vec4
=
T
__attribute__
((
ext_vector_type
(
4
)));
using
vec4
=
T
__attribute__
((
ext_vector_type
(
4
)));
template
<
class
T
>
template
<
class
T
>
vec4
<
T
>*
as_vec4
(
T
*
x
)
__device__
__host__
vec4
<
T
>*
as_vec4
(
T
*
x
)
{
{
std
::
uintptr_t
a
=
reinterpret_cast
<
std
::
uintptr_t
>
(
x
);
if
(
a
%
32
!=
0
)
throw
std
::
runtime_error
(
"Memory not aligned for vector operations"
);
return
reinterpret_cast
<
vec4
<
T
>*>
(
x
);
return
reinterpret_cast
<
vec4
<
T
>*>
(
x
);
// return (vec4<T>*)(x);
}
}
template
<
class
T
>
template
<
class
T
>
vec4
<
T
>
vec4_load
(
T
*
x
,
size_t
i
)
__device__
__host__
T
*
as_pointer
(
vec4
<
T
>*
x
)
{
{
vec4
<
T
>
result
;
return
reinterpret_cast
<
T
*>
(
x
);
auto
n
=
i
*
4
;
result
[
0
]
=
x
[
n
+
0
];
result
[
1
]
=
x
[
n
+
1
];
result
[
2
]
=
x
[
n
+
2
];
result
[
3
]
=
x
[
n
+
3
];
return
result
;
}
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
...
@@ -67,46 +57,21 @@ auto nary_nonstandard(argument result, Arguments... args)
...
@@ -67,46 +57,21 @@ auto nary_nonstandard(argument result, Arguments... args)
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
return
[
=
](
auto
f
)
{
return
nary_nonstandard_impl
(
f
,
result
,
args
...);
};
}
}
inline
auto
binary_broadcast
(
argument
result
,
argument
arg1
,
argument
arg2
)
inline
auto
binary_broadcast
_vec
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
//
const auto& output_shape = result.get_shape();
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
b_shape
.
lens
()[
bdim
];
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
#if 1
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
__shared__
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
auto
bidx
=
i
%
bdim_len
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
});
#else
auto
*
xp
=
as_vec4
(
input1
.
data
());
auto
*
xp
=
as_vec4
(
input1
.
data
());
auto
*
yp
=
as_vec4
(
input2
.
data
());
auto
*
yp
=
as_vec4
(
input2
.
data
());
auto
*
outp
=
as_vec4
(
output
.
data
());
auto
*
outp
=
as_vec4
(
output
.
data
());
...
@@ -115,63 +80,86 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
...
@@ -115,63 +80,86 @@ inline auto binary_broadcast(argument result, argument arg1, argument arg2)
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
n
=
output
.
size
()
/
vec_size
;
const
std
::
size_t
rem
=
output
.
size
()
%
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
const
std
::
size_t
bdim_vec_len
=
bdim_len
/
vec_size
;
const
std
::
size_t
bdim_vec_rem
=
bdim_len
%
vec_size
;
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
__shared__
vec4
<
type
>
buffer
[
2048
];
__shared__
vec4
<
type
>
buffer
[
2048
/
vec_size
];
// Load bias into LDS
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_len
;
i
+=
nlocal
)
{
{
buffer
[
i
]
=
yp
[
i
];
buffer
[
i
]
=
yp
[
i
];
}
}
// Load remaining bias data
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_vec_rem
;
i
+=
nlocal
)
{
buffer
[
bdim_vec_len
][
i
]
=
yp
[
bdim_vec_len
][
i
];
}
for
(
size_t
i
=
idx
.
local
;
i
<
(
vec_size
-
bdim_vec_rem
);
i
+=
nlocal
)
{
buffer
[
bdim_vec_len
][
i
]
=
yp
[
0
][
i
];
}
__syncthreads
();
__syncthreads
();
auto
*
bp
=
as_pointer
(
buffer
);
// Process the data
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
{
auto
bidx
=
bdim_vec_len
==
0
?
0
:
i
%
bdim_vec_len
;
auto
bidx
=
((
i
*
vec_size
)
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
b
uffer
[
bidx
];
auto
b
=
b
p
[
bidx
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
x
=
xp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
vec4
<
type
>
out
=
outp
[
i
];
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
for
(
std
::
size_t
j
=
0
;
j
<
vec_size
;
j
++
)
{
{
out
[
j
]
=
f
(
x
[
j
],
b
[
j
]
);
out
[
j
]
=
f
(
x
[
j
],
b
);
}
}
outp
[
i
]
=
out
;
outp
[
i
]
=
out
;
}
}
for
(
size_t
i
=
idx
.
global
;
i
<
rem
;
i
+=
nglobal
)
});
});
};
}
inline
auto
binary_broadcast
(
argument
result
,
argument
arg1
,
argument
arg2
)
{
return
[
=
](
auto
f
)
{
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
b_shape
=
arg2
.
get_shape
();
auto
bdim
=
std
::
distance
(
b_shape
.
strides
().
begin
(),
std
::
find_if
(
b_shape
.
strides
().
begin
(),
b_shape
.
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
}));
auto
bdim_len
=
output_shape
.
lens
()[
bdim
];
auto
bdim_stride
=
output_shape
.
strides
()[
bdim
];
auto
bdim_next_stride
=
bdim_stride
*
bdim_len
;
visit_all
(
result
,
arg1
,
arg2
)([
&
](
auto
output
,
auto
input1
,
auto
input2
)
{
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
auto
*
xp
=
input1
.
data
();
auto
*
yp
=
input2
.
data
();
auto
*
outp
=
output
.
data
();
const
std
::
size_t
nlocal
=
1024
;
const
std
::
size_t
nglobal
=
256
*
nlocal
;
const
std
::
size_t
n
=
output
.
size
();
launch
(
nglobal
,
nlocal
)([
=
](
auto
idx
)
__device__
{
__shared__
type
buffer
[
2048
];
// Load bias into LDS
for
(
size_t
i
=
idx
.
local
;
i
<
bdim_len
;
i
+=
nlocal
)
{
buffer
[
i
]
=
yp
[
i
];
}
__syncthreads
();
// Process the data
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
{
{
outp
[
n
][
i
]
=
f
(
xp
[
n
][
i
],
buffer
[
bdim_vec_len
][
i
]);
auto
bidx
=
(
i
%
bdim_next_stride
)
/
bdim_stride
;
auto
b
=
buffer
[
bidx
];
type
x
=
xp
[
i
];
outp
[
i
]
=
f
(
x
,
b
);
}
}
});
});
#endif
});
});
};
};
}
}
template
<
class
...
Arguments
>
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
auto
nary_standard
_vec
(
argument
result
,
Arguments
...
args
)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
#if 1
auto
data
=
pack
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
#else
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
using
type
=
std
::
remove_cv_t
<
typename
decltype
(
output
)
::
value_type
>
;
const
std
::
size_t
vec_size
=
4
;
const
std
::
size_t
vec_size
=
4
;
auto
data
=
pack_vec4
(
inputs
.
data
()...);
auto
data
=
pack_vec4
(
inputs
.
data
()...);
...
@@ -188,7 +176,21 @@ auto nary_standard(argument result, Arguments... args)
...
@@ -188,7 +176,21 @@ auto nary_standard(argument result, Arguments... args)
i
);
i
);
outp
[
i
]
=
out
;
outp
[
i
]
=
out
;
});
});
#endif
});
};
}
template
<
class
...
Arguments
>
auto
nary_standard
(
argument
result
,
Arguments
...
args
)
{
return
[
=
](
auto
f
)
{
// assert(x.get_shape().elements() == y.get_shape().elements());
const
auto
&
output_shape
=
result
.
get_shape
();
visit_all
(
result
,
args
...)([
&
](
auto
output
,
auto
...
inputs
)
{
auto
data
=
pack
(
inputs
.
data
()...);
auto
*
outp
=
output
.
data
();
gs_launch
(
output_shape
.
elements
())(
[
=
](
auto
i
)
{
data
([
&
](
auto
...
xps
)
{
outp
[
i
]
=
f
(
xps
[
i
]...);
});
});
});
});
};
};
}
}
...
@@ -219,22 +221,23 @@ inline auto nary(argument result, argument arg1, argument arg2)
...
@@ -219,22 +221,23 @@ inline auto nary(argument result, argument arg1, argument arg2)
{
{
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
// TODO: Check result and arg1 shape is the same
// TODO: Check result and arg1 shape is the same
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
()
and
if
(
arg1
.
get_shape
().
standard
()
and
arg2
.
get_shape
().
broadcasted
())
std
::
count_if
(
arg2
.
get_shape
().
strides
().
begin
(),
arg2
.
get_shape
().
strides
().
end
(),
[](
auto
x
)
{
return
x
!=
0
;
})
==
1
)
{
{
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
auto
not_zero
=
[](
auto
x
)
{
return
x
!=
0
;
};
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
const
auto
&
strides
=
arg2
.
get_shape
().
strides
();
auto
stride
_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
b
_it
=
std
::
find_if
(
strides
.
begin
(),
strides
.
end
(),
not_zero
);
auto
stride
_idx
=
std
::
distance
(
strides
.
begin
(),
stride
_it
);
auto
b
_idx
=
std
::
distance
(
strides
.
begin
(),
b
_it
);
auto
stride
_len
=
arg2
.
get_shape
().
lens
()[
stride
_idx
];
auto
b
_len
=
result
.
get_shape
().
lens
()[
b
_idx
];
// TODO: Dont require disibility by 4
auto
b_stride
=
result
.
get_shape
().
strides
()[
b_idx
];
bool
divisible_by_4
=
(
stride_len
%
4
==
0
)
and
(
arg
1
.
get_shape
().
e
le
ments
()
%
4
==
0
);
assert
(
arg
2
.
get_shape
().
le
ns
()[
b_idx
]
==
b_len
);
if
(
divisible_by_4
and
stride
_len
<=
2048
and
if
(
b
_len
<=
2048
and
std
::
none_of
(
std
::
next
(
stride
_it
),
strides
.
end
(),
not_zero
))
std
::
none_of
(
std
::
next
(
b
_it
),
strides
.
end
(),
not_zero
))
{
{
binary_broadcast
(
result
,
arg1
,
arg2
)(
f
);
const
bool
divisible_by_4
=
(
b_len
%
4
==
0
)
and
(
b_stride
%
4
==
0
)
and
(
arg1
.
get_shape
().
elements
()
%
4
==
0
);
if
(
divisible_by_4
)
binary_broadcast_vec
(
result
,
arg1
,
arg2
)(
f
);
else
binary_broadcast
(
result
,
arg1
,
arg2
)(
f
);
return
;
return
;
}
}
}
}
...
...
test/gpu/miopen.cpp
View file @
6459204c
...
@@ -260,6 +260,20 @@ struct test_add_broadcast4
...
@@ -260,6 +260,20 @@ struct test_add_broadcast4
}
}
};
};
struct
test_add_broadcast5
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
3
}};
auto
x
=
p
.
add_parameter
(
"x"
,
{
migraph
::
shape
::
float_type
,
{
2
,
4
,
8
}});
auto
y
=
p
.
add_parameter
(
"y"
,
{
migraph
::
shape
::
float_type
,
{
4
}});
auto
by
=
p
.
add_instruction
(
migraph
::
broadcast
{
1
},
x
,
y
);
p
.
add_instruction
(
migraph
::
add
{},
x
,
by
);
return
p
;
}
};
struct
test_conv_relu
struct
test_conv_relu
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -471,6 +485,7 @@ int main()
...
@@ -471,6 +485,7 @@ int main()
verify_program
<
test_add_broadcast2
>
();
verify_program
<
test_add_broadcast2
>
();
verify_program
<
test_add_broadcast3
>
();
verify_program
<
test_add_broadcast3
>
();
verify_program
<
test_add_broadcast4
>
();
verify_program
<
test_add_broadcast4
>
();
verify_program
<
test_add_broadcast5
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_add_relu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_conv_pooling
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment