Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d5ade1e7
Commit
d5ade1e7
authored
Jul 05, 2019
by
Paul
Browse files
Merge branch 'develop' into tf-transpose
parents
02c28d6a
3db703df
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
587 additions
and
154 deletions
+587
-154
src/include/migraphx/generate.hpp
src/include/migraphx/generate.hpp
+2
-2
src/include/migraphx/literal.hpp
src/include/migraphx/literal.hpp
+1
-0
src/include/migraphx/op/multibroadcast.hpp
src/include/migraphx/op/multibroadcast.hpp
+17
-3
src/include/migraphx/op/reduce_sum.hpp
src/include/migraphx/op/reduce_sum.hpp
+56
-0
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+1
-0
src/include/migraphx/shape.hpp
src/include/migraphx/shape.hpp
+2
-0
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+9
-1
src/pass_manager.cpp
src/pass_manager.cpp
+0
-1
src/rewrite_rnn.cpp
src/rewrite_rnn.cpp
+14
-1
src/shape.cpp
src/shape.cpp
+18
-0
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+64
-53
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-0
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
+8
-0
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+42
-10
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
...targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
+248
-0
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
+16
-0
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
+3
-3
src/targets/gpu/device/logsoftmax.cpp
src/targets/gpu/device/logsoftmax.cpp
+34
-41
src/targets/gpu/device/reduce_sum.cpp
src/targets/gpu/device/reduce_sum.cpp
+17
-0
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+33
-39
No files found.
src/include/migraphx/generate.hpp
View file @
d5ade1e7
...
@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
...
@@ -25,7 +25,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
is_signed
<
T
>{}
and
not
is_floating_point
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
const
auto
half_max
=
max
/
2
;
const
auto
half_max
=
max
/
2
;
return
half_max
-
(
z
%
max
);
return
half_max
-
(
z
%
max
);
}
}
...
@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
...
@@ -33,7 +33,7 @@ constexpr T normalize(unsigned long z)
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
template
<
class
T
,
MIGRAPHX_REQUIRES
(
not
is_signed
<
T
>{}
and
std
::
is_integral
<
T
>
{})
>
constexpr
T
normalize
(
unsigned
long
z
)
constexpr
T
normalize
(
unsigned
long
z
)
{
{
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
();
const
auto
max
=
std
::
numeric_limits
<
T
>::
max
()
/
64
;
return
z
%
max
;
return
z
%
max
;
}
}
...
...
src/include/migraphx/literal.hpp
View file @
d5ade1e7
...
@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
...
@@ -79,6 +79,7 @@ struct literal : raw_data<literal>
template
<
class
Iterator
>
template
<
class
Iterator
>
void
fill
(
Iterator
start
,
Iterator
end
)
void
fill
(
Iterator
start
,
Iterator
end
)
{
{
assert
(
std
::
distance
(
start
,
end
)
==
m_shape
.
elements
());
if
(
m_shape
.
standard
())
if
(
m_shape
.
standard
())
{
{
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
m_shape
.
visit_type
([
&
](
auto
as
)
{
std
::
copy
(
start
,
end
,
as
.
from
(
buffer
.
get
()));
});
...
...
src/include/migraphx/op/multibroadcast.hpp
View file @
d5ade1e7
...
@@ -35,14 +35,28 @@ struct multibroadcast
...
@@ -35,14 +35,28 @@ struct multibroadcast
auto
input
=
inputs
.
at
(
0
);
auto
input
=
inputs
.
at
(
0
);
if
(
input
.
lens
().
empty
())
if
(
input
.
lens
().
empty
())
MIGRAPHX_THROW
(
"inputs dimensions should be > 0"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should be > 0"
);
}
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
if
(
input
.
lens
().
size
()
>
output_lens
.
size
())
MIGRAPHX_THROW
(
"inputs dimensions should <= output size"
);
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: inputs dimensions should <= output size"
);
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
auto
offset
=
output_lens
.
size
()
-
input
.
lens
().
size
();
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
if
(
output_lens
[
i
+
offset
]
!=
input
.
lens
()[
i
]
and
input
.
lens
()[
i
]
!=
1
)
{
MIGRAPHX_THROW
(
"MULTIBROADCAST: input shape {"
+
to_string_range
(
input
.
lens
())
+
"} cannot be broadcasted to {"
+
to_string_range
(
output_lens
)
+
"}!"
);
}
}
std
::
vector
<
size_t
>
bcast_strides
(
output_lens
.
size
(),
0
);
for
(
std
::
ptrdiff_t
i
=
input
.
lens
().
size
()
-
1
;
i
>=
0
;
i
--
)
{
{
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
if
(
output_lens
[
i
+
offset
]
==
input
.
lens
()[
i
])
{
{
...
...
src/include/migraphx/op/reduce_sum.hpp
0 → 100644
View file @
d5ade1e7
#ifndef MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#define MIGRAPHX_GUARD_OPERATORS_SUM_HPP
#include <migraphx/check_shapes.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <vector>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
op
{
struct
reduce_sum
{
std
::
vector
<
std
::
size_t
>
axes
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
axes
,
"axes"
));
}
std
::
string
name
()
const
{
return
"reduce_sum"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
1
);
auto
s
=
inputs
.
at
(
0
);
auto
lens
=
s
.
lens
();
for
(
auto
axis
:
axes
)
lens
[
axis
]
=
1
;
return
{
s
.
type
(),
lens
};
}
argument
compute
(
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
input
.
get_shape
(),
[
&
](
auto
&&
in_idx
)
{
auto
out_idx
=
in_idx
;
for
(
auto
axis
:
axes
)
out_idx
[
axis
]
=
0
;
output
(
out_idx
.
begin
(),
out_idx
.
end
())
+=
input
(
in_idx
.
begin
(),
in_idx
.
end
());
});
});
return
result
;
}
};
}
// namespace op
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/include/migraphx/operators.hpp
View file @
d5ade1e7
...
@@ -42,6 +42,7 @@
...
@@ -42,6 +42,7 @@
#include <migraphx/op/outline.hpp>
#include <migraphx/op/outline.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/reduce_sum.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/relu.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/reshape.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn.hpp>
...
...
src/include/migraphx/shape.hpp
View file @
d5ade1e7
...
@@ -99,6 +99,8 @@ struct shape
...
@@ -99,6 +99,8 @@ struct shape
/// Map element index to space index
/// Map element index to space index
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
size_t
index
(
std
::
size_t
i
)
const
;
std
::
vector
<
std
::
size_t
>
multi
(
std
::
size_t
i
)
const
;
/// Returns true if the shape is packed with no padding
/// Returns true if the shape is packed with no padding
bool
packed
()
const
;
bool
packed
()
const
;
/// Returns true is the shape has been transposed. That is the strides are not in descending
/// Returns true is the shape has been transposed. That is the strides are not in descending
...
...
src/onnx/onnx.cpp
View file @
d5ade1e7
...
@@ -182,7 +182,15 @@ struct onnx_parser
...
@@ -182,7 +182,15 @@ struct onnx_parser
s0
.
end
(),
s0
.
end
(),
s1
.
begin
()
+
offset
,
s1
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
out_lens
.
begin
()
+
offset
,
[](
auto
a
,
auto
b
)
{
return
std
::
max
(
a
,
b
);
});
[
&
](
auto
a
,
auto
b
)
{
if
(
a
!=
b
and
a
!=
1
and
b
!=
1
)
{
MIGRAPHX_THROW
(
"COMPUTE_BROADCASTLEN: shape {"
+
to_string_range
(
s0
)
+
"} and {"
+
to_string_range
(
s1
)
+
"} mismatch!"
);
}
return
std
::
max
(
a
,
b
);
});
return
out_lens
;
return
out_lens
;
}
}
...
...
src/pass_manager.cpp
View file @
d5ade1e7
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
#include <migraphx/pass_manager.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/target.hpp>
#include <migraphx/target.hpp>
#include <migraphx/env.hpp>
#include <migraphx/env.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/ranges.hpp>
...
...
src/rewrite_rnn.cpp
View file @
d5ade1e7
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/rewrite_rnn.hpp>
#include <migraphx/program.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/add.hpp>
#include <migraphx/op/broadcast.hpp>
#include <migraphx/op/concat.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/gru.hpp>
#include <migraphx/op/lstm.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/rnn.hpp>
#include <migraphx/op/rnn_last_output.hpp>
#include <migraphx/op/slice.hpp>
#include <migraphx/op/squeeze.hpp>
#include <migraphx/op/sub.hpp>
#include <migraphx/op/transpose.hpp>
#include <migraphx/op/unsqueeze.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/op/common.hpp>
#include <migraphx/op/common.hpp>
...
...
src/shape.cpp
View file @
d5ade1e7
...
@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
...
@@ -138,6 +138,24 @@ std::size_t shape::index(std::size_t i) const
return
result
;
return
result
;
}
}
}
}
std
::
vector
<
std
::
size_t
>
shape
::
multi
(
std
::
size_t
i
)
const
{
assert
(
this
->
standard
());
std
::
vector
<
std
::
size_t
>
indices
(
lens
().
size
());
std
::
transform
(
strides
().
begin
(),
strides
().
end
(),
lens
().
begin
(),
indices
.
begin
(),
[
&
](
std
::
size_t
stride
,
std
::
size_t
len
)
{
assert
(
len
>
0
and
stride
>
0
);
return
(
i
/
stride
)
%
len
;
});
return
indices
;
}
bool
shape
::
packed
()
const
{
return
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
packed
()
const
{
return
this
->
elements
()
==
this
->
element_space
();
}
bool
shape
::
transposed
()
const
bool
shape
::
transposed
()
const
...
...
src/targets/cpu/lowering.cpp
View file @
d5ade1e7
...
@@ -2,7 +2,17 @@
...
@@ -2,7 +2,17 @@
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/cpu/lowering.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/operators.hpp>
#include <migraphx/op/batch_norm.hpp>
#include <migraphx/op/convolution.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/elu.hpp>
#include <migraphx/op/im2col.hpp>
#include <migraphx/op/leaky_relu.hpp>
#include <migraphx/op/logsoftmax.hpp>
#include <migraphx/op/lrn.hpp>
#include <migraphx/op/pad.hpp>
#include <migraphx/op/pooling.hpp>
#include <migraphx/op/softmax.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/par_dfor.hpp>
#include <migraphx/par_dfor.hpp>
...
@@ -529,18 +539,11 @@ struct cpu_softmax
...
@@ -529,18 +539,11 @@ struct cpu_softmax
std
::
string
name
()
const
{
return
"cpu::softmax"
;
}
std
::
string
name
()
const
{
return
"cpu::softmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
template
<
typename
T
>
std
::
size_t
compute_batch_index
(
T
idx
,
shape
&
batch_shape
,
int
axis
)
const
{
idx
[
axis
]
=
0
;
return
batch_shape
.
index
(
idx
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
batch_lens
=
output_shape
.
lens
();
auto
batch_lens
=
output_shape
.
lens
();
std
::
size_t
n_dims
=
batch_lens
[
op
.
axis
];
batch_lens
[
op
.
axis
]
=
1
;
batch_lens
[
op
.
axis
]
=
1
;
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
...
@@ -548,26 +551,33 @@ struct cpu_softmax
...
@@ -548,26 +551,33 @@ struct cpu_softmax
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
lowest
());
std
::
numeric_limits
<
value_type
>::
lowest
());
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
std
::
vector
<
value_type
>
batch_sum
(
batch_shape
.
elements
(),
value_type
(
0
));
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
par_for
(
batch_shape
.
elements
(),
[
&
](
auto
i
)
{
batch_max
[
index
]
=
std
::
max
(
batch_max
[
index
],
input
(
idx
.
begin
(),
idx
.
end
()));
auto
idx
=
batch_shape
.
multi
(
i
);
});
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
idx
[
op
.
axis
]
=
j
;
batch_max
[
i
]
=
std
::
max
(
batch_max
[
i
],
input
(
idx
.
begin
(),
idx
.
end
()));
}
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
{
output
(
idx
.
begin
(),
idx
.
end
())
=
idx
[
op
.
axis
]
=
j
;
std
::
exp
(
input
(
idx
.
begin
(),
idx
.
end
())
-
batch_max
[
index
]);
std
::
size_t
index
=
output_shape
.
index
(
idx
);
});
output
[
index
]
=
std
::
exp
(
input
[
index
]
-
batch_max
[
i
]);
}
std
::
vector
<
value_type
>
batch_sum
(
batch_shape
.
elements
(),
value_type
(
0
));
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
{
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
)
;
idx
[
op
.
axis
]
=
j
;
batch_sum
[
i
ndex
]
+=
output
(
idx
.
begin
(),
idx
.
end
());
batch_sum
[
i
]
+=
output
(
idx
.
begin
(),
idx
.
end
());
});
}
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
{
output
(
idx
.
begin
(),
idx
.
end
())
/=
batch_sum
[
index
];
idx
[
op
.
axis
]
=
j
;
output
(
idx
.
begin
(),
idx
.
end
())
/=
batch_sum
[
i
];
}
});
});
});
});
...
@@ -587,49 +597,50 @@ struct cpu_logsoftmax
...
@@ -587,49 +597,50 @@ struct cpu_logsoftmax
std
::
string
name
()
const
{
return
"cpu::logsoftmax"
;
}
std
::
string
name
()
const
{
return
"cpu::logsoftmax"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
return
op
.
compute_shape
(
inputs
);
}
template
<
typename
T
>
std
::
size_t
compute_batch_index
(
T
idx
,
const
shape
&
batch_shape
,
int
axis
)
const
{
idx
[
axis
]
=
0
;
return
batch_shape
.
index
(
idx
);
}
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
context
&
,
const
shape
&
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
argument
result
{
output_shape
};
argument
result
{
output_shape
};
auto
batch_lens
=
output_shape
.
lens
();
auto
batch_lens
=
output_shape
.
lens
();
std
::
size_t
n_dims
=
batch_lens
[
op
.
axis
];
batch_lens
[
op
.
axis
]
=
1
;
batch_lens
[
op
.
axis
]
=
1
;
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
shape
batch_shape
{
shape
::
int32_type
,
batch_lens
};
// use a parallel implementation to acheive better performance
// one thread for one batch
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
args
[
0
])([
&
](
auto
output
,
auto
input
)
{
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
using
value_type
=
typename
decltype
(
input
)
::
value_type
;
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
std
::
vector
<
value_type
>
batch_max
(
batch_shape
.
elements
(),
std
::
numeric_limits
<
value_type
>::
lowest
());
std
::
numeric_limits
<
value_type
>::
lowest
());
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
std
::
vector
<
value_type
>
batch_sum
(
batch_shape
.
elements
(),
value_type
(
0
));
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
batch_max
[
index
]
=
std
::
max
(
batch_max
[
index
],
input
(
idx
.
begin
(),
idx
.
end
()));
});
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
par_for
(
batch_shape
.
elements
(),
[
&
](
auto
i
)
{
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
auto
idx
=
batch_shape
.
multi
(
i
);
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
())
-
batch_max
[
index
];
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
});
{
idx
[
op
.
axis
]
=
j
;
batch_max
[
i
]
=
std
::
max
(
batch_max
[
i
],
input
(
idx
.
begin
(),
idx
.
end
()));
}
std
::
vector
<
value_type
>
batch_sum
(
batch_shape
.
elements
(),
value_type
(
0
));
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
{
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
idx
[
op
.
axis
]
=
j
;
batch_sum
[
index
]
+=
std
::
exp
(
output
(
idx
.
begin
(),
idx
.
end
()));
std
::
size_t
index
=
output_shape
.
index
(
idx
);
});
output
[
index
]
=
input
[
index
]
-
batch_max
[
i
];
}
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
idx
[
op
.
axis
]
=
j
;
batch_sum
[
i
]
+=
std
::
exp
(
output
(
idx
.
begin
(),
idx
.
end
()));
}
for
(
std
::
size_t
i
=
0
;
i
<
batch_sum
.
size
();
++
i
)
{
batch_sum
[
i
]
=
std
::
log
(
batch_sum
[
i
]);
batch_sum
[
i
]
=
std
::
log
(
batch_sum
[
i
]);
}
shape_for_each
(
output_shape
,
[
&
](
auto
idx
)
{
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
auto
index
=
this
->
compute_batch_index
(
idx
,
batch_shape
,
op
.
axis
);
{
output
(
idx
.
begin
(),
idx
.
end
())
-=
batch_sum
[
index
];
idx
[
op
.
axis
]
=
j
;
output
(
idx
.
begin
(),
idx
.
end
())
-=
batch_sum
[
i
];
}
});
});
});
});
...
...
src/targets/gpu/CMakeLists.txt
View file @
d5ade1e7
...
@@ -35,6 +35,7 @@ add_library(migraphx_device
...
@@ -35,6 +35,7 @@ add_library(migraphx_device
device/gather.cpp
device/gather.cpp
device/sub.cpp
device/sub.cpp
device/clip.cpp
device/clip.cpp
device/reduce_sum.cpp
)
)
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
set_target_properties
(
migraphx_device PROPERTIES EXPORT_NAME device
)
rocm_clang_tidy_check
(
migraphx_device
)
rocm_clang_tidy_check
(
migraphx_device
)
...
@@ -70,6 +71,7 @@ add_library(migraphx_gpu
...
@@ -70,6 +71,7 @@ add_library(migraphx_gpu
schedule_model.cpp
schedule_model.cpp
adjust_allocation.cpp
adjust_allocation.cpp
clip.cpp
clip.cpp
reduce_sum.cpp
)
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
set_target_properties
(
migraphx_gpu PROPERTIES EXPORT_NAME gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
rocm_clang_tidy_check
(
migraphx_gpu
)
...
...
src/targets/gpu/device/include/migraphx/gpu/device/array.hpp
View file @
d5ade1e7
...
@@ -50,6 +50,14 @@ struct hip_array
...
@@ -50,6 +50,14 @@ struct hip_array
result
[
i
]
=
x
[
i
]
*
y
[
i
];
result
[
i
]
=
x
[
i
]
*
y
[
i
];
return
result
;
return
result
;
}
}
friend
MIGRAPHX_DEVICE_CONSTEXPR
hip_array
operator
+
(
const
hip_array
&
x
,
const
hip_array
&
y
)
{
hip_array
result
{};
for
(
std
::
size_t
i
=
0
;
i
<
N
;
i
++
)
result
[
i
]
=
x
[
i
]
+
y
[
i
];
return
result
;
}
};
};
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
View file @
d5ade1e7
...
@@ -11,9 +11,33 @@ namespace device {
...
@@ -11,9 +11,33 @@ namespace device {
struct
index
struct
index
{
{
std
::
size_t
global
;
std
::
size_t
global
=
0
;
std
::
size_t
local
;
std
::
size_t
local
=
0
;
std
::
size_t
group
;
std
::
size_t
group
=
0
;
__device__
std
::
size_t
nglobal
()
const
{
return
blockDim
.
x
*
gridDim
.
x
;
}
// NOLINT
__device__
std
::
size_t
nlocal
()
const
{
return
blockDim
.
x
;
}
// NOLINT
template
<
class
F
>
__device__
void
global_stride
(
std
::
size_t
n
,
F
f
)
const
{
const
auto
stride
=
nglobal
();
for
(
std
::
size_t
i
=
global
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
}
template
<
class
F
>
__device__
void
local_stride
(
std
::
size_t
n
,
F
f
)
const
{
const
auto
stride
=
nlocal
();
for
(
std
::
size_t
i
=
local
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
}
};
};
template
<
class
F
>
template
<
class
F
>
...
@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
...
@@ -35,18 +59,26 @@ inline auto launch(hipStream_t stream, std::size_t global, std::size_t local)
};
};
}
}
template
<
class
F
>
__host__
__device__
auto
gs_invoke
(
F
&&
f
,
std
::
size_t
i
,
index
idx
)
->
decltype
(
f
(
i
,
idx
))
{
return
f
(
i
,
idx
);
}
template
<
class
F
>
__host__
__device__
auto
gs_invoke
(
F
&&
f
,
std
::
size_t
i
,
index
)
->
decltype
(
f
(
i
))
{
return
f
(
i
);
}
inline
auto
gs_launch
(
hipStream_t
stream
,
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
inline
auto
gs_launch
(
hipStream_t
stream
,
std
::
size_t
n
,
std
::
size_t
local
=
1024
)
{
{
std
::
size_t
groups
=
1
+
n
/
local
;
std
::
size_t
groups
=
(
n
+
local
-
1
)
/
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
std
::
size_t
nglobal
=
std
::
min
<
std
::
size_t
>
(
256
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
{
launch
(
stream
,
nglobal
,
local
)(
for
(
size_t
i
=
idx
.
global
;
i
<
n
;
i
+=
nglobal
)
[
=
](
auto
idx
)
{
idx
.
global_stride
(
n
,
[
&
](
auto
i
)
{
gs_invoke
(
f
,
i
,
idx
);
});
});
{
f
(
i
);
}
});
};
};
}
}
...
...
src/targets/gpu/device/include/migraphx/gpu/device/reduce.hpp
0 → 100644
View file @
d5ade1e7
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_REDUCE_HPP
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/visit.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
struct
sum
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
id
{
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
max
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
>
y
?
x
:
y
;
}
};
struct
min
{
template
<
class
T
,
class
U
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
<
y
?
x
:
y
;
}
};
struct
lowest
{
template
<
class
T
>
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
lowest
());
}
};
struct
highest
{
template
<
class
T
>
operator
T
()
const
{
return
device_cast
(
std
::
numeric_limits
<
host_type
<
T
>>::
max
());
}
};
#ifdef MIGRAPHX_NO_DPP
template
<
std
::
size_t
N
,
class
Op
,
class
T
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
std
::
size_t
n
,
F
f
)
{
using
type
=
decltype
(
f
(
idx
.
local
));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
for
(
std
::
size_t
s
=
1
;
s
<
idx
.
nlocal
();
s
*=
2
)
{
const
std
::
size_t
index
=
2
*
s
*
idx
.
local
;
if
(
index
+
s
<
idx
.
nlocal
())
{
buffer
[
index
]
=
op
(
buffer
[
index
],
buffer
[
index
+
s
]);
}
__syncthreads
();
}
return
buffer
[
0
];
}
#else
constexpr
unsigned
int
dpp_row_shr
(
unsigned
int
x
)
{
return
0x110u
|
x
;
}
constexpr
unsigned
int
dpp_row_bcast
(
unsigned
int
x
)
{
unsigned
int
y
=
0
;
switch
(
x
)
{
case
15
:
y
=
0x142
;
break
;
case
31
:
y
=
0x143
;
break
;
default:
throw
std
::
runtime_error
(
"Unknown bcast"
);
}
return
y
;
}
template
<
unsigned
int
DppCtrl
,
unsigned
int
RowMask
=
0xf
,
unsigned
int
BankMask
=
0xf
,
bool
BoundCtrl
=
false
,
class
T
>
__device__
T
dpp_mov
(
T
&
x
)
{
static
const
std
::
size_t
n
=
sizeof
(
T
)
<
4
?
1
:
sizeof
(
T
)
/
4
;
union
type
{
uint32_t
reg
[
n
];
T
data
;
};
type
output
{};
type
input
{};
// cppcheck-suppress unreadVariable
input
.
data
=
x
;
for
(
std
::
size_t
i
=
0
;
i
<
n
;
i
++
)
{
output
.
reg
[
i
]
=
__llvm_amdgcn_move_dpp
(
input
.
reg
[
i
],
DppCtrl
,
RowMask
,
BankMask
,
BoundCtrl
);
}
return
output
.
data
;
}
template
<
class
T
,
class
Op
>
__device__
void
dpp_reduce
(
T
&
in
,
Op
op
)
{
T
out
;
out
=
dpp_mov
<
dpp_row_shr
(
1
)
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
2
)
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
4
),
0xf
,
0xe
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_shr
(
8
),
0xf
,
0xc
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
15
),
0xa
>
(
in
);
in
=
op
(
in
,
out
);
out
=
dpp_mov
<
dpp_row_bcast
(
31
),
0xc
>
(
in
);
in
=
op
(
in
,
out
);
}
__device__
inline
void
dpp_reduce
(
float
&
x
,
sum
)
{
#ifdef MIGRAPHX_USE_CLANG_TIDY
(
void
)
x
;
#else
__asm__
volatile
(
"s_nop 4
\n
"
"v_add_f32 %0 %0 %0 row_shr:1
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_shr:2
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_shr:4 bank_mask:0xe
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_shr:8 bank_mask:0xc
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_bcast:15 row_mask:0xa
\n
"
"s_nop 1
\n
"
"v_add_f32 %0 %0 %0 row_bcast:31 row_mask:0xc
\n
"
"s_nop 1
\n
"
:
"=v"
(
x
)
:
"0"
(
x
));
#endif
}
template
<
std
::
size_t
N
,
class
Op
,
class
T
,
class
F
>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
std
::
size_t
n
,
F
f
)
{
using
type
=
decltype
(
f
(
idx
.
local
));
MIGRAPHX_DEVICE_SHARED
type
buffer
[
N
/
64
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
64
;
if
((
idx
.
local
%
64
)
==
63
)
{
buffer
[
ldsidx
]
=
x
;
}
__syncthreads
();
type
y
=
init
;
for
(
std
::
size_t
i
=
0
;
i
<
idx
.
nlocal
()
/
64
;
i
++
)
{
y
=
op
(
y
,
buffer
[
i
]);
}
return
y
;
}
#endif
constexpr
std
::
size_t
compute_block_size
(
std
::
size_t
n
,
std
::
size_t
max_block_size
)
{
size_t
block_size
=
64
;
while
(
block_size
<
max_block_size
and
block_size
<
n
)
block_size
*=
2
;
return
block_size
;
}
template
<
class
Op
,
class
T
,
class
Input
,
class
Output
>
void
reduce
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
Op
op
,
T
init
,
Input
read_input
,
Output
read_output
)
{
auto
&&
output_shape
=
result
.
get_shape
();
auto
&&
input_shape
=
arg
.
get_shape
();
std
::
vector
<
std
::
size_t
>
reduce_lens
;
std
::
transform
(
output_shape
.
lens
().
begin
(),
output_shape
.
lens
().
end
(),
input_shape
.
lens
().
begin
(),
std
::
back_inserter
(
reduce_lens
),
[](
auto
x
,
auto
y
)
->
std
::
size_t
{
if
(
x
==
y
)
return
1
;
else
return
y
;
});
shape
reduce_slice
{
output_shape
.
type
(),
reduce_lens
};
hip_visit_all
(
result
,
arg
,
reduce_slice
)([
&
](
auto
output
,
auto
input
,
auto
reduce_shape
)
{
auto
nelements
=
result
.
get_shape
().
elements
();
auto
relements
=
reduce_slice
.
elements
();
const
std
::
size_t
max_block_size
=
256
;
const
std
::
size_t
block_size
=
compute_block_size
(
relements
,
max_block_size
);
gs_launch
(
stream
,
nelements
*
block_size
,
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
const
auto
out_idx
=
i
/
block_size
;
auto
base_idx
=
output
.
get_shape
().
multi
(
out_idx
);
auto
r
=
block_reduce
<
max_block_size
>
(
idx
,
op
,
init
,
relements
,
[
&
](
auto
j
)
__device__
{
auto
reduce_idx
=
reduce_shape
.
multi
(
j
);
return
read_input
(
input
[
reduce_idx
+
base_idx
]);
});
if
(
idx
.
local
==
0
)
output
.
data
()[
out_idx
]
=
read_output
(
r
);
});
});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif
src/targets/gpu/device/include/migraphx/gpu/device/shape.hpp
View file @
d5ade1e7
...
@@ -73,6 +73,22 @@ struct hip_shape
...
@@ -73,6 +73,22 @@ struct hip_shape
}
}
return
result
;
return
result
;
}
}
MIGRAPHX_DEVICE_CONSTEXPR
hip_index
carry
(
hip_index
result
)
const
{
std
::
ptrdiff_t
rem
=
0
;
for
(
std
::
ptrdiff_t
i
=
result
.
size
()
-
1
;
i
>=
0
;
i
--
)
{
auto
z
=
result
[
i
]
+
rem
;
rem
=
z
-
std
::
ptrdiff_t
(
lens
[
i
])
+
1
;
if
(
rem
>
0
)
z
-=
rem
;
else
rem
=
0
;
result
[
i
]
=
z
;
}
return
result
;
}
};
};
template
<
std
::
size_t
N
>
template
<
std
::
size_t
N
>
...
...
src/targets/gpu/device/include/migraphx/gpu/device/types.hpp
View file @
d5ade1e7
...
@@ -91,7 +91,7 @@ using device_type = typename detail::device_type<T>::type;
...
@@ -91,7 +91,7 @@ using device_type = typename detail::device_type<T>::type;
template
<
class
T
>
template
<
class
T
>
host_type
<
T
>
host_cast
(
T
x
)
host_type
<
T
>
host_cast
(
T
x
)
{
{
return
reinterpret_cast
<
host_type
<
T
>>
(
x
);
return
reinterpret_cast
<
const
host_type
<
T
>
&
>
(
x
);
}
}
template
<
class
T
>
template
<
class
T
>
...
@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x)
...
@@ -119,13 +119,13 @@ tensor_view<device_type<T>> device_cast(tensor_view<T> x)
}
}
template
<
class
T
>
template
<
class
T
>
T
to_hip_type
(
T
x
)
__device__
__host__
T
to_hip_type
(
T
x
)
{
{
return
x
;
return
x
;
}
}
// Hip doens't support __fp16
// Hip doens't support __fp16
inline
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
inline
__device__
__host__
float
to_hip_type
(
gpu_half
x
)
{
return
x
;
}
}
// namespace device
}
// namespace device
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/device/logsoftmax.cpp
View file @
d5ade1e7
#include <migraphx/shape.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/logsoftmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
...
@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -11,53 +12,45 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
argument
logsoftmax
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
int
axis
)
void
logsoftmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
auto
lens
=
result
.
get_shape
().
lens
();
auto
lens
=
result
.
get_shape
().
lens
();
auto
batch_lens
=
lens
;
auto
num_in_batch
=
lens
[
axis
];
std
::
size_t
batch_item_num
=
lens
[
axis
];
auto
batch_lens
=
lens
;
batch_lens
[
axis
]
=
1
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
const
std
::
size_t
max_block_size
=
256
;
// each thread is for one item in the batch
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
batch_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
stream
,
auto
batch_idx
=
batch
.
multi
(
i
);
batch_shape
.
elements
()
*
block_size
,
auto
data_idx
=
batch_idx
;
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
data_idx
=
batch
.
multi
(
i
/
block_size
);
// get max
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
auto
batch_max
=
input
[
batch_idx
];
type
init
=
lowest
();
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
{
auto
batch_max
=
block_reduce
<
max_block_size
>
(
data_idx
[
axis
]
=
j
;
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
batch_max
=
std
::
max
(
to_hip_type
(
batch_max
),
to_hip_type
(
input
[
data_idx
]));
data_idx
[
axis
]
=
j
;
}
return
input
[
data_idx
];
});
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
{
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
return
::
exp
(
to_hip_type
(
val
));
});
auto
log_batch_sum
=
::
log
(
to_hip_type
(
batch_sum
))
+
batch_max
;
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
output
[
data_idx
]
=
input
[
data_idx
]
-
batch_max
;
output
[
data_idx
]
=
input
[
data_idx
]
-
log_batch_sum
;
}
});
auto
batch_sum
=
::
exp
(
to_hip_type
(
output
[
batch_idx
]));
for
(
std
::
size_t
j
=
1
;
j
<
num_in_batch
;
++
j
)
{
data_idx
[
axis
]
=
j
;
batch_sum
+=
::
exp
(
to_hip_type
(
output
[
data_idx
]));
}
batch_sum
=
::
log
(
to_hip_type
(
batch_sum
));
for
(
std
::
size_t
j
=
0
;
j
<
num_in_batch
;
++
j
)
{
data_idx
[
axis
]
=
j
;
output
[
data_idx
]
-=
batch_sum
;
}
});
});
});
});
return
result
;
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/reduce_sum.cpp
0 → 100644
View file @
d5ade1e7
#include <migraphx/gpu/device/reduce_sum.hpp>
#include <migraphx/gpu/device/reduce.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
device
{
void
reduce_sum
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
reduce
(
stream
,
result
,
arg
,
sum
{},
0
,
id
{},
id
{});
}
}
// namespace device
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/softmax.cpp
View file @
d5ade1e7
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <migraphx/argument.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/softmax.hpp>
#include <migraphx/gpu/device/reduce.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/tensor.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/launch.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
...
@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -12,51 +13,44 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
argument
softmax
(
hipStream_t
stream
,
argument
result
,
argument
arg
,
int
axis
)
void
softmax
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
,
int
axis
)
{
{
auto
lens
=
result
.
get_shape
().
lens
();
auto
lens
=
result
.
get_shape
().
lens
();
auto
batch_lens
=
lens
;
auto
batch_lens
=
lens
;
size_t
n_dims
=
lens
[
axis
];
std
::
size_t
batch_item_num
=
lens
[
axis
];
batch_lens
[
axis
]
=
1
;
batch_lens
[
axis
]
=
1
;
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
const
std
::
size_t
max_block_size
=
256
;
// each thread is for one item in the batch
const
std
::
size_t
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
batch_shape
.
elements
())([
=
](
auto
i
)
{
gs_launch
(
stream
,
auto
batch_idx
=
batch
.
multi
(
i
);
batch_shape
.
elements
()
*
block_size
,
auto
data_idx
=
batch_idx
;
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
data_idx
=
batch
.
multi
(
i
/
block_size
);
// get max
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
auto
batch_max
=
input
[
batch_idx
];
type
init
=
lowest
();
for
(
std
::
size_t
j
=
1
;
j
<
n_dims
;
++
j
)
{
auto
batch_max
=
block_reduce
<
max_block_size
>
(
data_idx
[
axis
]
=
j
;
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
batch_max
=
std
::
max
(
to_hip_type
(
batch_max
),
to_hip_type
(
input
[
data_idx
]));
data_idx
[
axis
]
=
j
;
}
return
input
[
data_idx
];
});
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
return
::
exp
(
to_hip_type
(
val
));
});
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
{
data_idx
[
axis
]
=
j
;
data_idx
[
axis
]
=
j
;
output
[
data_idx
]
=
exp
(
to_hip_type
(
input
[
data_idx
]
-
batch_max
));
auto
val
=
input
[
data_idx
]
-
batch_max
;
}
output
[
data_idx
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
});
auto
batch_sum
=
output
[
batch_idx
];
for
(
std
::
size_t
j
=
1
;
j
<
n_dims
;
++
j
)
{
data_idx
[
axis
]
=
j
;
batch_sum
+=
output
[
data_idx
];
}
for
(
std
::
size_t
j
=
0
;
j
<
n_dims
;
++
j
)
{
data_idx
[
axis
]
=
j
;
output
[
data_idx
]
=
output
[
data_idx
]
/
batch_sum
;
}
});
});
});
});
return
result
;
}
}
}
// namespace device
}
// namespace device
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment