Commit 6aeb4198 authored by Alan Turner's avatar Alan Turner
Browse files

Merge

parent e7f7ea10
...@@ -69,8 +69,7 @@ struct find_gemm_softmax_gemm_gemm ...@@ -69,8 +69,7 @@ struct find_gemm_softmax_gemm_gemm
auto gemm1 = auto gemm1 =
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1"))); match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale"); auto mul = match::name("mul")(match::any_of[match::inputs()](gemm1)).bind("scale");
auto add = match::name("add")(match::any_of[match::inputs()](mul)); auto softmax = match::name("softmax")(match::any_of[match::inputs()](mul)).bind("softmax");
auto softmax = match::name("softmax")(match::any_of[match::inputs()](add)).bind("softmax");
return match::name("dot")(is_ck_gemm().bind("gemm2"))( return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match::any_of[match::inputs()](softmax)); match::any_of[match::inputs()](softmax));
} }
......
...@@ -28,6 +28,7 @@ import tensorflow as tf ...@@ -28,6 +28,7 @@ import tensorflow as tf
def tf_test(op_test): def tf_test(op_test):
def run_test(): def run_test():
g1 = tf.Graph() g1 = tf.Graph()
op_test(g1) op_test(g1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment