"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "274c772b385003c4bef3bbe90b29de3dd2287dd7"
Commit 2bca512e authored by charlie's avatar charlie
Browse files

progress

parent 6c41008a
......@@ -38,10 +38,10 @@ struct gelu_erf_matcher
F f;
auto erf_fn() const
{
auto mul_1_sqrt_2 = f("mul")(either_arg(0, 1)(none_of(has_value(M_SQRT1_2, 1e-3)).bind("x"),
has_value(M_SQRT1_2, 1e-3)));
auto mul_1_sqrt_2 = f("mul")(
either_arg(0, 1)(none_of(has_value(M_SQRT1_2)).bind("x"), has_value(M_SQRT1_2, 1e-3)));
auto div_sqrt_2 =
f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2, 1e-3)));
f("div")(args(none_of(has_value(M_SQRT2, 1e-3)).bind("x"), has_value(M_SQRT2)));
return f("erf")(used_once(), arg(0)(used_once(), any_of(mul_1_sqrt_2, div_sqrt_2)));
}
......
......@@ -36,23 +36,34 @@ template <class F>
struct gelu_tanh_matcher
{
F f;
/// x ^ 3
auto pow_fn() const { return f("pow")(used_once(), arg(1)(has_value(3.0f))); }
/// tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3 )
auto tanh_fn() const
{
return f("tanh")(
used_once(),
arg(0)(f("mul")(either_arg(0, 1)(has_value(sqrt(M_2_PI), 1e-3),
f("add")(any_arg(0, 1)(f("mul")(either_arg(0, 1)(
has_value(0.044715f), pow_fn()))))))));
auto mul_const_pow = f("mul")(either_arg(0, 1)(has_value(0.044715f), pow_fn()));
auto add_any_mul = f("add")(any_arg(0, 1)(mul_const_pow));
auto either_SQRT2RPI_add = either_arg(0, 1)(has_value(sqrt(M_2_PI)), add_any_mul);
return f("tanh")(used_once(), arg(0)(f("mul")(either_SQRT2RPI_add)));
}
/// x * (0.5? + 0.5 * tanh( sqrt(2/M_PI) * (x? + 0.044715 * x? ^ 3) ) )
/// <item>? question mark means it doesn't explicitly match that item (anything will work)
auto matcher_v0() const
{
auto mul_half_tanh = f("mul")(either_arg(0, 1)(has_value(0.5f), tanh_fn()));
auto add_any_mul = f("add")(any_arg(0, 1)(mul_half_tanh));
return f("mul")(either_arg(0, 1)(any().bind("x"), add_any_mul));
}
auto matcher() const
/// x * 0.5 * (1.0 + tanh( sqrt(2/M_PI) * (x + 0.044715 * x ^ 3) ) )
auto matcher_v1() const
{
return f("mul")(used_once(),
either_arg(0, 1)(any().bind("x"),
f("add")(any_arg(0, 1)(f("mul")(
either_arg(0, 1)(has_value(0.5f), tanh_fn()))))));
auto add_one_tanh = f("add")(used_once(), either_arg(0, 1)(has_value(1.0), tanh_fn()));
auto mul_half_x = f("mul")(used_once(), either_arg(0, 1)(has_value(0.5), any().bind("x")));
return f("mul")(either_arg(0, 1)(mul_half_x, add_one_tanh));
}
};
} // namespace detail
......@@ -60,7 +71,7 @@ struct gelu_tanh_matcher
template <class F>
auto gelu_tanh(F f)
{
return detail::gelu_tanh_matcher<F>{f}.matcher();
return detail::gelu_tanh_matcher<F>{f}.matcher_v1();
}
inline auto gelu_tanh()
......
......@@ -864,7 +864,7 @@ auto skip_broadcasts_transposes_contiguous(Ms... ms)
}
template <class T>
inline auto has_value(T x, float tolerance = 1e-6)
inline auto has_value(T x, std::size_t atol_mult = 10, std::size_t rtol_mult = 10)
{
return skip_broadcasts_converts(make_basic_pred_matcher([=](instruction_ref ins) {
if(ins->name() != "@literal")
......@@ -874,8 +874,13 @@ inline auto has_value(T x, float tolerance = 1e-6)
return false;
bool b = false;
l.visit([&](auto v) {
if(std::all_of(
v.begin(), v.end(), [&](auto val) { return std::fabs(val - x) < tolerance; }))
// cast to the literal's data type before comparing
using type = typename decltype(v)::value_type;
auto eps = std::numeric_limits<type>::epsilon();
if(std::all_of(v.begin(), v.end(), [&](auto val) {
return std::fabs(val - static_cast<type>(x)) <
(atol_mult * eps + rtol_mult * eps * std::fabs(val));
}))
b = true;
});
return b;
......
......@@ -70,6 +70,7 @@ struct find_tanh_fast_gelu
void apply(module& m, const match::matcher_result& r) const
{
/*
auto ins = r.result;
auto x = r.instructions["x"];
auto sqrt_2_rpi = m.add_literal(
......@@ -89,6 +90,18 @@ struct find_tanh_fast_gelu
auto cdf = insert_common_op(m, ins, make_op("div"), {one, e});
auto y = m.insert_instruction(ins, make_op("mul"), x, cdf);
m.replace_instruction(ins, y);
*/
auto ins = r.result;
auto x = r.instructions["x"];
auto sqrt1_2 = m.add_literal(literal{shape{x->get_shape().type()}, {M_SQRT1_2}});
auto one = m.add_literal(literal{shape{x->get_shape().type()}, {1.0f}});
auto one_half = m.add_literal(literal{shape{x->get_shape().type()}, {0.5f}});
auto a = insert_common_op(m, ins, make_op("mul"), {x, sqrt1_2});
auto erf = m.insert_instruction(ins, make_op("erf"), a);
auto add_erf = insert_common_op(m, ins, make_op("add"), {one, erf});
auto b = insert_common_op(m, ins, make_op("mul"), {one_half, add_erf});
auto y = m.insert_instruction(ins, make_op("mul"), x, b);
m.replace_instruction(ins, y);
}
};
......
......@@ -1217,7 +1217,7 @@ struct find_unit_ops
auto div_1 =
match::name("div")(match::args(match::any().bind("x"), match::has_value(1.0f)));
auto add_0 = match::name("add")(
match::either_arg(0, 1)(match::has_value(0.0f, 1e-12), match::any().bind("x")));
match::either_arg(0, 1)(match::has_value(0.0f, 0, 0), match::any().bind("x")));
auto sub_0 =
match::name("sub")(match::args(match::any().bind("x"), match::has_value(0.0f)));
return match::any_of(mul_1, div_1, add_0, sub_0);
......
......@@ -83,7 +83,8 @@ struct miopen_apply
assert(mod != nullptr);
assert(pass != nullptr);
compute_fp32 = get_compute_fp32_flag();
// compute_fp32 = get_compute_fp32_flag();
compute_fp32 = true;
offload_copy = (mod == mpm->get_root_module()) ? pass->offload_copy : false;
add_generic_op("contiguous");
......
......@@ -122,4 +122,49 @@ TEST_CASE(non_bias_gelu)
EXPECT(m1 == m2);
}
TEST_CASE(tanh_gelu_distilgpt2_fp16)
{
// Uses constant values seen in the distilgpt2_fp16 model, note how they're not exactly right
migraphx::shape s1{migraphx::shape::half_type, {2, 4, 8}};
migraphx::shape s2{migraphx::shape::half_type};
migraphx::module m1;
{
auto x = m1.add_parameter("x", s1);
auto fit_const = m1.add_literal(migraphx::literal{s2, {0.044708251953125}});
auto sqrt_2_rpi = m1.add_literal(migraphx::literal{s2, {0.7978515625}});
auto one = m1.add_literal(migraphx::literal{s2, {1.0f}});
auto one_half = m1.add_literal(migraphx::literal{s2, {0.5f}});
auto three = m1.add_literal(migraphx::literal{s2, {3.0f}});
auto pow0 = add_common_op(m1, migraphx::make_op("pow"), {x, three});
auto mul0 = add_common_op(m1, migraphx::make_op("mul"), {pow0, fit_const});
auto add0 = m1.add_instruction(migraphx::make_op("add"), {mul0, x});
auto mul1 = add_common_op(m1, migraphx::make_op("mul"), {add0, sqrt_2_rpi});
auto tanh0 = m1.add_instruction(migraphx::make_op("tanh"), mul1);
auto add1 = add_common_op(m1, migraphx::make_op("add"), {tanh0, one});
auto mul2 = add_common_op(m1, migraphx::make_op("mul"), {x, one_half});
auto y = m1.add_instruction(migraphx::make_op("mul"), {add1, mul2});
m1.add_return({y});
}
migraphx::rewrite_gelu pass;
pass.apply(m1);
migraphx::dead_code_elimination dce;
dce.apply(m1);
migraphx::module m2;
{
auto x = m2.add_parameter("x", s1);
auto sqrt1_2 = m2.add_literal(migraphx::literal{s2, {M_SQRT1_2}});
auto one = m2.add_literal(migraphx::literal{s2, {1.0f}});
auto one_half = m2.add_literal(migraphx::literal{s2, {0.5f}});
auto a = add_common_op(m2, migraphx::make_op("mul"), {x, sqrt1_2});
auto erf = m2.add_instruction(migraphx::make_op("erf"), a);
auto add_erf = add_common_op(m2, migraphx::make_op("add"), {one, erf});
auto b = add_common_op(m2, migraphx::make_op("mul"), {one_half, add_erf});
auto y = m2.add_instruction(migraphx::make_op("mul"), x, b);
m2.add_return({y});
}
EXPECT(m1 == m2);
}
int main(int argc, const char* argv[]) { test::run(argc, argv); }
......@@ -136,9 +136,13 @@ def check_correctness(gold_outputs,
if verbose:
with np.printoptions(threshold=np.inf):
print('\nOutput {} is incorrect ...'.format(i))
print('Expected value: \n{}\n'.format(gold_outputs[i]))
print('\n......\n')
print('Actual value: \n{}\n'.format(outputs[i]))
#print('Expected value: \n{}'.format(gold_outputs[i]))
#print('\n......\n')
#print('Actual value: \n{}\n'.format(outputs[i]))
diff = gold_outputs[i] - outputs[i]
#print(f'Difference: {diff}')
max_diff = np.max(np.abs(diff))
print(f'Max Difference: {max_diff}')
else:
print('Outputs do not match')
break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment