"llm/ggml.c" did not exist on "fd4792ec56965a9c8564c3d88212c29a0378583d"
Commit 84ba492c authored by Shucai Xiao's avatar Shucai Xiao
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into round_operator

parents 824d9c5f 3ab91a79
...@@ -33,6 +33,7 @@ add_library(migraphx_device ...@@ -33,6 +33,7 @@ add_library(migraphx_device
device/contiguous.cpp device/contiguous.cpp
device/logsoftmax.cpp device/logsoftmax.cpp
device/softmax.cpp device/softmax.cpp
device/sigmoid.cpp
device/convert.cpp device/convert.cpp
device/mul.cpp device/mul.cpp
device/concat.cpp device/concat.cpp
...@@ -78,7 +79,6 @@ add_library(migraphx_gpu ...@@ -78,7 +79,6 @@ add_library(migraphx_gpu
batchnorm.cpp batchnorm.cpp
write_literals.cpp write_literals.cpp
rocblas.cpp rocblas.cpp
sigmoid.cpp
abs.cpp abs.cpp
elu.cpp elu.cpp
pad.cpp pad.cpp
......
...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -245,8 +245,7 @@ void reduce_standard_impl(hipStream_t stream,
T init, T init,
Input read_input, Input read_input,
Output read_output, Output read_output,
std::size_t relements, std::size_t relements)
std::size_t stride)
{ {
hip_visit_all(result, arg)([&](auto output, auto input) { hip_visit_all(result, arg)([&](auto output, auto input) {
auto nelements = result.get_shape().elements(); auto nelements = result.get_shape().elements();
...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream, ...@@ -255,7 +254,7 @@ void reduce_standard_impl(hipStream_t stream,
const std::size_t block_size = compute_block_size(relements, max_block_size); const std::size_t block_size = compute_block_size(relements, max_block_size);
gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ { gs_launch(stream, nelements * block_size, block_size)([=](auto i, auto idx) __device__ {
const auto out_idx = i / block_size; const auto out_idx = i / block_size;
const auto base_idx = out_idx * stride; const auto base_idx = out_idx * relements;
auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ { auto r = block_reduce<max_block_size>(idx, op, init, relements, [&](auto j) __device__ {
return read_input(input.data()[base_idx + j]); return read_input(input.data()[base_idx + j]);
}); });
...@@ -276,25 +275,15 @@ void reduce(hipStream_t stream, ...@@ -276,25 +275,15 @@ void reduce(hipStream_t stream,
{ {
auto&& output_shape = result.get_shape(); auto&& output_shape = result.get_shape();
auto&& input_shape = arg.get_shape(); auto&& input_shape = arg.get_shape();
assert(output_shape.lens().size() == input_shape.lens().size());
if(input_shape.standard() and output_shape.standard() and if(input_shape.standard() and output_shape.standard() and
output_shape.lens().back() != input_shape.lens().back() and output_shape.lens().back() != input_shape.lens().back() and
std::equal(output_shape.lens().begin(), std::equal(output_shape.lens().begin(),
std::prev(output_shape.lens().end()), std::prev(output_shape.lens().end()),
input_shape.lens().begin())) input_shape.lens().begin()))
{ {
std::size_t stride = std::accumulate(input_shape.strides().begin(), reduce_standard_impl(
input_shape.strides().end(), stream, result, arg, op, init, read_input, read_output, input_shape.lens().back());
1,
std::multiplies<size_t>());
reduce_standard_impl(stream,
result,
arg,
op,
init,
read_input,
read_output,
input_shape.lens().back(),
stride);
} }
else else
{ {
......
#include <migraphx/gpu/device/sigmoid.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sigmoid(hipStream_t stream, const argument& result, const argument& arg)
{
nary(stream, result, arg)([](auto x) { return 1.f / (1.f + ::exp(to_hip_type(-x))); });
}
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#ifndef MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGMOID_HPP
#define MIGRAPHX_GUARD_RTGLIB_DEVICE_SIGMOID_HPP
#include <migraphx/argument.hpp>
#include <migraphx/config.hpp>
#include <hip/hip_runtime_api.h>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
namespace device {
void sigmoid(hipStream_t stream, const argument& result, const argument& arg);
} // namespace device
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
#endif
#ifndef MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP #ifndef MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#define MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP #define MIGRAPHX_GUARD_RTGLIB_SIGMOID_HPP
#include <migraphx/shape.hpp> #include <migraphx/gpu/oper.hpp>
#include <migraphx/gpu/miopen.hpp> #include <migraphx/gpu/device/sigmoid.hpp>
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace gpu { namespace gpu {
struct context; struct hip_sigmoid : unary_device<hip_sigmoid, device::sigmoid>
struct miopen_sigmoid
{ {
shared<activation_descriptor> ad;
template <class Self, class F>
static auto reflect(Self& self, F f)
{
return gpu::reflect(self.ad.get(), f);
}
std::string name() const { return "gpu::sigmoid"; }
shape compute_shape(const std::vector<shape>& inputs) const;
argument
compute(context& ctx, const shape& output_shape, const std::vector<argument>& args) const;
std::ptrdiff_t output_alias(const std::vector<shape>& shapes) const
{
return shapes.size() - 1;
}
}; };
} // namespace gpu } // namespace gpu
......
...@@ -87,7 +87,6 @@ struct miopen_apply ...@@ -87,7 +87,6 @@ struct miopen_apply
void init() void init()
{ {
this->last = instruction::get_output_alias(std::prev(prog->end())); this->last = instruction::get_output_alias(std::prev(prog->end()));
add_miopen_simple_op<miopen_sigmoid>("sigmoid", make_sigmoid);
add_miopen_simple_op<miopen_abs>("abs", make_abs); add_miopen_simple_op<miopen_abs>("abs", make_abs);
add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu); add_miopen_extend_op<miopen_leaky_relu, op::leaky_relu>("leaky_relu", make_leaky_relu);
...@@ -118,6 +117,7 @@ struct miopen_apply ...@@ -118,6 +117,7 @@ struct miopen_apply
add_generic_op<hip_sqdiff>("sqdiff"); add_generic_op<hip_sqdiff>("sqdiff");
add_generic_op<hip_relu>("relu"); add_generic_op<hip_relu>("relu");
add_generic_op<hip_sign>("sign"); add_generic_op<hip_sign>("sign");
add_generic_op<hip_sigmoid>("sigmoid");
add_extend_op<miopen_gemm, op::dot>("dot"); add_extend_op<miopen_gemm, op::dot>("dot");
add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot"); add_extend_op<rocblas_quant_gemm, op::quant_dot>("quant_dot");
......
#include <migraphx/gpu/sigmoid.hpp>
#include <migraphx/gpu/context.hpp>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace gpu {
shape miopen_sigmoid::compute_shape(const std::vector<shape>& inputs) const
{
check_shapes{inputs, *this}.has(2).not_broadcasted();
return inputs.at(1);
}
argument miopen_sigmoid::compute(context& ctx,
const shape& output_shape,
const std::vector<argument>& args) const
{
float alpha = 1;
float beta = 0;
auto x_desc = make_tensor(args[0].get_shape());
auto y_desc = make_tensor(output_shape);
miopenActivationForward(ctx.get_stream().get_miopen(),
ad.get(),
&alpha,
x_desc.get(),
args[0].implicit(),
&beta,
y_desc.get(),
args[1].implicit());
return args[1];
}
} // namespace gpu
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -3792,6 +3792,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean> ...@@ -3792,6 +3792,18 @@ struct test_reduce_mean : verify_program<test_reduce_mean>
}; };
}; };
struct test_reduce_mean2 : verify_program<test_reduce_mean2>
{
migraphx::program create_program() const
{
migraphx::program p;
migraphx::shape s{migraphx::shape::float_type, {1, 128, 768}};
auto x = p.add_parameter("x", s);
p.add_instruction(migraphx::op::reduce_mean{{2}}, x);
return p;
};
};
struct test_reduce_mean_int : verify_program<test_reduce_mean_int> struct test_reduce_mean_int : verify_program<test_reduce_mean_int>
{ {
migraphx::program create_program() const migraphx::program create_program() const
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment