Commit 6711780a authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents c0563b9e d1abf06f
......@@ -31,6 +31,14 @@
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
extern "C" __device__ size_t __ockl_get_enqueued_local_size(uint); // NOLINT
extern "C" __device__ size_t __ockl_get_local_size(uint); // NOLINT
#pragma clang diagnostic pop
#endif
namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
......@@ -45,43 +53,37 @@ inline __device__ __attribute__((const)) index_int compute_global_size()
// This actualy works even when global is not divisible by local size.
// This doesnt actually do a multiplicatiosn. Instead it calls a device
// function to get the global size, which is why it works.
return blockDim.x * gridDim.x; // NOLINT
return blockDim.x * gridDim.x; // NOLINT
#endif
}
// We cant just use blockDim.x to get the local size since its broken on hip
// when global is not divisible by local size. In this case, we calulate the
// size for the last group.
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
inline __device__ __attribute__((const)) index_int compute_local_size()
{
#ifdef MIGRAPHX_NLOCAL
const auto nlocal = MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#ifdef MIGRAPHX_HAS_CONST_LOCAL
return MIGRAPHX_NLOCAL;
#else
const auto ngroup = gridDim.x; // NOLINT
// Returns block size. For the non-uniform block it returns the size of the non-uniform block.
return __ockl_get_local_size(0); // NOLINT
#endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
}
#ifdef MIGRAPHX_NGROUP
// If global is divisible by local then local can be a const
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
inline __device__ __attribute__((const)) index_int compute_max_local_size()
{
#ifdef MIGRAPHX_LOCAL
return MIGRAPHX_NLOCAL;
#else
// Returns the block size. When workgrop has non-uniform block, this returns size of the uniform
// block.
return __ockl_get_enqueued_local_size(0); // NOLINT
#endif
}
struct index
{
......@@ -126,8 +128,8 @@ struct index
#else
__device__ index_int max_nlocal() const
{
MIGRAPHX_ASSERT(blockDim.x > 0);
return blockDim.x;
MIGRAPHX_ASSERT(compute_max_local_size() > 0);
return compute_max_local_size();
}
#endif
......@@ -249,7 +251,8 @@ struct index
#endif
inline __device__ __attribute__((const)) index make_index()
{
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
return index{
blockIdx.x * compute_max_local_size() + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
}
} // namespace migraphx
......
......@@ -24,9 +24,8 @@
#include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#if !defined(_MSC_VER)
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#endif
#include <migraphx/pass_manager.hpp>
......@@ -126,6 +125,60 @@ struct find_add_layernorm
m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
}
};
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
}
struct find_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_lit = r.instructions["scale"];
float scale = 1.0;
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, pre_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs);
}
};
} // namespace
#endif
......@@ -135,6 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#else
(void)mpm;
#endif
}
......
......@@ -41,8 +41,7 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
}
using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double>
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
double time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{
// TODO: Use std::ref
......@@ -51,21 +50,19 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
auto output = op.compute_shape(inputs);
op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs);
auto run = [&] {
op.compute(ctx, output, args);
ctx.finish();
};
gctx.enable_perf_measurement();
auto start = context::create_event_for_timing();
auto stop = context::create_event_for_timing();
auto run = [&] { op.compute(ctx, output, args); };
run();
double host_time = 0.0;
double device_time = 0.0;
gctx.get_stream().record(start.get());
for(auto i : range(n))
{
(void)i;
host_time += time<milliseconds>(run);
device_time += gctx.get_elapsed_ms();
run();
}
return std::make_pair(host_time / n, device_time / n);
gctx.get_stream().record(stop.get());
gctx.finish();
return context::get_elapsed_ms(start.get(), stop.get()) / n;
}
} // namespace gpu
......
......@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};
......
......@@ -60,7 +60,7 @@ struct concat
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};
......@@ -104,7 +104,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};
......
......@@ -34,7 +34,8 @@
void run_pass(migraphx::program& p)
{
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}});
migraphx::run_passes(
p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}});
}
template <class F>
......@@ -151,7 +152,6 @@ TEST_CASE(int_quant_dot_tanh_fails)
int main(int argc, const char* argv[])
{
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv);
test::run(argc, argv);
return 0;
}
......@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};
......
......@@ -57,7 +57,7 @@ struct normalize_test_op
const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const
{
return {output_shape};
return migraphx::argument{output_shape};
}
};
......
6d7bc2a097a1a08541cd0d4628831c79ab8092d5
635d3faa3b3908d2806d009dc6872152cfcfcdda
This diff is collapsed.
group_norm_3d_half_test:
M
x
scale
biasy"GroupNormalization*
epsilon'7*
num_groupsgroup_norm_3d_half_testZ
x




Z
scale


Z
bias


b
y




B
\ No newline at end of file
 group_norm_3d_test:
:
x
scale
biasy"GroupNormalization*
num_groupsgroup_norm_3d_testZ
x



Z
scale

Z
bias

b
y



B
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
 )group_norm_invalid_input_count_error_test:
4
x
scaley"GroupNormalization*
num_groups)group_norm_invalid_input_count_error_testZ
x




Z
scale

b
y




B
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment