Commit 65386ce5 authored by Shucai Xiao's avatar Shucai Xiao
Browse files

backup latest changes for the layernorm_half2 branch

parent ec205c54
...@@ -37,7 +37,7 @@ struct squeeze ...@@ -37,7 +37,7 @@ struct squeeze
std::string name() const { return "squeeze"; } std::string name() const { return "squeeze"; }
shape normalize_compute_shape(std::vector<shape> inputs) const shape normalize_compute_shape(std::vector<shape> inputs) const
{ {
check_shapes{inputs, *this}.has(1); check_shapes{inputs, *this}.has(1).standard();
auto input_shape = inputs[0]; auto input_shape = inputs[0];
auto type = input_shape.type(); auto type = input_shape.type();
auto old_lens = input_shape.lens(); auto old_lens = input_shape.lens();
......
...@@ -24,7 +24,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY); ...@@ -24,7 +24,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \ if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl; std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static index_int group_num_global = (1 << 20); static index_int group_num_global = (1 << 8);
template <class... Ts> template <class... Ts>
constexpr auto pack(Ts... xs) constexpr auto pack(Ts... xs)
......
...@@ -147,7 +147,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -147,7 +147,7 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx::shape batch_shape{result.get_shape().type(), batch_lens}; migraphx::shape batch_shape{result.get_shape().type(), batch_lens};
hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) { hip_visit_all(result, arg, batch_shape)([&](auto output, auto input, auto batch) {
const index_int max_block_size = 1024; const index_int max_block_size = 128;
const index_int block_size = compute_block_size(batch_item_num, max_block_size); const index_int block_size = compute_block_size(batch_item_num, max_block_size);
using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>; using type = device_type<std::remove_cv_t<typename decltype(input)::value_type>>;
type init = lowest(); type init = lowest();
...@@ -157,9 +157,10 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in ...@@ -157,9 +157,10 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
auto in_type = result.get_shape().type(); auto in_type = result.get_shape().type();
if(in_type == shape::half_type and batch_item_num <= 1024) if(in_type == shape::half_type and batch_item_num <= 1024)
{ {
auto half2_block_size = compute_block_size(batch_item_num, 1024);
int block_num = batch_shape.elements(); int block_num = batch_shape.elements();
int shared_size = batch_item_num * 2 * result.get_shape().type_size(); int shared_size = batch_item_num * 2 * result.get_shape().type_size();
auto half2_block_size = block_size / 4; half2_block_size = half2_block_size / 4;
softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>( softmax_kernel<<<block_num, half2_block_size, shared_size, stream>>>(
arg.data(), batch_item_num, half2_block_size, result.data()); arg.data(), batch_item_num, half2_block_size, result.data());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment