Commit a16adb42 authored by turneram's avatar turneram
Browse files

Fix layernorm verify test

parent d41f0d66
...@@ -40,8 +40,9 @@ struct transposectx_compiler : compiler<transposectx_compiler> ...@@ -40,8 +40,9 @@ struct transposectx_compiler : compiler<transposectx_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
auto h = inputs.front().lens().back();
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); v, compute_global_for(ctx, inputs.back().elements(), h), h);
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposectx_kernel"; options.kernel_name = "transposectx_kernel";
...@@ -79,8 +80,9 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler> ...@@ -79,8 +80,9 @@ struct transposeqkv_compiler : compiler<transposeqkv_compiler>
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
{ {
hip_compile_options options; hip_compile_options options;
auto h = inputs.front().lens().back();
options.set_launch_params( options.set_launch_params(
v, compute_global_for(ctx, inputs.back().elements()), inputs.front().lens().back()); v, compute_global_for(ctx, inputs.back().elements(), h), h);
options.output = inputs.back(); options.output = inputs.back();
options.inputs = inputs; options.inputs = inputs;
options.kernel_name = "transposeqkv_kernel"; options.kernel_name = "transposeqkv_kernel";
......
...@@ -11,7 +11,7 @@ struct test_layernorm : verify_program<test_layernorm> ...@@ -11,7 +11,7 @@ struct test_layernorm : verify_program<test_layernorm>
auto* mm = p.get_main_module(); auto* mm = p.get_main_module();
auto x = auto x =
mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 384, 768}}); mm->add_parameter("x", migraphx::shape{migraphx::shape::float_type, {1, 384, 768}});
mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}}), x); mm->add_instruction(migraphx::make_op("layernorm", {{"axis", -1}, {"epsilon", 1e-12}}), x);
return p; return p;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment