Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
ff3bd8e6
Commit
ff3bd8e6
authored
May 12, 2021
by
Khalique Ahmed
Browse files
manual merge
parents
32b69ceb
c310bc5c
Changes
83
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
285 additions
and
474 deletions
+285
-474
CMakeLists.txt
CMakeLists.txt
+7
-1
Dockerfile
Dockerfile
+3
-1
Jenkinsfile
Jenkinsfile
+6
-0
examples/python_bert_squad_example/BERT-Squad.ipynb
examples/python_bert_squad_example/BERT-Squad.ipynb
+2
-3
examples/python_bert_squad_example/README.md
examples/python_bert_squad_example/README.md
+9
-5
examples/python_bert_squad_example/bert-squad-migraphx.py
examples/python_bert_squad_example/bert-squad-migraphx.py
+2
-3
examples/python_bert_squad_example/requirements_bertsquad.txt
...ples/python_bert_squad_example/requirements_bertsquad.txt
+3
-2
examples/python_bert_squad_example/run_onnx_squad.py
examples/python_bert_squad_example/run_onnx_squad.py
+20
-13
examples/python_bert_squad_example/tokenization.py
examples/python_bert_squad_example/tokenization.py
+0
-397
requirements.txt
requirements.txt
+1
-1
src/CMakeLists.txt
src/CMakeLists.txt
+2
-0
src/api/api.cpp
src/api/api.cpp
+2
-0
src/api/include/migraphx/migraphx.h
src/api/include/migraphx/migraphx.h
+1
-0
src/argument.cpp
src/argument.cpp
+144
-0
src/dead_code_elimination.cpp
src/dead_code_elimination.cpp
+11
-9
src/decompose.cpp
src/decompose.cpp
+23
-14
src/eliminate_data_type.cpp
src/eliminate_data_type.cpp
+2
-0
src/include/migraphx/algorithm.hpp
src/include/migraphx/algorithm.hpp
+14
-0
src/include/migraphx/argument.hpp
src/include/migraphx/argument.hpp
+30
-24
src/include/migraphx/dead_code_elimination.hpp
src/include/migraphx/dead_code_elimination.hpp
+3
-1
No files found.
CMakeLists.txt
View file @
ff3bd8e6
...
@@ -60,7 +60,11 @@ endif()
...
@@ -60,7 +60,11 @@ endif()
set
(
MIGRAPHX_ENABLE_CPU Off CACHE BOOL
""
)
set
(
MIGRAPHX_ENABLE_CPU Off CACHE BOOL
""
)
set
(
CMAKE_CXX_STANDARD_DEFAULT
""
)
set
(
CMAKE_CXX_STANDARD_DEFAULT
""
)
add_compile_options
(
-std=c++14
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
add_compile_options
(
-std=c++17
)
else
()
add_compile_options
(
-std=c++14
)
endif
()
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
include
(
EnableCompilerWarnings
)
include
(
EnableCompilerWarnings
)
...
@@ -187,6 +191,8 @@ rocm_enable_cppcheck(
...
@@ -187,6 +191,8 @@ rocm_enable_cppcheck(
definePrefix:*test/include/test.hpp
definePrefix:*test/include/test.hpp
useSmartPointer:*src/api/api.cpp
useSmartPointer:*src/api/api.cpp
useSmartPointer:*make_shared_array.hpp
useSmartPointer:*make_shared_array.hpp
constParameter:*src/targets/gpu/*.cpp
constParameter:*src/targets/gpu/*.hpp
FORCE
FORCE
INCONCLUSIVE
INCONCLUSIVE
RULE_FILE
RULE_FILE
...
...
Dockerfile
View file @
ff3bd8e6
...
@@ -74,7 +74,7 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
...
@@ -74,7 +74,7 @@ RUN cget -p $PREFIX install facebook/zstd@v1.4.5 -X subdir -DCMAKE_DIR=build/cma
RUN
cget
-p
$PREFIX
install
ccache@v4.1
RUN
cget
-p
$PREFIX
install
ccache@v4.1
# Install newer cmake for onnx runtime
# Install newer cmake for onnx runtime
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.13.
0
RUN
cget
-p
/opt/cmake
install
kitware/cmake@v3.13.
4
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_REPO=https://github.com/Microsoft/onnxruntime
ARG
ONNXRUNTIME_BRANCH=master
ARG
ONNXRUNTIME_BRANCH=master
...
@@ -86,6 +86,8 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
...
@@ -86,6 +86,8 @@ RUN git clone --single-branch --branch ${ONNXRUNTIME_BRANCH} --recursive ${ONNXR
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
ADD
tools/build_and_test_onnxrt.sh /onnxruntime/build_and_test_onnxrt.sh
RUN
PATH
=
/opt/cmake/bin:
$PATH
cget
-p
/usr/local
install
ROCmSoftwarePlatform/llvm-project-mlir@02078ce236ad90e3aec04c0c770ef5bfc99e49c2
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_FIND_DB_PATH=/tmp/miopen/find-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
MIOPEN_USER_DB_PATH=/tmp/miopen/user-db
ENV
LD_LIBRARY_PATH=$PREFIX/lib
ENV
LD_LIBRARY_PATH=$PREFIX/lib
...
...
Jenkinsfile
View file @
ff3bd8e6
...
@@ -94,6 +94,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
...
@@ -94,6 +94,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=release"
)
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=release"
)
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
stash
includes:
'build/*.deb'
,
name:
'migraphx-package'
}
}
},
mlir_debug:
rocmnode
(
'vega'
)
{
cmake_build
->
stage
(
'MLIR Debug'
)
{
def
sanitizers
=
"undefined"
def
debug_flags
=
"-g -O2 -fsanitize=${sanitizers} -fno-sanitize-recover=${sanitizers}"
cmake_build
(
"/opt/rocm/llvm/bin/clang++"
,
"-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}'"
)
}
}
}
def
onnxnode
(
name
,
body
)
{
def
onnxnode
(
name
,
body
)
{
...
...
examples/python_bert_squad_example/BERT-Squad.ipynb
View file @
ff3bd8e6
...
@@ -43,7 +43,7 @@
...
@@ -43,7 +43,7 @@
"from os import path\n",
"from os import path\n",
"import sys\n",
"import sys\n",
"\n",
"\n",
"import tokeniz
ation
\n",
"import tokeniz
ers
\n",
"from run_onnx_squad import *\n",
"from run_onnx_squad import *\n",
"\n",
"\n",
"import migraphx"
"import migraphx"
...
@@ -137,8 +137,7 @@
...
@@ -137,8 +137,7 @@
"outputs": [],
"outputs": [],
"source": [
"source": [
"vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n",
"vocab_file = os.path.join('uncased_L-12_H-768_A-12', 'vocab.txt')\n",
"tokenizer = tokenization.FullTokenizer(vocab_file=vocab_file,\n",
"tokenizer = tokenizers.BertWordPieceTokenizer(vocab_file)"
" do_lower_case=True)"
]
]
},
},
{
{
...
...
examples/python_bert_squad_example/README.md
View file @
ff3bd8e6
...
@@ -7,21 +7,25 @@ There are two ways to run the example:
...
@@ -7,21 +7,25 @@ There are two ways to run the example:
# Steps
# Steps
1) Install MIGraphX to your environment. Please follow the steps to build MIGraphX given at https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
1) Install MIGraphX to your environment. Please follow the steps to build MIGraphX given at https://github.com/ROCmSoftwarePlatform/AMDMIGraphX
2)
Install the requirements file
2)
Upgrade your pip3 to latest version
```
```
pip3 install -
r requirements_migraphx.txt
pip3 install -
-upgrade pip
```
```
3) Install
`unzip`
and fetch the uncased file (vocabulary):
3) Install the requirements file
```
pip3 install -r requirements_bertsquad.txt
```
4) Install
`unzip`
and fetch the uncased file (vocabulary):
```
```
apt-get install unzip
apt-get install unzip
wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
wget -q https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip
```
```
4
) Get BERT ONNX model (bertsquad-10.onnx):
5
) Get BERT ONNX model (bertsquad-10.onnx):
```
```
wget https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
wget https://github.com/onnx/models/raw/master/text/machine_comprehension/bert-squad/model/bertsquad-10.onnx
```
```
5
) Run the inference, it will compile and run the model on three questions and small data provided in
`inputs.json`
:
6
) Run the inference, it will compile and run the model on three questions and small data provided in
`inputs.json`
:
```
```
python3 bert-squad-migraphx.py
python3 bert-squad-migraphx.py
```
```
...
...
examples/python_bert_squad_example/bert-squad-migraphx.py
View file @
ff3bd8e6
...
@@ -5,7 +5,7 @@ import os.path
...
@@ -5,7 +5,7 @@ import os.path
from
os
import
path
from
os
import
path
import
sys
import
sys
import
tokeniz
ation
import
tokeniz
ers
from
run_onnx_squad
import
*
from
run_onnx_squad
import
*
import
migraphx
import
migraphx
...
@@ -30,8 +30,7 @@ n_best_size = 20
...
@@ -30,8 +30,7 @@ n_best_size = 20
max_answer_length
=
30
max_answer_length
=
30
vocab_file
=
os
.
path
.
join
(
'uncased_L-12_H-768_A-12'
,
'vocab.txt'
)
vocab_file
=
os
.
path
.
join
(
'uncased_L-12_H-768_A-12'
,
'vocab.txt'
)
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
vocab_file
,
tokenizer
=
tokenizers
.
BertWordPieceTokenizer
(
vocab_file
)
do_lower_case
=
True
)
# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input
# Use convert_examples_to_features method from run_onnx_squad to get parameters from the input
input_ids
,
input_mask
,
segment_ids
,
extra_data
=
convert_examples_to_features
(
input_ids
,
input_mask
,
segment_ids
,
extra_data
=
convert_examples_to_features
(
...
...
examples/python_bert_squad_example/requirements_bertsquad.txt
100644 → 100755
View file @
ff3bd8e6
tensorflow==
1.14
tensorflow==
2.4.0
onnxruntime
onnxruntime
tokenizers
\ No newline at end of file
examples/python_bert_squad_example/run_onnx_squad.py
100644 → 100755
View file @
ff3bd8e6
...
@@ -38,7 +38,8 @@ from timeit import default_timer as timer
...
@@ -38,7 +38,8 @@ from timeit import default_timer as timer
import
numpy
as
np
import
numpy
as
np
import
onnxruntime
as
onnxrt
import
onnxruntime
as
onnxrt
import
six
import
six
import
tokenization
from
tokenizers
import
BertWordPieceTokenizer
from
tokenizers
import
pre_tokenizers
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
RawResult
=
collections
.
namedtuple
(
"RawResult"
,
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
[
"unique_id"
,
"start_logits"
,
"end_logits"
])
...
@@ -70,9 +71,8 @@ class SquadExample(object):
...
@@ -70,9 +71,8 @@ class SquadExample(object):
def
__repr__
(
self
):
def
__repr__
(
self
):
s
=
[]
s
=
[]
s
.
append
(
"qas_id: %s"
%
(
tokenization
.
printable_text
(
self
.
qas_id
)))
s
.
append
(
"qas_id: %s"
%
(
self
.
qas_id
))
s
.
append
(
"question_text: %s"
%
s
.
append
(
"question_text: %s"
%
(
self
.
question_text
))
(
tokenization
.
printable_text
(
self
.
question_text
)))
s
.
append
(
"doc_tokens: [%s]"
%
(
" "
.
join
(
self
.
doc_tokens
)))
s
.
append
(
"doc_tokens: [%s]"
%
(
" "
.
join
(
self
.
doc_tokens
)))
if
self
.
start_position
:
if
self
.
start_position
:
s
.
append
(
"start_position: %d"
%
(
self
.
start_position
))
s
.
append
(
"start_position: %d"
%
(
self
.
start_position
))
...
@@ -130,7 +130,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -130,7 +130,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
unique_id
=
0
unique_id
=
0
for
(
example_index
,
example
)
in
enumerate
(
examples
):
for
(
example_index
,
example
)
in
enumerate
(
examples
):
query_tokens
=
tokenizer
.
tokeniz
e
(
example
.
question_text
)
query_tokens
=
tokenizer
.
encod
e
(
example
.
question_text
)
if
len
(
query_tokens
)
>
max_query_length
:
if
len
(
query_tokens
)
>
max_query_length
:
query_tokens
=
query_tokens
[
0
:
max_query_length
]
query_tokens
=
query_tokens
[
0
:
max_query_length
]
...
@@ -140,8 +140,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -140,8 +140,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
all_doc_tokens
=
[]
all_doc_tokens
=
[]
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
for
(
i
,
token
)
in
enumerate
(
example
.
doc_tokens
):
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
orig_to_tok_index
.
append
(
len
(
all_doc_tokens
))
sub_tokens
=
tokenizer
.
tokenize
(
token
)
sub_tokens
=
tokenizer
.
encode
(
token
,
add_special_tokens
=
False
)
for
sub_token
in
sub_tokens
:
for
sub_token
in
sub_tokens
.
tokens
:
tok_to_orig_index
.
append
(
i
)
tok_to_orig_index
.
append
(
i
)
all_doc_tokens
.
append
(
sub_token
)
all_doc_tokens
.
append
(
sub_token
)
...
@@ -172,7 +172,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -172,7 +172,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
segment_ids
=
[]
segment_ids
=
[]
tokens
.
append
(
"[CLS]"
)
tokens
.
append
(
"[CLS]"
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
for
token
in
query_tokens
:
for
token
in
query_tokens
.
tokens
:
tokens
.
append
(
token
)
tokens
.
append
(
token
)
segment_ids
.
append
(
0
)
segment_ids
.
append
(
0
)
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
...
@@ -192,7 +192,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
...
@@ -192,7 +192,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
tokens
.
append
(
"[SEP]"
)
tokens
.
append
(
"[SEP]"
)
segment_ids
.
append
(
1
)
segment_ids
.
append
(
1
)
input_ids
=
tokenizer
.
convert_tokens_to_ids
(
tokens
)
input_ids
=
[]
for
token
in
tokens
:
input_ids
.
append
(
tokenizer
.
token_to_id
(
token
))
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
# tokens are attended to.
...
@@ -437,9 +439,15 @@ def get_final_text(pred_text, orig_text, do_lower_case):
...
@@ -437,9 +439,15 @@ def get_final_text(pred_text, orig_text, do_lower_case):
# and `pred_text`, and check if they are the same length. If they are
# and `pred_text`, and check if they are the same length. If they are
# NOT the same length, the heuristic has failed. If they are the same
# NOT the same length, the heuristic has failed. If they are the same
# length, we assume the characters are one-to-one aligned.
# length, we assume the characters are one-to-one aligned.
tokenizer
=
tokenization
.
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
tokenizer
=
pre_tokenizers
.
Sequence
(
[
pre_tokenizers
.
Whitespace
(),
pre_tokenizers
.
Punctuation
()])
tok_text
=
" "
.
join
(
tokenizer
.
tokenize
(
orig_text
))
tok_text
=
[]
for
item
in
tokenizer
.
pre_tokenize_str
(
orig_text
):
tok_text
.
append
(
item
[
0
])
tok_text
=
" "
.
join
(
tok_text
)
start_position
=
tok_text
.
find
(
pred_text
)
start_position
=
tok_text
.
find
(
pred_text
)
if
start_position
==
-
1
:
if
start_position
==
-
1
:
...
@@ -559,8 +567,7 @@ def main():
...
@@ -559,8 +567,7 @@ def main():
sess_options
=
onnxrt
.
SessionOptions
()
sess_options
=
onnxrt
.
SessionOptions
()
sess_options
.
session_log_verbosity_level
=
args
.
log
sess_options
.
session_log_verbosity_level
=
args
.
log
tokenizer
=
tokenization
.
FullTokenizer
(
vocab_file
=
args
.
vocab_file
,
tokenizer
=
BertWordPieceTokenizer
(
vocab_file
)
do_lower_case
=
True
)
eval_examples
=
read_squad_examples
(
input_file
=
args
.
predict_file
)
eval_examples
=
read_squad_examples
(
input_file
=
args
.
predict_file
)
input_ids
,
input_mask
,
segment_ids
,
extra_data
=
\
input_ids
,
input_mask
,
segment_ids
,
extra_data
=
\
...
...
examples/python_bert_squad_example/tokenization.py
deleted
100644 → 0
View file @
32b69ceb
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
collections
import
re
import
unicodedata
import
six
import
tensorflow
as
tf
def
validate_case_matches_checkpoint
(
do_lower_case
,
init_checkpoint
):
"""Checks whether the casing config is consistent with the checkpoint name."""
# The casing has to be passed in by the user and there is no explicit check
# as to whether it matches the checkpoint. The casing information probably
# should have been stored in the bert_config.json file, but it's not, so
# we have to heuristically detect it to validate.
if
not
init_checkpoint
:
return
m
=
re
.
match
(
"^.*?([A-Za-z0-9_-]+)/bert_model.ckpt"
,
init_checkpoint
)
if
m
is
None
:
return
model_name
=
m
.
group
(
1
)
lower_models
=
[
"uncased_L-24_H-1024_A-16"
,
"uncased_L-12_H-768_A-12"
,
"multilingual_L-12_H-768_A-12"
,
"chinese_L-12_H-768_A-12"
]
cased_models
=
[
"cased_L-12_H-768_A-12"
,
"cased_L-24_H-1024_A-16"
,
"multi_cased_L-12_H-768_A-12"
]
is_bad_config
=
False
if
model_name
in
lower_models
and
not
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"False"
case_name
=
"lowercased"
opposite_flag
=
"True"
if
model_name
in
cased_models
and
do_lower_case
:
is_bad_config
=
True
actual_flag
=
"True"
case_name
=
"cased"
opposite_flag
=
"False"
if
is_bad_config
:
raise
ValueError
(
"You passed in `--do_lower_case=%s` with `--init_checkpoint=%s`. "
"However, `%s` seems to be a %s model, so you "
"should pass in `--do_lower_case=%s` so that the fine-tuning matches "
"how the model was pre-training. If this error is wrong, please "
"just comment out this check."
%
(
actual_flag
,
init_checkpoint
,
model_name
,
case_name
,
opposite_flag
))
def
convert_to_unicode
(
text
):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
elif
isinstance
(
text
,
unicode
):
return
text
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
printable_text
(
text
):
"""Returns text encoded in a way suitable for print or `tf.logging`."""
# These functions want `str` for both Python2 and Python3, but in one case
# it's a Unicode string and in the other it's a byte string.
if
six
.
PY3
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
bytes
):
return
text
.
decode
(
"utf-8"
,
"ignore"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
elif
six
.
PY2
:
if
isinstance
(
text
,
str
):
return
text
elif
isinstance
(
text
,
unicode
):
return
text
.
encode
(
"utf-8"
)
else
:
raise
ValueError
(
"Unsupported string type: %s"
%
(
type
(
text
)))
else
:
raise
ValueError
(
"Not running on Python2 or Python 3?"
)
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
index
=
0
with
tf
.
gfile
.
GFile
(
vocab_file
,
"r"
)
as
reader
:
while
True
:
token
=
convert_to_unicode
(
reader
.
readline
())
if
not
token
:
break
token
=
token
.
strip
()
vocab
[
token
]
=
index
index
+=
1
return
vocab
def
convert_by_vocab
(
vocab
,
items
):
"""Converts a sequence of [tokens|ids] using the vocab."""
output
=
[]
for
item
in
items
:
output
.
append
(
vocab
[
item
])
return
output
def
convert_tokens_to_ids
(
vocab
,
tokens
):
return
convert_by_vocab
(
vocab
,
tokens
)
def
convert_ids_to_tokens
(
inv_vocab
,
ids
):
return
convert_by_vocab
(
inv_vocab
,
ids
)
def
whitespace_tokenize
(
text
):
"""Runs basic whitespace cleaning and splitting on a piece of text."""
text
=
text
.
strip
()
if
not
text
:
return
[]
tokens
=
text
.
split
()
return
tokens
class
FullTokenizer
(
object
):
"""Runs end-to-end tokenziation."""
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
):
self
.
vocab
=
load_vocab
(
vocab_file
)
self
.
inv_vocab
=
{
v
:
k
for
k
,
v
in
self
.
vocab
.
items
()}
self
.
basic_tokenizer
=
BasicTokenizer
(
do_lower_case
=
do_lower_case
)
self
.
wordpiece_tokenizer
=
WordpieceTokenizer
(
vocab
=
self
.
vocab
)
def
tokenize
(
self
,
text
):
split_tokens
=
[]
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
return
split_tokens
def
convert_tokens_to_ids
(
self
,
tokens
):
return
convert_by_vocab
(
self
.
vocab
,
tokens
)
def
convert_ids_to_tokens
(
self
,
ids
):
return
convert_by_vocab
(
self
.
inv_vocab
,
ids
)
class
BasicTokenizer
(
object
):
"""Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
def
__init__
(
self
,
do_lower_case
=
True
):
"""Constructs a BasicTokenizer.
Args:
do_lower_case: Whether to lower case the input.
"""
self
.
do_lower_case
=
do_lower_case
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text."""
text
=
convert_to_unicode
(
text
)
text
=
self
.
_clean_text
(
text
)
# This was added on November 1st, 2018 for the multilingual and Chinese
# models. This is also applied to the English models now, but it doesn't
# matter since the English models were not trained on any Chinese data
# and generally don't have any Chinese data in them (there are Chinese
# characters in the vocabulary because Wikipedia does have some Chinese
# words in the English Wikipedia.).
text
=
self
.
_tokenize_chinese_chars
(
text
)
orig_tokens
=
whitespace_tokenize
(
text
)
split_tokens
=
[]
for
token
in
orig_tokens
:
if
self
.
do_lower_case
:
token
=
token
.
lower
()
token
=
self
.
_run_strip_accents
(
token
)
split_tokens
.
extend
(
self
.
_run_split_on_punc
(
token
))
output_tokens
=
whitespace_tokenize
(
" "
.
join
(
split_tokens
))
return
output_tokens
def
_run_strip_accents
(
self
,
text
):
"""Strips accents from a piece of text."""
text
=
unicodedata
.
normalize
(
"NFD"
,
text
)
output
=
[]
for
char
in
text
:
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Mn"
:
continue
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_run_split_on_punc
(
self
,
text
):
"""Splits punctuation on a piece of text."""
chars
=
list
(
text
)
i
=
0
start_new_word
=
True
output
=
[]
while
i
<
len
(
chars
):
char
=
chars
[
i
]
if
_is_punctuation
(
char
):
output
.
append
([
char
])
start_new_word
=
True
else
:
if
start_new_word
:
output
.
append
([])
start_new_word
=
False
output
[
-
1
].
append
(
char
)
i
+=
1
return
[
""
.
join
(
x
)
for
x
in
output
]
def
_tokenize_chinese_chars
(
self
,
text
):
"""Adds whitespace around any CJK character."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
self
.
_is_chinese_char
(
cp
):
output
.
append
(
" "
)
output
.
append
(
char
)
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
def
_is_chinese_char
(
self
,
cp
):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if
((
cp
>=
0x4E00
and
cp
<=
0x9FFF
)
or
#
(
cp
>=
0x3400
and
cp
<=
0x4DBF
)
or
#
(
cp
>=
0x20000
and
cp
<=
0x2A6DF
)
or
#
(
cp
>=
0x2A700
and
cp
<=
0x2B73F
)
or
#
(
cp
>=
0x2B740
and
cp
<=
0x2B81F
)
or
#
(
cp
>=
0x2B820
and
cp
<=
0x2CEAF
)
or
(
cp
>=
0xF900
and
cp
<=
0xFAFF
)
or
#
(
cp
>=
0x2F800
and
cp
<=
0x2FA1F
)):
#
return
True
return
False
def
_clean_text
(
self
,
text
):
"""Performs invalid character removal and whitespace cleanup on text."""
output
=
[]
for
char
in
text
:
cp
=
ord
(
char
)
if
cp
==
0
or
cp
==
0xfffd
or
_is_control
(
char
):
continue
if
_is_whitespace
(
char
):
output
.
append
(
" "
)
else
:
output
.
append
(
char
)
return
""
.
join
(
output
)
class
WordpieceTokenizer
(
object
):
"""Runs WordPiece tokenziation."""
def
__init__
(
self
,
vocab
,
unk_token
=
"[UNK]"
,
max_input_chars_per_word
=
200
):
self
.
vocab
=
vocab
self
.
unk_token
=
unk_token
self
.
max_input_chars_per_word
=
max_input_chars_per_word
def
tokenize
(
self
,
text
):
"""Tokenizes a piece of text into its word pieces.
This uses a greedy longest-match-first algorithm to perform tokenization
using the given vocabulary.
For example:
input = "unaffable"
output = ["un", "##aff", "##able"]
Args:
text: A single token or whitespace separated tokens. This should have
already been passed through `BasicTokenizer.
Returns:
A list of wordpiece tokens.
"""
text
=
convert_to_unicode
(
text
)
output_tokens
=
[]
for
token
in
whitespace_tokenize
(
text
):
chars
=
list
(
token
)
if
len
(
chars
)
>
self
.
max_input_chars_per_word
:
output_tokens
.
append
(
self
.
unk_token
)
continue
is_bad
=
False
start
=
0
sub_tokens
=
[]
while
start
<
len
(
chars
):
end
=
len
(
chars
)
cur_substr
=
None
while
start
<
end
:
substr
=
""
.
join
(
chars
[
start
:
end
])
if
start
>
0
:
substr
=
"##"
+
substr
if
substr
in
self
.
vocab
:
cur_substr
=
substr
break
end
-=
1
if
cur_substr
is
None
:
is_bad
=
True
break
sub_tokens
.
append
(
cur_substr
)
start
=
end
if
is_bad
:
output_tokens
.
append
(
self
.
unk_token
)
else
:
output_tokens
.
extend
(
sub_tokens
)
return
output_tokens
def
_is_whitespace
(
char
):
"""Checks whether `chars` is a whitespace character."""
# \t, \n, and \r are technically contorl characters but we treat them
# as whitespace since they are generally considered as such.
if
char
==
" "
or
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
==
"Zs"
:
return
True
return
False
def
_is_control
(
char
):
"""Checks whether `chars` is a control character."""
# These are technically control characters but we count them as whitespace
# characters.
if
char
==
"
\t
"
or
char
==
"
\n
"
or
char
==
"
\r
"
:
return
False
cat
=
unicodedata
.
category
(
char
)
if
cat
in
(
"Cc"
,
"Cf"
):
return
True
return
False
def
_is_punctuation
(
char
):
"""Checks whether `chars` is a punctuation character."""
cp
=
ord
(
char
)
# We treat all non-letter/number ASCII as punctuation.
# Characters such as "^", "$", and "`" are not in the Unicode
# Punctuation class but we treat them as punctuation anyways, for
# consistency.
if
((
cp
>=
33
and
cp
<=
47
)
or
(
cp
>=
58
and
cp
<=
64
)
or
(
cp
>=
91
and
cp
<=
96
)
or
(
cp
>=
123
and
cp
<=
126
)):
return
True
cat
=
unicodedata
.
category
(
char
)
if
cat
.
startswith
(
"P"
):
return
True
return
False
requirements.txt
View file @
ff3bd8e6
src/CMakeLists.txt
View file @
ff3bd8e6
...
@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
...
@@ -7,6 +7,7 @@ include(CheckCXXLinkerFlag)
add_library
(
migraphx
add_library
(
migraphx
adjust_allocation.cpp
adjust_allocation.cpp
analyze_streams.cpp
analyze_streams.cpp
argument.cpp
auto_contiguous.cpp
auto_contiguous.cpp
eliminate_common_subexpression.cpp
eliminate_common_subexpression.cpp
decompose.cpp
decompose.cpp
...
@@ -121,6 +122,7 @@ register_migraphx_ops(
...
@@ -121,6 +122,7 @@ register_migraphx_ops(
pad
pad
pooling
pooling
pow
pow
prefix_scan_sum
prelu
prelu
quant_convolution
quant_convolution
quant_dot
quant_dot
...
...
src/api/api.cpp
View file @
ff3bd8e6
...
@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t)
...
@@ -49,6 +49,7 @@ shape::type_t to_shape_type(migraphx_shape_datatype_t t)
{
{
switch
(
t
)
switch
(
t
)
{
{
case
migraphx_shape_tuple_type
:
return
shape
::
tuple_type
;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x;
case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
)
...
@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
...
@@ -61,6 +62,7 @@ migraphx_shape_datatype_t to_shape_type(shape::type_t t)
{
{
switch
(
t
)
switch
(
t
)
{
{
case
shape
::
tuple_type
:
return
migraphx_shape_tuple_type
;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x;
case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT
)
...
...
src/api/include/migraphx/migraphx.h
View file @
ff3bd8e6
...
@@ -36,6 +36,7 @@ typedef enum {
...
@@ -36,6 +36,7 @@ typedef enum {
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) migraphx_shape_##x,
/// An enum to represent the different data type inputs
/// An enum to represent the different data type inputs
typedef
enum
{
typedef
enum
{
migraphx_shape_tuple_type
,
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
MIGRAPHX_SHAPE_VISIT_TYPES
(
MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
)
}
migraphx_shape_datatype_t
;
}
migraphx_shape_datatype_t
;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
...
...
src/argument.cpp
0 → 100755
View file @
ff3bd8e6
#include <migraphx/argument.hpp>
#include <migraphx/functional.hpp>
#include <unordered_map>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
argument
::
argument
(
const
shape
&
s
)
:
m_shape
(
s
)
{
auto
buffer
=
make_shared_array
<
char
>
(
s
.
bytes
());
m_data
=
{[
=
]()
mutable
{
return
buffer
.
get
();
}};
}
argument
::
argument
(
shape
s
,
std
::
nullptr_t
)
:
m_shape
(
std
::
move
(
s
)),
m_data
({[]
{
return
nullptr
;
}})
{
}
argument
::
argument
(
const
shape
&
s
,
const
argument
::
data_t
&
d
)
:
m_shape
(
s
),
m_data
(
d
)
{}
argument
argument
::
load
(
const
shape
&
s
,
char
*
buffer
)
{
if
(
s
.
type
()
!=
shape
::
tuple_type
)
return
argument
{
s
,
buffer
};
// Collect all shapes
std
::
unordered_map
<
std
::
size_t
,
shape
>
shapes
;
{
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
fix
([
&
](
auto
self
,
auto
ss
)
{
if
(
ss
.
sub_shapes
().
empty
())
{
shapes
[
i
]
=
ss
;
i
++
;
}
else
{
for
(
auto
&&
child
:
ss
.
sub_shapes
())
self
(
child
);
}
})(
s
);
}
// Sort by type size
std
::
vector
<
std
::
size_t
>
order
(
shapes
.
size
());
std
::
iota
(
order
.
begin
(),
order
.
end
(),
0
);
std
::
sort
(
order
.
begin
(),
order
.
end
(),
by
(
std
::
greater
<>
{},
[
&
](
auto
i
)
{
return
shapes
[
i
].
type_size
();
}));
// Compute offsets
std
::
unordered_map
<
std
::
size_t
,
std
::
size_t
>
offsets
;
std
::
size_t
offset
=
0
;
for
(
auto
i
:
order
)
{
offsets
[
i
]
=
offset
;
offset
+=
shapes
[
i
].
bytes
();
}
assert
(
offset
==
s
.
bytes
());
// cppcheck-suppress variableScope
std
::
size_t
i
=
0
;
return
fix
<
argument
>
([
&
](
auto
self
,
auto
ss
)
{
if
(
ss
.
sub_shapes
().
empty
())
{
argument
r
{
shapes
[
i
],
buffer
+
offsets
[
i
]};
i
++
;
return
r
;
}
std
::
vector
<
argument
>
subs
;
std
::
transform
(
ss
.
sub_shapes
().
begin
(),
ss
.
sub_shapes
().
end
(),
std
::
back_inserter
(
subs
),
[
&
](
auto
child
)
{
return
self
(
child
);
});
return
argument
{
subs
};
})(
s
);
}
std
::
vector
<
shape
>
to_shapes
(
const
std
::
vector
<
argument
>&
args
)
{
std
::
vector
<
shape
>
shapes
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
shapes
),
[](
auto
&&
arg
)
{
return
arg
.
get_shape
();
});
return
shapes
;
}
argument
::
argument
(
const
std
::
vector
<
argument
>&
args
)
:
m_shape
(
to_shapes
(
args
)),
m_data
(
data_t
::
from_args
(
args
))
{
}
char
*
argument
::
data
()
const
{
assert
(
m_shape
.
type
()
!=
shape
::
tuple_type
);
assert
(
not
this
->
empty
());
return
m_data
.
get
();
}
bool
argument
::
empty
()
const
{
return
not
m_data
.
get
and
m_data
.
sub
.
empty
();
}
const
shape
&
argument
::
get_shape
()
const
{
return
this
->
m_shape
;
}
argument
argument
::
reshape
(
const
shape
&
s
)
const
{
return
{
s
,
this
->
m_data
};
}
argument
::
data_t
argument
::
data_t
::
share
()
const
{
data_t
result
;
if
(
this
->
get
)
{
auto
self
=
std
::
make_shared
<
data_t
>
(
*
this
);
result
.
get
=
[
self
]()
mutable
{
return
self
->
get
();
};
}
std
::
transform
(
sub
.
begin
(),
sub
.
end
(),
std
::
back_inserter
(
result
.
sub
),
[](
const
auto
&
d
)
{
return
d
.
share
();
});
return
result
;
}
argument
::
data_t
argument
::
data_t
::
from_args
(
const
std
::
vector
<
argument
>&
args
)
{
data_t
result
;
std
::
transform
(
args
.
begin
(),
args
.
end
(),
std
::
back_inserter
(
result
.
sub
),
[](
auto
&&
arg
)
{
return
arg
.
m_data
;
});
return
result
;
}
argument
argument
::
share
()
const
{
return
{
m_shape
,
m_data
.
share
()};
}
std
::
vector
<
argument
>
argument
::
get_sub_objects
()
const
{
std
::
vector
<
argument
>
result
;
assert
(
m_shape
.
sub_shapes
().
size
()
==
m_data
.
sub
.
size
());
std
::
transform
(
m_shape
.
sub_shapes
().
begin
(),
m_shape
.
sub_shapes
().
end
(),
m_data
.
sub
.
begin
(),
std
::
back_inserter
(
result
),
[](
auto
&&
s
,
auto
&&
d
)
{
return
argument
{
s
,
d
};
});
return
result
;
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/dead_code_elimination.cpp
View file @
ff3bd8e6
...
@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
...
@@ -29,14 +29,16 @@ std::ptrdiff_t bidistance(const Range& r, Iterator start, Iterator last)
return
-
n
;
return
-
n
;
}
}
void
dead_code_elimination
::
apply
(
module
&
p
)
const
void
dead_code_elimination
::
apply
(
program
&
p
)
const
{
p
.
remove_unused_modules
();
}
void
dead_code_elimination
::
apply
(
module
&
m
)
const
{
{
auto
last
=
std
::
prev
(
p
.
end
());
auto
last
=
std
::
prev
(
m
.
end
());
for
(
auto
ins
:
iterator_for
(
p
))
for
(
auto
ins
:
iterator_for
(
m
))
{
{
// Skip the first instruction, since we always process the previous
// Skip the first instruction, since we always process the previous
// instruction
// instruction
if
(
ins
==
p
.
begin
())
if
(
ins
==
m
.
begin
())
continue
;
continue
;
const
auto
i
=
std
::
prev
(
ins
);
const
auto
i
=
std
::
prev
(
ins
);
// Skip the last instruction
// Skip the last instruction
...
@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const
...
@@ -46,9 +48,9 @@ void dead_code_elimination::apply(module& p) const
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
if
(
i
->
get_shape
().
elements
()
==
0
and
i
->
name
().
front
()
!=
'@'
and
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
i
->
name
()
!=
"undefined"
and
i
->
name
()
!=
"identity"
)
continue
;
continue
;
assert
(
bidistance
(
p
,
i
,
last
)
>
0
);
assert
(
bidistance
(
m
,
i
,
last
)
>
0
);
fix
([
&
](
auto
self
,
auto
leaf
)
{
fix
([
&
](
auto
self
,
auto
leaf
)
{
if
(
not
p
.
has_instruction
(
leaf
))
if
(
not
m
.
has_instruction
(
leaf
))
return
;
return
;
if
(
leaf
->
outputs
().
empty
())
if
(
leaf
->
outputs
().
empty
())
...
@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const
...
@@ -56,15 +58,15 @@ void dead_code_elimination::apply(module& p) const
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
std
::
unordered_set
<
instruction_ref
>
args
(
leaf
->
inputs
().
begin
(),
leaf
->
inputs
().
end
());
leaf
->
inputs
().
end
());
leaf
->
clear_arguments
();
leaf
->
clear_arguments
();
assert
(
bidistance
(
p
,
last
,
leaf
)
<
0
);
assert
(
bidistance
(
m
,
last
,
leaf
)
<
0
);
assert
(
leaf
!=
ins
);
assert
(
leaf
!=
ins
);
p
.
move_instruction
(
leaf
,
p
.
end
());
m
.
move_instruction
(
leaf
,
m
.
end
());
for
(
auto
arg
:
args
)
for
(
auto
arg
:
args
)
self
(
arg
);
self
(
arg
);
}
}
})(
i
);
})(
i
);
}
}
p
.
remove_instructions
(
std
::
next
(
last
),
p
.
end
());
m
.
remove_instructions
(
std
::
next
(
last
),
m
.
end
());
}
}
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
src/decompose.cpp
100755 → 100644
View file @
ff3bd8e6
...
@@ -12,35 +12,44 @@
...
@@ -12,35 +12,44 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
{
namespace
{
struct
alpha_beta
{
float
alpha
=
0.0
;
float
beta
=
0.0
;
};
alpha_beta
get_alpha_beta
(
const
operation
&
op
)
{
auto
v
=
op
.
to_value
();
return
{
v
.
at
(
"alpha"
).
to
<
float
>
(),
v
.
at
(
"beta"
).
to
<
float
>
()};
}
struct
find_dot_add
struct
find_dot_add
{
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
nargs
(
3
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
,
"quant_dot"
)(
match
::
nargs
(
3
));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
dot
=
any_cast
<
op
::
dot
>
(
ins
->
get_operator
());
auto
dot
=
get_alpha_beta
(
ins
->
get_operator
());
if
(
not
float_equal
(
dot
.
beta
,
1
)
and
not
contains
({
shape
::
float_type
,
shape
::
half_type
,
shape
::
double_type
},
ins
->
get_shape
().
type
()))
return
;
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
b_ins
=
ins
->
inputs
()[
1
];
if
(
not
float_equal
(
dot
.
alpha
,
1
))
if
(
not
float_equal
(
dot
.
alpha
,
1
))
{
{
auto
alpha
=
p
.
add_literal
(
literal
{
shape
{
ins
->
get_shape
().
type
()},
{
dot
.
alpha
}});
auto
alpha
=
p
.
add_literal
(
literal
{
shape
{
a_
ins
->
get_shape
().
type
()},
{
dot
.
alpha
}});
auto
alpha_broadcast
=
p
.
insert_instruction
(
auto
alpha_broadcast
=
p
.
insert_instruction
(
ins
,
ins
,
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
a_ins
->
get_shape
().
lens
()}}),
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
a_ins
->
get_shape
().
lens
()}}),
alpha
);
alpha
);
a_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
alpha_broadcast
);
a_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
alpha_broadcast
);
}
}
auto
dot_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"dot"
,
{{
"beta"
,
0
}}),
a_ins
,
b_ins
);
auto
dot_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
ins
->
name
()
,
{{
"beta"
,
0
}}),
a_ins
,
b_ins
);
auto
c_ins
=
ins
->
inputs
()[
2
];
auto
c_ins
=
ins
->
inputs
()[
2
];
if
(
not
float_equal
(
dot
.
beta
,
1
))
if
(
not
float_equal
(
dot
.
beta
,
1
))
{
{
auto
beta
=
p
.
add_literal
(
literal
{
shape
{
ins
->
get_shape
().
type
()},
{
dot
.
beta
}});
auto
beta
=
p
.
add_literal
(
literal
{
shape
{
c_
ins
->
get_shape
().
type
()},
{
dot
.
beta
}});
auto
beta_broadcast
=
p
.
insert_instruction
(
auto
beta_broadcast
=
p
.
insert_instruction
(
ins
,
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
ins
->
get_shape
().
lens
()}}),
beta
);
ins
,
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
ins
->
get_shape
().
lens
()}}),
beta
);
c_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
beta_broadcast
);
c_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
c_ins
,
beta_broadcast
);
...
@@ -51,24 +60,24 @@ struct find_dot_add
...
@@ -51,24 +60,24 @@ struct find_dot_add
struct
find_dot_alpha
struct
find_dot_alpha
{
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
match
::
nargs
(
2
));
}
auto
matcher
()
const
{
return
match
::
name
(
"dot"
,
"quant_dot"
)(
match
::
nargs
(
2
));
}
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module
&
p
,
const
match
::
matcher_result
&
r
)
const
{
{
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
dot
=
any_cast
<
op
::
dot
>
(
ins
->
get_operator
());
auto
dot
=
get_alpha_beta
(
ins
->
get_operator
());
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
a_ins
=
ins
->
inputs
()[
0
];
auto
b_ins
=
ins
->
inputs
()[
1
];
auto
b_ins
=
ins
->
inputs
()[
1
];
if
(
not
float_equal
(
dot
.
alpha
,
1
))
if
(
not
float_equal
(
dot
.
alpha
,
1
))
{
{
auto
alpha
=
p
.
add_literal
(
literal
{
shape
{
ins
->
get_shape
().
type
()},
{
dot
.
alpha
}});
auto
alpha
=
p
.
add_literal
(
literal
{
shape
{
a_
ins
->
get_shape
().
type
()},
{
dot
.
alpha
}});
auto
alpha_broadcast
=
p
.
insert_instruction
(
auto
alpha_broadcast
=
p
.
insert_instruction
(
ins
,
ins
,
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
a_ins
->
get_shape
().
lens
()}}),
make_op
(
"multibroadcast"
,
{{
"output_lens"
,
a_ins
->
get_shape
().
lens
()}}),
alpha
);
alpha
);
a_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
alpha_broadcast
);
a_ins
=
p
.
insert_instruction
(
ins
,
make_op
(
"mul"
),
a_ins
,
alpha_broadcast
);
}
}
p
.
replace_instruction
(
ins
,
make_op
(
"dot"
,
{{
"beta"
,
0
}}),
a_ins
,
b_ins
);
p
.
replace_instruction
(
ins
,
make_op
(
ins
->
name
()
,
{{
"beta"
,
0
}}),
a_ins
,
b_ins
);
}
}
};
};
...
...
src/eliminate_data_type.cpp
View file @
ff3bd8e6
...
@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const
...
@@ -13,6 +13,8 @@ void eliminate_data_type::apply(module& m) const
{
{
if
(
ins
->
name
()[
0
]
==
'@'
)
if
(
ins
->
name
()[
0
]
==
'@'
)
continue
;
continue
;
if
(
ins
->
name
()
==
"convert"
)
continue
;
auto
inputs
=
ins
->
inputs
();
auto
inputs
=
ins
->
inputs
();
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
std
::
transform
(
inputs
.
begin
(),
inputs
.
end
(),
inputs
.
begin
(),
[
&
](
auto
i
)
{
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
if
(
types
.
count
(
i
->
get_shape
().
type
())
==
0
)
...
...
src/include/migraphx/algorithm.hpp
100755 → 100644
View file @
ff3bd8e6
...
@@ -7,6 +7,20 @@
...
@@ -7,6 +7,20 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
template
<
class
Iterator
,
class
Output
,
class
Predicate
,
class
F
>
void
transform_if
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
,
F
f
)
{
while
(
start
!=
last
)
{
if
(
pred
(
*
start
))
{
*
out
=
f
(
*
start
);
++
out
;
}
++
start
;
}
}
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
template
<
class
Iterator
,
class
Output
,
class
Predicate
>
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
void
group_by
(
Iterator
start
,
Iterator
last
,
Output
out
,
Predicate
pred
)
{
{
...
...
src/include/migraphx/argument.hpp
View file @
ff3bd8e6
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <functional>
#include <functional>
#include <utility>
#include <utility>
// clang-format off
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS {
...
@@ -20,57 +21,61 @@ inline namespace MIGRAPHX_INLINE_NS {
*/
*/
struct
argument
:
raw_data
<
argument
>
struct
argument
:
raw_data
<
argument
>
{
{
argument
()
{}
argument
()
=
default
;
argument
(
const
shape
&
s
)
:
m_shape
(
s
)
argument
(
const
shape
&
s
);
{
auto
buffer
=
make_shared_array
<
char
>
(
s
.
bytes
());
data
=
[
=
]()
mutable
{
return
buffer
.
get
();
};
}
template
<
class
F
,
MIGRAPHX_REQUIRES
(
std
::
is_pointer
<
decltype
(
std
::
declval
<
F
>()())
>
{})
>
template
<
class
F
,
MIGRAPHX_REQUIRES
(
std
::
is_pointer
<
decltype
(
std
::
declval
<
F
>()())
>
{})
>
argument
(
shape
s
,
F
d
)
argument
(
shape
s
,
F
d
)
:
data
([
f
=
std
::
move
(
d
)]()
mutable
{
return
reinterpret_cast
<
char
*>
(
f
());
}),
:
m_shape
(
std
::
move
(
s
)),
m_shape
(
std
::
move
(
s
))
m_data
({[
f
=
std
::
move
(
d
)]()
mutable
{
return
reinterpret_cast
<
char
*>
(
f
());
}})
{
{
}
}
template
<
class
T
>
template
<
class
T
>
argument
(
shape
s
,
T
*
d
)
argument
(
shape
s
,
T
*
d
)
:
data
([
d
]
{
return
reinterpret_cast
<
char
*>
(
d
);
}
),
m_shape
(
std
::
move
(
s
)
)
:
m_shape
(
std
::
move
(
s
)),
m_
data
(
{
[
d
]
{
return
reinterpret_cast
<
char
*>
(
d
);
}
}
)
{
{
}
}
template
<
class
T
>
template
<
class
T
>
argument
(
shape
s
,
std
::
shared_ptr
<
T
>
d
)
argument
(
shape
s
,
std
::
shared_ptr
<
T
>
d
)
:
data
([
d
]
{
return
reinterpret_cast
<
char
*>
(
d
.
get
());
}
),
m_shape
(
std
::
move
(
s
)
)
:
m_shape
(
std
::
move
(
s
)),
m_
data
(
{
[
d
]
{
return
reinterpret_cast
<
char
*>
(
d
.
get
());
}
}
)
{
{
}
}
argument
(
shape
s
,
std
::
nullptr_t
)
:
data
([]
{
return
nullptr
;
}),
m_shape
(
std
::
move
(
s
))
{}
argument
(
shape
s
,
std
::
nullptr_t
);
argument
(
const
std
::
vector
<
argument
>&
args
);
static
argument
load
(
const
shape
&
s
,
char
*
buffer
);
/// Provides a raw pointer to the data
/// Provides a raw pointer to the data
std
::
function
<
char
*
()
>
data
=
nullptr
;
char
*
data
()
const
;
/// Whether data is available
/// Whether data is available
bool
empty
()
const
{
return
not
data
;
}
bool
empty
()
const
;
const
shape
&
get_shape
()
const
{
return
this
->
m_shape
;
}
const
shape
&
get_shape
()
const
;
argument
reshape
(
const
shape
&
s
)
const
argument
reshape
(
const
shape
&
s
)
const
;
{
argument
self
=
*
this
;
return
{
s
,
[
=
]()
mutable
{
return
self
.
data
();
}};
}
/// Make copy of the argument that is always sharing the data
/// Make copy of the argument that is always sharing the data
argument
share
()
const
argument
share
()
const
;
{
auto
self
=
std
::
make_shared
<
argument
>
(
*
this
);
std
::
vector
<
argument
>
get_sub_objects
()
const
;
return
{
m_shape
,
[
self
]()
mutable
{
return
self
->
data
();
}};
}
private:
private:
struct
data_t
{
std
::
function
<
char
*
()
>
get
=
nullptr
;
std
::
vector
<
data_t
>
sub
=
{};
data_t
share
()
const
;
static
data_t
from_args
(
const
std
::
vector
<
argument
>&
args
);
};
argument
(
const
shape
&
s
,
const
data_t
&
d
);
shape
m_shape
;
shape
m_shape
;
data_t
m_data
{};
};
};
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
void
migraphx_to_value
(
value
&
v
,
const
argument
&
a
);
...
@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a);
...
@@ -78,5 +83,6 @@ void migraphx_from_value(const value& v, argument& a);
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
}
// namespace migraphx
// clang-format on
#endif
#endif
src/include/migraphx/dead_code_elimination.hpp
View file @
ff3bd8e6
...
@@ -9,6 +9,7 @@ namespace migraphx {
...
@@ -9,6 +9,7 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
struct
module
;
struct
program
;
/**
/**
* Remove instructions where the output is not used.
* Remove instructions where the output is not used.
...
@@ -16,7 +17,8 @@ struct module;
...
@@ -16,7 +17,8 @@ struct module;
struct
dead_code_elimination
struct
dead_code_elimination
{
{
std
::
string
name
()
const
{
return
"dead_code_elimination"
;
}
std
::
string
name
()
const
{
return
"dead_code_elimination"
;
}
void
apply
(
module
&
p
)
const
;
void
apply
(
module
&
m
)
const
;
void
apply
(
program
&
p
)
const
;
};
};
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace MIGRAPHX_INLINE_NS
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment