BUILD 2.54 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
licenses(["notice"])

package(
    default_visibility = [
        "//visibility:public",
    ],
)

cc_library(
    name = "sgnn_projection",
    srcs = ["sgnn_projection.cc"],
    hdrs = ["sgnn_projection.h"],
    deps = [
        "@org_tensorflow//tensorflow/lite:context",
        "@org_tensorflow//tensorflow/lite:string_util",
        "@org_tensorflow//tensorflow/lite/kernels:kernel_util",
        "@org_tensorflow//tensorflow/lite/kernels/internal:tensor",
        "@farmhash_archive//:farmhash",
        "@flatbuffers",
    ],
)

cc_library(
    name = "sgnn_projection_op_resolver",
    srcs = ["sgnn_projection_op_resolver.cc"],
    hdrs = ["sgnn_projection_op_resolver.h"],
    visibility = ["//visibility:public"],
    deps = [
        ":sgnn_projection",
        "@org_tensorflow//tensorflow/lite:framework",
    ],
    alwayslink = 1,
)

cc_test(
    name = "sgnn_projection_test",
    srcs = ["sgnn_projection_test.cc"],
    deps = [
        ":sgnn_projection",
        "@org_tensorflow//tensorflow/lite:string_util",
        "@org_tensorflow//tensorflow/lite/kernels:test_util",
        "@org_tensorflow//tensorflow/lite/schema:schema_fbs",
        "@com_google_googletest//:gtest_main",
        "@flatbuffers",
    ],
)

py_library(
    name = "sgnn",
    srcs = [
        "sgnn.py",
    ],
    srcs_version = "PY3",
    deps = [
        # package tensorflow
        "@org_tflite_support//tensorflow_lite_support/custom_ops/python:tflite_text_api",
        # Expect tensorflow text installed
    ],
)

py_test(
    name = "sgnn_test",
    srcs = [
        "sgnn_test.py",
    ],
    deps = [
        ":sgnn",
        # package tensorflow
        # Expect tensorflow text installed
    ],
)

py_binary(
    name = "train",
    srcs = [
        "train.py",
    ],
    main = "train.py",
    python_version = "PY3",
    deps = [
        ":sgnn",
        # package tensorflow
        # package tensorflow_datasets
    ],
)

py_binary(
    name = "run_tflite",
    srcs = ["run_tflite.py"],
    main = "run_tflite.py",
    python_version = "PY3",
    deps = [
        # Expect numpy installed
        # package TFLite flex delegate
        # package TFLite interpreter
        "@org_tflite_support//tensorflow_lite_support/custom_ops/kernel:ngrams_op_resolver",
        "@org_tflite_support//tensorflow_lite_support/custom_ops/kernel:whitespace_tokenizer_op_resolver",
        # Expect tensorflow text installed
    ],
)

# pip install numpy
py_library(
    name = "expect_numpy_installed",
)

# pip install tensroflow_text
py_library(
    name = "expect_tensorflow_text_installed",
)