Commit ff3bd8e6 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

manual merge

parents 32b69ceb c310bc5c
...@@ -60,7 +60,11 @@ endif() ...@@ -60,7 +60,11 @@ endif()
set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "") set(MIGRAPHX_ENABLE_CPU Off CACHE BOOL "")
set(CMAKE_CXX_STANDARD_DEFAULT "") set(CMAKE_CXX_STANDARD_DEFAULT "")
add_compile_options(-std=c++14) if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-std=c++17)
else()
add_compile_options(-std=c++14)
endif()
list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake)
include(EnableCompilerWarnings) include(EnableCompilerWarnings)
...@@ -187,6 +191,8 @@ rocm_enable_cppcheck( ...@@ -187,6 +191,8 @@ rocm_enable_cppcheck(
definePrefix:*test/include/test.hpp definePrefix:*test/include/test.hpp
useSmartPointer:*src/api/api.cpp useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp
constParameter:*src/targets/gpu/*.hpp
FORCE FORCE
INCONCLUSIVE INCONCLUSIVE
RULE_FILE RULE_FILE
......
...@@ -74,7 +74,7 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma ...@@ -74,7 +74,7 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
RUN cget -p $PREFIX install ccache@v4.1 RUN cget -p $PREFIX install ccache@v4.1
# Install newer cmake for onnx runtime # Install newer cmake for onnx runtime
RUN cget -p /opt/cmake install kitware/cmake@v3.13.0 RUN cget -p /opt/cmake install kitware/cmake@v3.13.4
ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime ARG ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG ONNXRUNTIME_BRANCH=master ARG ONNXRUNTIME_BRANCH=master
...@@ -86,6 +86,8 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR ...@@ -86,6 +86,8 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh ADD tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN PATH=/opt/cmake/bin:$PATH cget -p /usr/local install ROCmSoftwarePlatform/llvm-project-mlir@02078ce236ad90e3aec04c0c770ef5bfc99e49c2
ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db ENV MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db ENV MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV LD_LIBRARY_PATH=$PREFIX/lib ENV LD_LIBRARY_PATH=$PREFIX/lib
......
...@@ -94,6 +94,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build -> ...@@ -94,6 +94,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release") cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=release")
stash includes: 'build/*.deb', name: 'migraphx-package' stash includes: 'build/*.deb', name: 'migraphx-package'
} }
}, mlir_debug: rocmnode('vega') { cmake_build ->
stage('MLIR Debug') {
def sanitizers = "undefined"
def debug_flags = "-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build("/opt/rocm/llvm/bin/clang++", "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'")
}
} }
def onnxnode(name, body) { def onnxnode(name, body) {
......
...@@ -43,7 +43,7 @@ ...@@ -43,7 +43,7 @@
"from os import path\n", "from os import path\n",
"import sys\n", "import sys\n",
"\n", "\n",
"import tokenization\n", "import tokenizers\n",
"from run_onnx_squad import *\n", "from run_onnx_squad import *\n",
"\n", "\n",
"import migraphx" "import migraphx"
...@@ -137,8 +137,7 @@ ...@@ -137,8 +137,7 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n", "vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n",
"tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,\n", "tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)"
" do_lower_case=True)"
] ]
}, },
{ {
......
...@@ -7,21 +7,25 @@ There are two ways to run the example: ...@@ -7,21 +7,25 @@ There are two ways to run the example:
# Steps # Steps
1) Install MIGraphX to your environment. Please follow the steps to build MIGraphX given at https://github.com/ROCmSoftwarePlatform/AMDMIGraphX 1) Install MIGraphX to your environment. Please follow the steps to build MIGraphX given at https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
2) Install the requirements file 2) Upgrade your pip3 to latest version
``` ```
pip3 install -r requirements_migraphx.txt pip3 install --upgrade pip
``` ```
3) Install `unzip` and fetch the uncased file (vocabulary): 3) Install the requirements file
```
pip3 install -r requirements_bertsquad.txt
```
4) Install `unzip` and fetch the uncased file (vocabulary):
``` ```
apt-get install unzip apt-get install unzip
wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip unzip uncased_L-12_H-768_A-12.zip
``` ```
4) Get BERT ONNX model (bertsquad-10.onnx): 5) Get BERT ONNX model (bertsquad-10.onnx):
``` ```
wget https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx wget https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
``` ```
5) Run the inference, it will compile and run the model on three questions and small data provided in `inputs.json`: 6) Run the inference, it will compile and run the model on three questions and small data provided in `inputs.json`:
``` ```
python3 bert-squad-migraphx.py python3 bert-squad-migraphx.py
``` ```
......
...@@ -5,7 +5,7 @@ import os.path ...@@ -5,7 +5,7 @@ import os.path
from os import path from os import path
import sys import sys
import tokenization import tokenizers
from run_onnx_squad import * from run_onnx_squad import *
import migraphx import migraphx
...@@ -30,8 +30,7 @@ n_best_size = 20 ...@@ -30,8 +30,7 @@ n_best_size = 20
max_answer_length = 30 max_answer_length = 30
vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt') vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')
tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file, tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)
do_lower_case=True)
# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input # Use convert_examples_to_features method from run_onnx_squad to get parameters from the input
input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features( input_ids, input_mask, segment_ids, extra_data = convert_examples_to_features(
......
tensorflow==1.14 tensorflow==2.4.0
onnxruntime onnxruntime
\ No newline at end of file tokenizers
\ No newline at end of file
...@@ -38,7 +38,8 @@ from timeit import default_timer as timer ...@@ -38,7 +38,8 @@ from timeit import default_timer as timer
import numpy as np import numpy as np
import onnxruntime as onnxrt import onnxruntime as onnxrt
import six import six
import tokenization from tokenizers import BertWordPieceTokenizer
from tokenizers import pre_tokenizers
RawResult = collections.namedtuple("RawResult", RawResult = collections.namedtuple("RawResult",
["unique_id", "start_logits", "end_logits"]) ["unique_id", "start_logits", "end_logits"])
...@@ -70,9 +71,8 @@ class SquadExample(object): ...@@ -70,9 +71,8 @@ class SquadExample(object):
def __repr__(self): def __repr__(self):
s = [] s = []
s.append("qas_id: %s" % (tokenization.printable_text(self.qas_id))) s.append("qas_id: %s" % (self.qas_id))
s.append("question_text: %s" % s.append("question_text: %s" % (self.question_text))
(tokenization.printable_text(self.question_text)))
s.append("doc_tokens: [%s]" % (" ".join(self.doc_tokens))) s.append("doc_tokens: [%s]" % (" ".join(self.doc_tokens)))
if self.start_position: if self.start_position:
s.append("start_position: %d" % (self.start_position)) s.append("start_position: %d" % (self.start_position))
...@@ -130,7 +130,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -130,7 +130,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id = 0 unique_id = 0
for (example_index, example) in enumerate(examples): for (example_index, example) in enumerate(examples):
query_tokens = tokenizer.tokenize(example.question_text) query_tokens = tokenizer.encode(example.question_text)
if len(query_tokens) > max_query_length: if len(query_tokens) > max_query_length:
query_tokens = query_tokens[0:max_query_length] query_tokens = query_tokens[0:max_query_length]
...@@ -140,8 +140,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -140,8 +140,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens = [] all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens): for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens)) orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = tokenizer.tokenize(token) sub_tokens = tokenizer.encode(token, add_special_tokens=False)
for sub_token in sub_tokens: for sub_token in sub_tokens.tokens:
tok_to_orig_index.append(i) tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token) all_doc_tokens.append(sub_token)
...@@ -172,7 +172,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -172,7 +172,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
segment_ids = [] segment_ids = []
tokens.append("[CLS]") tokens.append("[CLS]")
segment_ids.append(0) segment_ids.append(0)
for token in query_tokens: for token in query_tokens.tokens:
tokens.append(token) tokens.append(token)
segment_ids.append(0) segment_ids.append(0)
tokens.append("[SEP]") tokens.append("[SEP]")
...@@ -192,7 +192,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, ...@@ -192,7 +192,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens.append("[SEP]") tokens.append("[SEP]")
segment_ids.append(1) segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens) input_ids = []
for token in tokens:
input_ids.append(tokenizer.token_to_id(token))
# The mask has 1 for real tokens and 0 for padding tokens. Only real # The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to. # tokens are attended to.
...@@ -437,9 +439,15 @@ def get_final_text(pred_text, orig_text, do_lower_case): ...@@ -437,9 +439,15 @@ def get_final_text(pred_text, orig_text, do_lower_case):
# and `pred_text`, and check if they are the same length. If they are # and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same # NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned. # length, we assume the characters are one-to-one aligned.
tokenizer = tokenization.BasicTokenizer(do_lower_case=do_lower_case) tokenizer = pre_tokenizers.Sequence(
[pre_tokenizers.Whitespace(),
pre_tokenizers.Punctuation()])
tok_text = " ".join(tokenizer.tokenize(orig_text)) tok_text = []
for item in tokenizer.pre_tokenize_str(orig_text):
tok_text.append(item[0])
tok_text = " ".join(tok_text)
start_position = tok_text.find(pred_text) start_position = tok_text.find(pred_text)
if start_position == -1: if start_position == -1:
...@@ -559,8 +567,7 @@ def main(): ...@@ -559,8 +567,7 @@ def main():
sess_options = onnxrt.SessionOptions() sess_options = onnxrt.SessionOptions()
sess_options.session_log_verbosity_level = args.log sess_options.session_log_verbosity_level = args.log
tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file, tokenizer = BertWordPieceTokenizer(vocab_file)
do_lower_case=True)
eval_examples = read_squad_examples(input_file=args.predict_file) eval_examples = read_squad_examples(input_file=args.predict_file)
input_ids, input_mask, segment_ids, extra_data = \ input_ids, input_mask, segment_ids, extra_data = \
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import re
import unicodedata
import six
import tensorflow as tf
def validate_case_matches_checkpoint(do_lower_case, init_checkpoint):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if not init_checkpoint:
return
m = re.match("^.*?([A-Za-z0-9_-]+)/bert_model.ckpt", init_checkpoint)
if m is None:
return
model_name = m.group(1)
lower_models = [
"uncased_L-24_H-1024_A-16", "uncased_L-12_H-768_A-12",
"multilingual_L-12_H-768_A-12", "chinese_L-12_H-768_A-12"
]
cased_models = [
"cased_L-12_H-768_A-12", "cased_L-24_H-1024_A-16",
"multi_cased_L-12_H-768_A-12"
]
is_bad_config = False
if model_name in lower_models and not do_lower_case:
is_bad_config = True
actual_flag = "False"
case_name = "lowercased"
opposite_flag = "True"
if model_name in cased_models and do_lower_case:
is_bad_config = True
actual_flag = "True"
case_name = "cased"
opposite_flag = "False"
if is_bad_config:
raise ValueError(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check." %
(actual_flag, init_checkpoint, model_name, case_name,
opposite_flag))
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def printable_text(text):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text
elif isinstance(text, unicode):
return text.encode("utf-8")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
index = 0
with tf.gfile.GFile(vocab_file, "r") as reader:
while True:
token = convert_to_unicode(reader.readline())
if not token:
break
token = token.strip()
vocab[token] = index
index += 1
return vocab
def convert_by_vocab(vocab, items):
"""Converts a sequence of [tokens|ids] using the vocab."""
output = []
for item in items:
output.append(vocab[item])
return output
def convert_tokens_to_ids(vocab, tokens):
return convert_by_vocab(vocab, tokens)
def convert_ids_to_tokens(inv_vocab, ids):
return convert_by_vocab(inv_vocab, ids)
def whitespace_tokenize(text):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text = text.strip()
if not text:
return []
tokens = text.split()
return tokens
class FullTokenizer(object):
"""Runs end-to-end tokenziation."""
def __init__(self, vocab_file, do_lower_case=True):
self.vocab = load_vocab(vocab_file)
self.inv_vocab = {v: k for k, v in self.vocab.items()}
self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
def tokenize(self, text):
split_tokens = []
for token in self.basic_tokenizer.tokenize(text):
for sub_token in self.wordpiece_tokenizer.tokenize(token):
split_tokens.append(sub_token)
return split_tokens
def convert_tokens_to_ids(self, tokens):
return convert_by_vocab(self.vocab, tokens)
def convert_ids_to_tokens(self, ids):
return convert_by_vocab(self.inv_vocab, ids)
class BasicTokenizer(object):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def __init__(self, do_lower_case=True):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self.do_lower_case = do_lower_case
def tokenize(self, text):
"""Tokenizes a piece of text."""
text = convert_to_unicode(text)
text = self._clean_text(text)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text = self._tokenize_chinese_chars(text)
orig_tokens = whitespace_tokenize(text)
split_tokens = []
for token in orig_tokens:
if self.do_lower_case:
token = token.lower()
token = self._run_strip_accents(token)
split_tokens.extend(self._run_split_on_punc(token))
output_tokens = whitespace_tokenize(" ".join(split_tokens))
return output_tokens
def _run_strip_accents(self, text):
"""Strips accents from a piece of text."""
text = unicodedata.normalize("NFD", text)
output = []
for char in text:
cat = unicodedata.category(char)
if cat == "Mn":
continue
output.append(char)
return "".join(output)
def _run_split_on_punc(self, text):
"""Splits punctuation on a piece of text."""
chars = list(text)
i = 0
start_new_word = True
output = []
while i < len(chars):
char = chars[i]
if _is_punctuation(char):
output.append([char])
start_new_word = True
else:
if start_new_word:
output.append([])
start_new_word = False
output[-1].append(char)
i += 1
return ["".join(x) for x in output]
def _tokenize_chinese_chars(self, text):
"""Adds whitespace around any CJK character."""
output = []
for char in text:
cp = ord(char)
if self._is_chinese_char(cp):
output.append(" ")
output.append(char)
output.append(" ")
else:
output.append(char)
return "".join(output)
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
(cp >= 0x3400 and cp <= 0x4DBF) or #
(cp >= 0x20000 and cp <= 0x2A6DF) or #
(cp >= 0x2A700 and cp <= 0x2B73F) or #
(cp >= 0x2B740 and cp <= 0x2B81F) or #
(cp >= 0x2B820 and cp <= 0x2CEAF) or
(cp >= 0xF900 and cp <= 0xFAFF) or #
(cp >= 0x2F800 and cp <= 0x2FA1F)): #
return True
return False
def _clean_text(self, text):
"""Performs invalid character removal and whitespace cleanup on text."""
output = []
for char in text:
cp = ord(char)
if cp == 0 or cp == 0xfffd or _is_control(char):
continue
if _is_whitespace(char):
output.append(" ")
else:
output.append(char)
return "".join(output)
class WordpieceTokenizer(object):
"""Runs WordPiece tokenziation."""
def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
self.vocab = vocab
self.unk_token = unk_token
self.max_input_chars_per_word = max_input_chars_per_word
def tokenize(self, text):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text = convert_to_unicode(text)
output_tokens = []
for token in whitespace_tokenize(text):
chars = list(token)
if len(chars) > self.max_input_chars_per_word:
output_tokens.append(self.unk_token)
continue
is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in self.vocab:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens
def _is_whitespace(char):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if char == " " or char == "\t" or char == "\n" or char == "\r":
return True
cat = unicodedata.category(char)
if cat == "Zs":
return True
return False
def _is_control(char):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if char == "\t" or char == "\n" or char == "\r":
return False
cat = unicodedata.category(char)
if cat in ("Cc", "Cf"):
return True
return False
def _is_punctuation(char):
"""Checks whether `chars` is a punctuation character."""
cp = ord(char)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64)
or (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
return True
cat = unicodedata.category(char)
if cat.startswith("P"):
return True
return False
...@@ -6,4 +6,4 @@ nlohmann/json@v3.8.0 ...@@ -6,4 +6,4 @@ nlohmann/json@v3.8.0
blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze blaze,https://bitbucket.org/blaze-lib/blaze/get/f0755dea0e03.tar.gz -X header -DHEADER_DIR=blaze
half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969 half,https://github.com/pfultz2/half/archive/1.12.0.tar.gz -X header -H sha256:0a08660b68abb176ebc2a0cdf8de46e3182a7f46c66443bb80dbfaaec98cf969
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
\ No newline at end of file
...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag) ...@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library(migraphx add_library(migraphx
adjust_allocation.cpp adjust_allocation.cpp
analyze_streams.cpp analyze_streams.cpp
argument.cpp
auto_contiguous.cpp auto_contiguous.cpp
eliminate_common_subexpression.cpp eliminate_common_subexpression.cpp
decompose.cpp decompose.cpp
...@@ -121,6 +122,7 @@ register_migraphx_ops( ...@@ -121,6 +122,7 @@ register_migraphx_ops(
pad pad
pooling pooling
pow pow
prefix_scan_sum
prelu prelu
quant_convolution quant_convolution
quant_dot quant_dot
......
...@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t) ...@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{ {
switch(t) switch(t)
{ {
case migraphx_shape_tuple_type: return shape::tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x; case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
...@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t) ...@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
{ {
switch(t) switch(t)
{ {
case shape::tuple_type: return migraphx_shape_tuple_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \ #define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x; case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
......
...@@ -36,6 +36,7 @@ typedef enum { ...@@ -36,6 +36,7 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x, #define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs /// An enum to represent the different data type inputs
typedef enum { typedef enum {
migraphx_shape_tuple_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t; } migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES #undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
......
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
argument::argument(const shape& s) : m_shape(s)
{
auto buffer = make_shared_array<char>(s.bytes());
m_data = {[=]() mutable { return buffer.get(); }};
}
argument::argument(shape s, std::nullptr_t)
: m_shape(std::move(s)), m_data({[] { return nullptr; }})
{
}
argument::argument(const shape& s, const argument::data_t& d) : m_shape(s), m_data(d) {}
argument argument::load(const shape& s, char* buffer)
{
if(s.type() != shape::tuple_type)
return argument{s, buffer};
// Collect all shapes
std::unordered_map<std::size_t, shape> shapes;
{
// cppcheck-suppress variableScope
std::size_t i = 0;
fix([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
shapes[i] = ss;
i++;
}
else
{
for(auto&& child : ss.sub_shapes())
self(child);
}
})(s);
}
// Sort by type size
std::vector<std::size_t> order(shapes.size());
std::iota(order.begin(), order.end(), 0);
std::sort(order.begin(), order.end(), by(std::greater<>{}, [&](auto i) {
return shapes[i].type_size();
}));
// Compute offsets
std::unordered_map<std::size_t, std::size_t> offsets;
std::size_t offset = 0;
for(auto i : order)
{
offsets[i] = offset;
offset += shapes[i].bytes();
}
assert(offset == s.bytes());
// cppcheck-suppress variableScope
std::size_t i = 0;
return fix<argument>([&](auto self, auto ss) {
if(ss.sub_shapes().empty())
{
argument r{shapes[i], buffer + offsets[i]};
i++;
return r;
}
std::vector<argument> subs;
std::transform(ss.sub_shapes().begin(),
ss.sub_shapes().end(),
std::back_inserter(subs),
[&](auto child) { return self(child); });
return argument{subs};
})(s);
}
std::vector<shape> to_shapes(const std::vector<argument>& args)
{
std::vector<shape> shapes;
std::transform(args.begin(), args.end(), std::back_inserter(shapes), [](auto&& arg) {
return arg.get_shape();
});
return shapes;
}
argument::argument(const std::vector<argument>& args)
: m_shape(to_shapes(args)), m_data(data_t::from_args(args))
{
}
char* argument::data() const
{
assert(m_shape.type() != shape::tuple_type);
assert(not this->empty());
return m_data.get();
}
bool argument::empty() const { return not m_data.get and m_data.sub.empty(); }
const shape& argument::get_shape() const { return this->m_shape; }
argument argument::reshape(const shape& s) const { return {s, this->m_data}; }
argument::data_t argument::data_t::share() const
{
data_t result;
if(this->get)
{
auto self = std::make_shared<data_t>(*this);
result.get = [self]() mutable { return self->get(); };
}
std::transform(sub.begin(), sub.end(), std::back_inserter(result.sub), [](const auto& d) {
return d.share();
});
return result;
}
argument::data_t argument::data_t::from_args(const std::vector<argument>& args)
{
data_t result;
std::transform(args.begin(), args.end(), std::back_inserter(result.sub), [](auto&& arg) {
return arg.m_data;
});
return result;
}
argument argument::share() const { return {m_shape, m_data.share()}; }
std::vector<argument> argument::get_sub_objects() const
{
std::vector<argument> result;
assert(m_shape.sub_shapes().size() == m_data.sub.size());
std::transform(m_shape.sub_shapes().begin(),
m_shape.sub_shapes().end(),
m_data.sub.begin(),
std::back_inserter(result),
[](auto&& s, auto&& d) {
return argument{s, d};
});
return result;
}
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
...@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last) ...@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
return -n; return -n;
} }
void dead_code_elimination::apply(module& p) const void dead_code_elimination::apply(program& p) const { p.remove_unused_modules(); }
void dead_code_elimination::apply(module& m) const
{ {
auto last = std::prev(p.end()); auto last = std::prev(m.end());
for(auto ins : iterator_for(p)) for(auto ins : iterator_for(m))
{ {
// Skip the first instruction, since we always process the previous // Skip the first instruction, since we always process the previous
// instruction // instruction
if(ins == p.begin()) if(ins == m.begin())
continue; continue;
const auto i = std::prev(ins); const auto i = std::prev(ins);
// Skip the last instruction // Skip the last instruction
...@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const ...@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const
if(i->get_shape().elements() == 0 and i->name().front() != '@' and if(i->get_shape().elements() == 0 and i->name().front() != '@' and
i->name() != "undefined" and i->name() != "identity") i->name() != "undefined" and i->name() != "identity")
continue; continue;
assert(bidistance(p, i, last) > 0); assert(bidistance(m, i, last) > 0);
fix([&](auto self, auto leaf) { fix([&](auto self, auto leaf) {
if(not p.has_instruction(leaf)) if(not m.has_instruction(leaf))
return; return;
if(leaf->outputs().empty()) if(leaf->outputs().empty())
...@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const ...@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const
std::unordered_set<instruction_ref> args(leaf->inputs().begin(), std::unordered_set<instruction_ref> args(leaf->inputs().begin(),
leaf->inputs().end()); leaf->inputs().end());
leaf->clear_arguments(); leaf->clear_arguments();
assert(bidistance(p, last, leaf) < 0); assert(bidistance(m, last, leaf) < 0);
assert(leaf != ins); assert(leaf != ins);
p.move_instruction(leaf, p.end()); m.move_instruction(leaf, m.end());
for(auto arg : args) for(auto arg : args)
self(arg); self(arg);
} }
})(i); })(i);
} }
p.remove_instructions(std::next(last), p.end()); m.remove_instructions(std::next(last), m.end());
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -12,35 +12,44 @@ ...@@ -12,35 +12,44 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
namespace { namespace {
struct alpha_beta
{
float alpha = 0.0;
float beta = 0.0;
};
alpha_beta get_alpha_beta(const operation& op)
{
auto v = op.to_value();
return {v.at("alpha").to<float>(), v.at("beta").to<float>()};
}
struct find_dot_add struct find_dot_add
{ {
auto matcher() const { return match::name("dot")(match::nargs(3)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(3)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = get_alpha_beta(ins->get_operator());
if(not float_equal(dot.beta, 1) and
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto a_ins = ins->inputs()[0]; auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1]; auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1)) if(not float_equal(dot.alpha, 1))
{ {
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}), make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha); alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
auto dot_ins = p.insert_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins); auto dot_ins = p.insert_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
auto c_ins = ins->inputs()[2]; auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1)) if(not float_equal(dot.beta, 1))
{ {
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); auto beta = p.add_literal(literal{shape{c_ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast = p.insert_instruction( auto beta_broadcast = p.insert_instruction(
ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta); ins, make_op("multibroadcast", {{"output_lens", ins->get_shape().lens()}}), beta);
c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast); c_ins = p.insert_instruction(ins, make_op("mul"), c_ins, beta_broadcast);
...@@ -51,24 +60,24 @@ struct find_dot_add ...@@ -51,24 +60,24 @@ struct find_dot_add
struct find_dot_alpha struct find_dot_alpha
{ {
auto matcher() const { return match::name("dot")(match::nargs(2)); } auto matcher() const { return match::name("dot", "quant_dot")(match::nargs(2)); }
void apply(module& p, const match::matcher_result& r) const void apply(module& p, const match::matcher_result& r) const
{ {
auto ins = r.result; auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator()); auto dot = get_alpha_beta(ins->get_operator());
auto a_ins = ins->inputs()[0]; auto a_ins = ins->inputs()[0];
auto b_ins = ins->inputs()[1]; auto b_ins = ins->inputs()[1];
if(not float_equal(dot.alpha, 1)) if(not float_equal(dot.alpha, 1))
{ {
auto alpha = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.alpha}}); auto alpha = p.add_literal(literal{shape{a_ins->get_shape().type()}, {dot.alpha}});
auto alpha_broadcast = p.insert_instruction( auto alpha_broadcast = p.insert_instruction(
ins, ins,
make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}), make_op("multibroadcast", {{"output_lens", a_ins->get_shape().lens()}}),
alpha); alpha);
a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast); a_ins = p.insert_instruction(ins, make_op("mul"), a_ins, alpha_broadcast);
} }
p.replace_instruction(ins, make_op("dot", {{"beta", 0}}), a_ins, b_ins); p.replace_instruction(ins, make_op(ins->name(), {{"beta", 0}}), a_ins, b_ins);
} }
}; };
......
...@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const ...@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const
{ {
if(ins->name()[0] == '@') if(ins->name()[0] == '@')
continue; continue;
if(ins->name() == "convert")
continue;
auto inputs = ins->inputs(); auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) { std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto i) {
if(types.count(i->get_shape().type()) == 0) if(types.count(i->get_shape().type()) == 0)
......
...@@ -7,6 +7,20 @@ ...@@ -7,6 +7,20 @@
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
template <class Iterator, class Output, class Predicate, class F>
void transform_if(Iterator start, Iterator last, Output out, Predicate pred, F f)
{
while(start != last)
{
if(pred(*start))
{
*out = f(*start);
++out;
}
++start;
}
}
template <class Iterator, class Output, class Predicate> template <class Iterator, class Output, class Predicate>
void group_by(Iterator start, Iterator last, Output out, Predicate pred) void group_by(Iterator start, Iterator last, Output out, Predicate pred)
{ {
......
...@@ -8,6 +8,7 @@ ...@@ -8,6 +8,7 @@
#include <functional> #include <functional>
#include <utility> #include <utility>
// clang-format off
namespace migraphx { namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
...@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS { ...@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS {
*/ */
struct argument : raw_data<argument> struct argument : raw_data<argument>
{ {
argument() {} argument() = default;
argument(const shape& s) : m_shape(s) argument(const shape& s);
{
auto buffer = make_shared_array<char>(s.bytes());
data = [=]() mutable { return buffer.get(); };
}
template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})> template <class F, MIGRAPHX_REQUIRES(std::is_pointer<decltype(std::declval<F>()())>{})>
argument(shape s, F d) argument(shape s, F d)
: data([f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }), : m_shape(std::move(s)),
m_shape(std::move(s)) m_data({[f = std::move(d)]() mutable { return reinterpret_cast<char*>(f()); }})
{ {
} }
template <class T> template <class T>
argument(shape s, T* d) argument(shape s, T* d)
: data([d] { return reinterpret_cast<char*>(d); }), m_shape(std::move(s)) : m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d); }})
{ {
} }
template <class T> template <class T>
argument(shape s, std::shared_ptr<T> d) argument(shape s, std::shared_ptr<T> d)
: data([d] { return reinterpret_cast<char*>(d.get()); }), m_shape(std::move(s)) : m_shape(std::move(s)), m_data({[d] { return reinterpret_cast<char*>(d.get()); }})
{ {
} }
argument(shape s, std::nullptr_t) : data([] { return nullptr; }), m_shape(std::move(s)) {} argument(shape s, std::nullptr_t);
argument(const std::vector<argument>& args);
static argument load(const shape& s, char* buffer);
/// Provides a raw pointer to the data /// Provides a raw pointer to the data
std::function<char*()> data = nullptr; char* data() const;
/// Whether data is available /// Whether data is available
bool empty() const { return not data; } bool empty() const;
const shape& get_shape() const { return this->m_shape; } const shape& get_shape() const;
argument reshape(const shape& s) const argument reshape(const shape& s) const;
{
argument self = *this;
return {s, [=]() mutable { return self.data(); }};
}
/// Make copy of the argument that is always sharing the data /// Make copy of the argument that is always sharing the data
argument share() const argument share() const;
{
auto self = std::make_shared<argument>(*this); std::vector<argument> get_sub_objects() const;
return {m_shape, [self]() mutable { return self->data(); }};
}
private: private:
struct data_t
{
std::function<char*()> get = nullptr;
std::vector<data_t> sub = {};
data_t share() const;
static data_t from_args(const std::vector<argument>& args);
};
argument(const shape& s, const data_t& d);
shape m_shape; shape m_shape;
data_t m_data{};
}; };
void migraphx_to_value(value& v, const argument& a); void migraphx_to_value(value& v, const argument& a);
...@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a); ...@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a);
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx } // namespace migraphx
// clang-format on
#endif #endif
...@@ -9,6 +9,7 @@ namespace migraphx { ...@@ -9,6 +9,7 @@ namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS { inline namespace MIGRAPHX_INLINE_NS {
struct module; struct module;
struct program;
/** /**
* Remove instructions where the output is not used. * Remove instructions where the output is not used.
...@@ -16,7 +17,8 @@ struct module; ...@@ -16,7 +17,8 @@ struct module;
struct dead_code_elimination struct dead_code_elimination
{ {
std::string name() const { return "dead_code_elimination"; } std::string name() const { return "dead_code_elimination"; }
void apply(module& p) const; void apply(module& m) const;
void apply(program& p) const;
}; };
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment