"...resnet50_tensorflow.git" did not exist on "eb8adfe68c95b5e6860eb84ec7bb6555317d2200"
Commit 2bca512e authored by charlie's avatar charlie
Browse files

progress

parent 6c41008a
...@@ -38,10 +38,10 @@ struct gelu_erf_matcher ...@@ -38,10 +38,10 @@ struct gelu_erf_matcher
F f; F f;
auto erf_fn() const auto erf_fn() const
{ {
auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"), auto mul_1_sqrt_2 = f("mul")(
has_value(M_SQRT1_2, 1e-3))); either_arg(0, 1)(none_of(has_value(M_SQRT1_2)).bind("x"), has_value(M_SQRT1_2, 1e-3)));
auto div_sqrt_2 = auto div_sqrt_2 =
f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3))); f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2)));
return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2))); return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
} }
......
...@@ -36,23 +36,34 @@ template <class F> ...@@ -36,23 +36,34 @@ template <class F>
struct gelu_tanh_matcher struct gelu_tanh_matcher
{ {
F f; F f;
/// x ^ 3
auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); } auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); }
/// tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3 )
auto tanh_fn() const auto tanh_fn() const
{ {
return f("tanh")( auto mul_const_pow = f("mul")(either_arg(0, 1)(has_value(0.044715f), pow_fn()));
used_once(), auto add_any_mul = f("add")(any_arg(0, 1)(mul_const_pow));
arg(0)(f("mul")(either_arg(0, 1)(has_value(sqrt(M_2_PI), 1e-3), auto either_SQRT2RPI_add = either_arg(0, 1)(has_value(sqrt(M_2_PI)), add_any_mul);
f("add")(any_arg(0, 1)(f("mul")(either_arg(0, 1)( return f("tanh")(used_once(), arg(0)(f("mul")(either_SQRT2RPI_add)));
has_value(0.044715f), pow_fn())))))))); }
/// x * (0.5? + 0.5 * tanh( sqrt(2/M_PI) * (x? + 0.044715 * x? ^ 3) ) )
/// <item>? question mark means it doesn't explicitly match that item (anything will work)
auto matcher_v0() const
{
auto mul_half_tanh = f("mul")(either_arg(0, 1)(has_value(0.5f), tanh_fn()));
auto add_any_mul = f("add")(any_arg(0, 1)(mul_half_tanh));
return f("mul")(either_arg(0, 1)(any().bind("x"), add_any_mul));
} }
auto matcher() const /// x * 0.5 * (1.0 + tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3) ) )
auto matcher_v1() const
{ {
return f("mul")(used_once(), auto add_one_tanh = f("add")(used_once(), either_arg(0, 1)(has_value(1.0), tanh_fn()));
either_arg(0, 1)(any().bind("x"), auto mul_half_x = f("mul")(used_once(), either_arg(0, 1)(has_value(0.5), any().bind("x")));
f("add")(any_arg(0, 1)(f("mul")( return f("mul")(either_arg(0, 1)(mul_half_x, add_one_tanh));
either_arg(0, 1)(has_value(0.5f), tanh_fn()))))));
} }
}; };
} // namespace detail } // namespace detail
...@@ -60,7 +71,7 @@ struct gelu_tanh_matcher ...@@ -60,7 +71,7 @@ struct gelu_tanh_matcher
template <class F> template <class F>
auto gelu_tanh(F f) auto gelu_tanh(F f)
{ {
return detail::gelu_tanh_matcher<F>{f}.matcher(); return detail::gelu_tanh_matcher<F>{f}.matcher_v1();
} }
inline auto gelu_tanh() inline auto gelu_tanh()
......
...@@ -864,7 +864,7 @@ auto skip_broadcasts_transposes_contiguous(Ms... ms) ...@@ -864,7 +864,7 @@ auto skip_broadcasts_transposes_contiguous(Ms... ms)
} }
template <class T> template <class T>
inline auto has_value(T x, float tolerance = 1e-6) inline auto has_value(T x, std::size_t atol_mult = 10, std::size_t rtol_mult = 10)
{ {
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) { return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal") if(ins->name() != "@literal")
...@@ -874,8 +874,13 @@ inline auto has_value(T x, float tolerance = 1e-6) ...@@ -874,8 +874,13 @@ inline auto has_value(T x, float tolerance = 1e-6)
return false; return false;
bool b = false; bool b = false;
l.visit([&](auto v) { l.visit([&](auto v) {
if(std::all_of( // cast to the literal's data type before comparing
v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; })) using type = typename decltype(v)::value_type;
auto eps = std::numeric_limits<type>::epsilon();
if(std::all_of(v.begin(), v.end(), [&](auto val) {
return std::fabs(val - static_cast<type>(x)) <
(atol_mult * eps + rtol_mult * eps * std::fabs(val));
}))
b = true; b = true;
}); });
return b; return b;
......
...@@ -70,6 +70,7 @@ struct find_tanh_fast_gelu ...@@ -70,6 +70,7 @@ struct find_tanh_fast_gelu
void apply(module& m, const match::matcher_result& r) const void apply(module& m, const match::matcher_result& r) const
{ {
/*
auto ins = r.result; auto ins = r.result;
auto x = r.instructions["x"]; auto x = r.instructions["x"];
auto sqrt_2_rpi = m.add_literal( auto sqrt_2_rpi = m.add_literal(
...@@ -89,6 +90,18 @@ struct find_tanh_fast_gelu ...@@ -89,6 +90,18 @@ struct find_tanh_fast_gelu
auto cdf = insert_common_op(m, ins, make_op("div"), {one, e}); auto cdf = insert_common_op(m, ins, make_op("div"), {one, e});
auto y = m.insert_instruction(ins, make_op("mul"), x, cdf); auto y = m.insert_instruction(ins, make_op("mul"), x, cdf);
m.replace_instruction(ins, y); m.replace_instruction(ins, y);
*/
auto ins = r.result;
auto x = r.instructions["x"];
auto sqrt1_2 = m.add_literal(literal{shape{x->get_shape().type()}, {M_SQRT1_2}});
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
auto one_half = m.add_literal(literal{shape{x->get_shape().type()}, {0.5f}});
auto a = insert_common_op(m, ins, make_op("mul"), {x, sqrt1_2});
auto erf = m.insert_instruction(ins, make_op("erf"), a);
auto add_erf = insert_common_op(m, ins, make_op("add"), {one, erf});
auto b = insert_common_op(m, ins, make_op("mul"), {one_half, add_erf});
auto y = m.insert_instruction(ins, make_op("mul"), x, b);
m.replace_instruction(ins, y);
} }
}; };
......
...@@ -1217,7 +1217,7 @@ struct find_unit_ops ...@@ -1217,7 +1217,7 @@ struct find_unit_ops
auto div_1 = auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f))); match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")( auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1e-12), match::any().bind("x"))); match::either_arg(0, 1)(match::has_value(0.0f, 0, 0), match::any().bind("x")));
auto sub_0 = auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f))); match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
return match::any_of(mul_1, div_1, add_0, sub_0); return match::any_of(mul_1, div_1, add_0, sub_0);
......
...@@ -83,7 +83,8 @@ struct miopen_apply ...@@ -83,7 +83,8 @@ struct miopen_apply
assert(mod != nullptr); assert(mod != nullptr);
assert(pass != nullptr); assert(pass != nullptr);
compute_fp32 = get_compute_fp32_flag(); // compute_fp32 = get_compute_fp32_flag();
compute_fp32 = true;
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false; offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous"); add_generic_op("contiguous");
......
...@@ -122,4 +122,49 @@ TEST_CASE(non_bias_gelu) ...@@ -122,4 +122,49 @@ TEST_CASE(non_bias_gelu)
EXPECT(m1 == m2); EXPECT(m1 == m2);
} }
TEST_CASE(tanh_gelu_distilgpt2_fp16)
{
// Uses constant values seen in the distilgpt2_fp16 model, note how they're not exactly right
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto fit_const = m1.add_literal(migraphx::literal{s2, {0.044708251953125}});
auto sqrt_2_rpi = m1.add_literal(migraphx::literal{s2, {0.7978515625}});
auto one = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto one_half = m1.add_literal(migraphx::literal{s2, {0.5f}});
auto three = m1.add_literal(migraphx::literal{s2, {3.0f}});
auto pow0 = add_common_op(m1, migraphx::make_op("pow"), {x, three});
auto mul0 = add_common_op(m1, migraphx::make_op("mul"), {pow0, fit_const});
auto add0 = m1.add_instruction(migraphx::make_op("add"), {mul0, x});
auto mul1 = add_common_op(m1, migraphx::make_op("mul"), {add0, sqrt_2_rpi});
auto tanh0 = m1.add_instruction(migraphx::make_op("tanh"), mul1);
auto add1 = add_common_op(m1, migraphx::make_op("add"), {tanh0, one});
auto mul2 = add_common_op(m1, migraphx::make_op("mul"), {x, one_half});
auto y = m1.add_instruction(migraphx::make_op("mul"), {add1, mul2});
m1.add_return({y});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto sqrt1_2 = m2.add_literal(migraphx::literal{s2, {M_SQRT1_2}});
auto one = m2.add_literal(migraphx::literal{s2, {1.0f}});
auto one_half = m2.add_literal(migraphx::literal{s2, {0.5f}});
auto a = add_common_op(m2, migraphx::make_op("mul"), {x, sqrt1_2});
auto erf = m2.add_instruction(migraphx::make_op("erf"), a);
auto add_erf = add_common_op(m2, migraphx::make_op("add"), {one, erf});
auto b = add_common_op(m2, migraphx::make_op("mul"), {one_half, add_erf});
auto y = m2.add_instruction(migraphx::make_op("mul"), x, b);
m2.add_return({y});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); } int main(int argc, const char* argv[]) { test::run(argc, argv); }
...@@ -136,9 +136,13 @@ def check_correctness(gold_outputs, ...@@ -136,9 +136,13 @@ def check_correctness(gold_outputs,
if verbose: if verbose:
with np.printoptions(threshold=np.inf): with np.printoptions(threshold=np.inf):
print('\nOutput {} is incorrect ...'.format(i)) print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}\n'.format(gold_outputs[i])) #print('Expected value: \n{}'.format(gold_outputs[i]))
print('\n......\n') #print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i])) #print('Actual value: \n{}\n'.format(outputs[i]))
diff = gold_outputs[i] - outputs[i]
#print(f'Difference: {diff}')
max_diff = np.max(np.abs(diff))
print(f'Max Difference: {max_diff}')
else: else:
print('Outputs do not match') print('Outputs do not match')
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment