Commit 6711780a authored by Artur Wojcik's avatar Artur Wojcik
Browse files

Merge branch 'develop' into uif2-initial

parents c0563b9e d1abf06f
...@@ -31,6 +31,14 @@ ...@@ -31,6 +31,14 @@
#include <migraphx/kernels/debug.hpp> #include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp> #include <migraphx/kernels/functional.hpp>
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wreserved-identifier"
extern "C" __device__ size_t __ockl_get_enqueued_local_size(uint); // NOLINT
extern "C" __device__ size_t __ockl_get_local_size(uint); // NOLINT
#pragma clang diagnostic pop
#endif
namespace migraphx { namespace migraphx {
#if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL) #if defined(MIGRAPHX_NGLOBAL) && defined(MIGRAPHX_NLOCAL)
...@@ -49,39 +57,33 @@ inline __device__ __attribute__((const)) index_int compute_global_size() ...@@ -49,39 +57,33 @@ inline __device__ __attribute__((const)) index_int compute_global_size()
#endif #endif
} }
// We cant just use blockDim.x to get the local size since its broken on hip #ifdef MIGRAPHX_NGROUP
// when global is not divisible by local size. In this case, we calulate the // If global is divisible by local then local can be a const
// size for the last group. #if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1)
#define MIGRAPHX_HAS_CONST_LOCAL 1
#endif
#endif
inline __device__ __attribute__((const)) index_int compute_local_size() inline __device__ __attribute__((const)) index_int compute_local_size()
{ {
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_HAS_CONST_LOCAL
const auto nlocal = MIGRAPHX_NLOCAL; return MIGRAPHX_NLOCAL;
#else
const auto nlocal = blockDim.x; // NOLINT
#endif
#ifdef MIGRAPHX_NGROUP
const auto ngroup = MIGRAPHX_NGROUP;
#else #else
const auto ngroup = gridDim.x; // NOLINT // Returns block size. For the non-uniform block it returns the size of the non-uniform block.
return __ockl_get_local_size(0); // NOLINT
#endif #endif
const auto group_id = blockIdx.x; // NOLINT
const auto nglobal = compute_global_size();
if(group_id == ngroup - 1)
{
return 1 + (nglobal - 1) % nlocal;
}
else
{
return nlocal; // NOLINT
}
} }
#ifdef MIGRAPHX_NGROUP inline __device__ __attribute__((const)) index_int compute_max_local_size()
// If global is divisible by local then local can be a const {
#if(MIGRAPHX_NGLOBAL % MIGRAPHX_NLOCAL == 0) || (MIGRAPHX_NGROUP == 1) #ifdef MIGRAPHX_LOCAL
#define MIGRAPHX_HAS_CONST_LOCAL 1 return MIGRAPHX_NLOCAL;
#endif #else
// Returns the block size. When workgrop has non-uniform block, this returns size of the uniform
// block.
return __ockl_get_enqueued_local_size(0); // NOLINT
#endif #endif
}
struct index struct index
{ {
...@@ -126,8 +128,8 @@ struct index ...@@ -126,8 +128,8 @@ struct index
#else #else
__device__ index_int max_nlocal() const __device__ index_int max_nlocal() const
{ {
MIGRAPHX_ASSERT(blockDim.x > 0); MIGRAPHX_ASSERT(compute_max_local_size() > 0);
return blockDim.x; return compute_max_local_size();
} }
#endif #endif
...@@ -249,7 +251,8 @@ struct index ...@@ -249,7 +251,8 @@ struct index
#endif #endif
inline __device__ __attribute__((const)) index make_index() inline __device__ __attribute__((const)) index make_index()
{ {
return index{blockIdx.x * blockDim.x + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT return index{
blockIdx.x * compute_max_local_size() + threadIdx.x, threadIdx.x, blockIdx.x}; // NOLINT
} }
} // namespace migraphx } // namespace migraphx
......
...@@ -24,9 +24,8 @@ ...@@ -24,9 +24,8 @@
#include <migraphx/permutation.hpp> #include <migraphx/permutation.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
#if !defined(_MSC_VER) #if !defined(_MSC_VER)
#include <migraphx/gpu/gemm_softmax_gemm.hpp>
#include <migraphx/match/layernorm.hpp> #include <migraphx/match/layernorm.hpp>
#include <migraphx/check_shapes.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp> #include <migraphx/register_op.hpp>
#endif #endif
#include <migraphx/pass_manager.hpp> #include <migraphx/pass_manager.hpp>
...@@ -126,6 +125,60 @@ struct find_add_layernorm ...@@ -126,6 +125,60 @@ struct find_add_layernorm
m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs()); m.replace_instruction(ins, add_layernorm{op.epsilon}, add_ins->inputs());
} }
}; };
struct pre_gemm_softmax_gemm : gemm_softmax_gemm
{
std::string name() const { return "gpu::pre_gemm_softmax_gemm"; }
};
MIGRAPHX_REGISTER_OP(pre_gemm_softmax_gemm);
MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
{
if(ins->name() != "dot")
return false;
if(not pre_gemm_softmax_gemm::is_ck_supported_type(ins->get_shape().type()))
return false;
return true;
}
struct find_gemm_softmax_gemm
{
auto matcher() const
{
auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("mul")(
match::nargs(2), match::either_arg(0, 1)(match::is_constant().bind("scale"), gemm1));
auto softmax = match::name("softmax")(match::arg(0)(mul)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::arg(0)(softmax));
}
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto gemm2_ins = r.instructions["gemm2"];
auto gemm1_ins = r.instructions["gemm1"];
auto scale_lit = r.instructions["scale"];
float scale = 1.0;
scale_lit->eval().visit([&](const auto s) {
// CK only supports single-valued scale
if(std::all_of(
s.begin() + 1, s.end(), [&](auto v) { return float_equal(v, s.front()); }))
scale = s.front();
else
return;
});
auto inputs = gemm1_ins->inputs(); // A, B
inputs.push_back(gemm2_ins->inputs().back()); // B1
mpm.get_module().replace_instruction(
ins, pre_gemm_softmax_gemm{gemm2_ins->get_operator(), scale}, inputs);
}
};
} // namespace } // namespace
#endif #endif
...@@ -135,6 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const ...@@ -135,6 +188,10 @@ void prefuse_ops::apply(module_pass_manager& mpm) const
match::find_matches(mpm.get_module(), find_layernorm{}); match::find_matches(mpm.get_module(), find_layernorm{});
mpm.run_pass(dead_code_elimination{}); mpm.run_pass(dead_code_elimination{});
match::find_matches(mpm.get_module(), find_add_layernorm{}); match::find_matches(mpm.get_module(), find_add_layernorm{});
if(enabled(MIGRAPHX_ENABLE_CK{}))
match::find_matches(mpm, find_gemm_softmax_gemm{});
#else
(void)mpm;
#endif #endif
} }
......
...@@ -41,8 +41,7 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig ...@@ -41,8 +41,7 @@ std::vector<argument> generate_arguments(const std::vector<shape>& shapes, unsig
} }
using milliseconds = std::chrono::duration<double, std::milli>; using milliseconds = std::chrono::duration<double, std::milli>;
std::pair<double, double> double time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
{ {
// TODO: Use std::ref // TODO: Use std::ref
...@@ -51,21 +50,19 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n) ...@@ -51,21 +50,19 @@ time_op(context& ictx, operation op, const std::vector<shape>& inputs, int n)
auto output = op.compute_shape(inputs); auto output = op.compute_shape(inputs);
op.finalize(ctx, output, inputs); op.finalize(ctx, output, inputs);
auto args = generate_arguments(inputs); auto args = generate_arguments(inputs);
auto run = [&] { auto start = context::create_event_for_timing();
op.compute(ctx, output, args); auto stop = context::create_event_for_timing();
ctx.finish(); auto run = [&] { op.compute(ctx, output, args); };
};
gctx.enable_perf_measurement();
run(); run();
double host_time = 0.0; gctx.get_stream().record(start.get());
double device_time = 0.0;
for(auto i : range(n)) for(auto i : range(n))
{ {
(void)i; (void)i;
host_time += time<milliseconds>(run); run();
device_time += gctx.get_elapsed_ms();
} }
return std::make_pair(host_time / n, device_time / n); gctx.get_stream().record(stop.get());
gctx.finish();
return context::get_elapsed_ms(start.get(), stop.get()) / n;
} }
} // namespace gpu } // namespace gpu
......
...@@ -55,7 +55,7 @@ struct allocate ...@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -60,7 +60,7 @@ struct concat ...@@ -60,7 +60,7 @@ struct concat
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
...@@ -104,7 +104,7 @@ struct allocate ...@@ -104,7 +104,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -34,7 +34,8 @@ ...@@ -34,7 +34,8 @@
void run_pass(migraphx::program& p) void run_pass(migraphx::program& p)
{ {
migraphx::run_passes(p, {migraphx::gpu::fuse_mlir{}, migraphx::dead_code_elimination{}}); migraphx::run_passes(
p, {migraphx::gpu::fuse_mlir{.enable_extra = true}, migraphx::dead_code_elimination{}});
} }
template <class F> template <class F>
...@@ -151,7 +152,6 @@ TEST_CASE(int_quant_dot_tanh_fails) ...@@ -151,7 +152,6 @@ TEST_CASE(int_quant_dot_tanh_fails)
int main(int argc, const char* argv[]) int main(int argc, const char* argv[])
{ {
if(migraphx::gpu::mlir_enabled())
test::run(argc, argv); test::run(argc, argv);
return 0; return 0;
} }
...@@ -55,7 +55,7 @@ struct allocate ...@@ -55,7 +55,7 @@ struct allocate
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
...@@ -57,7 +57,7 @@ struct normalize_test_op ...@@ -57,7 +57,7 @@ struct normalize_test_op
const migraphx::shape& output_shape, const migraphx::shape& output_shape,
const std::vector<migraphx::argument>&) const const std::vector<migraphx::argument>&) const
{ {
return {output_shape}; return migraphx::argument{output_shape};
} }
}; };
......
6d7bc2a097a1a08541cd0d4628831c79ab8092d5 635d3faa3b3908d2806d009dc6872152cfcfcdda
This diff is collapsed.
group_norm_3d_half_test:
M
x
scale
biasy"GroupNormalization*
epsilon'7*
num_groupsgroup_norm_3d_half_testZ
x




Z
scale


Z
bias


b
y




B
\ No newline at end of file
 group_norm_3d_test:
:
x
scale
biasy"GroupNormalization*
num_groupsgroup_norm_3d_testZ
x



Z
scale

Z
bias

b
y



B
\ No newline at end of file
group_norm_4d_half_test:
M
x
scale
biasy"GroupNormalization*
epsilon'7*
num_groupsgroup_norm_4d_half_testZ
x





Z
scale


Z
bias


b
y





B
\ No newline at end of file
 group_norm_4d_test:
:
x
scale
biasy"GroupNormalization*
num_groupsgroup_norm_4d_testZ
x




Z
scale

Z
bias

b
y




B
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment