Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b4ecca3e
"test/vscode:/vscode.git/clone" did not exist on "f548d82f4cc5257e63992c08b32fc860e777679e"
Commit
b4ecca3e
authored
Dec 05, 2023
by
Attila Dusnoki
Browse files
Add Llama-2 example
parent
a09dc502
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
727 additions
and
0 deletions
+727
-0
examples/transformers/README.md
examples/transformers/README.md
+3
-0
examples/transformers/python_llama2/README.md
examples/transformers/python_llama2/README.md
+69
-0
examples/transformers/python_llama2/gradio_app.py
examples/transformers/python_llama2/gradio_app.py
+60
-0
examples/transformers/python_llama2/gradio_reqirements.txt
examples/transformers/python_llama2/gradio_reqirements.txt
+25
-0
examples/transformers/python_llama2/llama2.ipynb
examples/transformers/python_llama2/llama2.ipynb
+361
-0
examples/transformers/python_llama2/requirements.txt
examples/transformers/python_llama2/requirements.txt
+27
-0
examples/transformers/python_llama2/txtgen.py
examples/transformers/python_llama2/txtgen.py
+182
-0
No files found.
examples/transformers/README.md
0 → 100644
View file @
b4ecca3e
# Transformers Inference Examples
-
[
Python Llama-2
](
./python_llama2
)
examples/transformers/python_llama2/README.md
0 → 100644
View file @
b4ecca3e
# Llama-2
This version was tested with
[
rocm 5.7
](
https://github.com/ROCmSoftwarePlatform/AMDMIGraphX/tree/rocm-5.7.0
)
revision.
## Jupyter notebook
There is a dedicated step-by-step notebook. See
[
llama2.ipynb
](
./llama2.ipynb
)
## Console application
To run the console application, follow these steps below.
Setup python environment
```
bash
# this will require the python venv to installed (e.g. apt install python3.8-venv)
python3
-m
venv ll2_venv
.
ll2_venv/bin/activate
```
```
bash
pip
install
-r
requirements.txt
```
Use MIGraphX Python Module
```
bash
export
PYTHONPATH
=
/opt/rocm/lib:
$PYTHONPATH
```
Llama2 requires logging to access the models
```
bash
huggingface-cli login
```
Get models with optimum
```
bash
optimum-cli
export
onnx
--model
meta-llama/Llama-2-7b-chat-hf models/llama-2-7b-chat-hf
--task
text-generation
--framework
pt
--library
transformers
--no-post-process
```
*Note: `models/llama-2-7b-chat-hf` will be used in the scripts.*
Run the text-generation script with the following example prompt:
```
bash
python txtgen.py
--prompt
"Where is Szeged?"
--log-process
```
*Note: The first run will compile the models and cache them to make subsequent runs faster.*
## Gradio application
Note: requires
`Console application`
to work
Install gradio dependencies
```
bash
pip
install
-r
gradio_requirements.txt
```
Usage
```
bash
python gradio_app.py
```
This will load the models (which can take several minutes), and when the setup is ready, starts a server on
`http://127.0.0.1:7860`
.
examples/transformers/python_llama2/gradio_app.py
0 → 100644
View file @
b4ecca3e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import
gradio
as
gr
from
txtgen
import
Llama2MGX
def
main
():
# Note: This will load the models, which can take several minutes
llama
=
Llama2MGX
(
1024
)
def
gr_wrapper
(
prompt
):
if
prompt
==
""
:
return
"Please provide a prompt."
input_ids
=
llama
.
tokenize
(
prompt
)
result
=
llama
.
generate
(
input_ids
)
# trim input prompt from result
result
=
result
[
len
(
prompt
)
+
2
:]
return
result
with
gr
.
Blocks
()
as
demo
:
gr
.
Markdown
(
"Start typing below and then click **Run** to see the output."
)
inp
=
gr
.
Textbox
(
placeholder
=
"Type something here..."
,
label
=
"Input prompt"
)
btn
=
gr
.
Button
(
"Run!"
)
out
=
gr
.
Textbox
(
placeholder
=
"The result will be displayed here"
,
label
=
"Response"
)
btn
.
click
(
fn
=
gr_wrapper
,
inputs
=
inp
,
outputs
=
out
)
demo
.
launch
()
if
__name__
==
"__main__"
:
main
()
examples/transformers/python_llama2/gradio_reqirements.txt
0 → 100644
View file @
b4ecca3e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
-f requirements.txt
gradio
\ No newline at end of file
examples/transformers/python_llama2/llama2.ipynb
0 → 100644
View file @
b4ecca3e
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#####################################################################################\n",
"# The MIT License (MIT)\n",
"#\n",
"# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.\n",
"#\n",
"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
"# of this software and associated documentation files (the \"Software\"), to deal\n",
"# in the Software without restriction, including without limitation the rights\n",
"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
"# copies of the Software, and to permit persons to whom the Software is\n",
"# furnished to do so, subject to the following conditions:\n",
"#\n",
"# The above copyright notice and this permission notice shall be included in\n",
"# all copies or substantial portions of the Software.\n",
"#\n",
"# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
"# THE SOFTWARE.\n",
"#####################################################################################"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Llama-2\n",
"\n",
"The following example will show how to run `Llama-2` with `MIGraphX`.\n",
"\n",
"Install the required dependencies."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install dependencies\n",
"%pip install accelerate huggingface_hub[cli] optimum[onnxruntime] transformers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use optimum to generate the onnx files.\n",
"But first, we need to login into huggingface to access it"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Please be careful and don't publish your token anywhere\n",
"!huggingface-cli login --token YOUR_TOKEN # from https://huggingface.co/settings/tokens"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we can export the models."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!optimum-cli export onnx --model meta-llama/Llama-2-7b-chat-hf models/llama-2-7b-chat-hf --task text-generation --framework pt --library transformers --no-post-process"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, it is time to load these models with python.\n",
"\n",
"First, we make sure that MIGraphX module is found in the python path."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"mgx_lib_path = \"/opt/rocm/lib/\" # or \"/code/AMDMIGraphX/build/lib/\"\n",
"if mgx_lib_path not in sys.path:\n",
" sys.path.append(mgx_lib_path)\n",
"import migraphx as mgx"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, a helper method to load and cache the models.\n",
"\n",
"This will use the `models/llama-2-7b-chat-hf` path. If you changed it, make sure to update here as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"# helper for model loading\n",
"def load_mgx_model(max_seq_len, shapes):\n",
" file = f\"models/llama-2-7b-chat-hf/model\"\n",
" print(f\"Loading {max_seq_len} seq-len version model from {file}\")\n",
" if os.path.isfile(f\"{file}-{max_seq_len}.mxr\"):\n",
" print(\"Found mxr, loading it...\")\n",
" model = mgx.load(f\"{file}-{max_seq_len}.mxr\", format=\"msgpack\")\n",
" elif os.path.isfile(f\"{file}.onnx\"):\n",
" print(\"Parsing from onnx file...\")\n",
" model = mgx.parse_onnx(f\"{file}.onnx\", map_input_dims=shapes)\n",
" model.compile(mgx.get_target(\"gpu\"))\n",
" print(\"Saving model to mxr file...\")\n",
" mgx.save(model, f\"{file}-{max_seq_len}.mxr\", format=\"msgpack\")\n",
" else:\n",
" print(\"No model found. Please download it and re-try.\")\n",
" sys.exit(1)\n",
" return model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"With that, we can load the models. This could take several minutes.\n",
"\n",
"We set the maximum sequence length at load time, if you change it, please reload the model as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"max_seq_len = 1024\n",
"decoder_model = load_mgx_model(\n",
" max_seq_len, {\n",
" \"input_ids\": [1, max_seq_len],\n",
" \"attention_mask\": [1, max_seq_len],\n",
" \"position_ids\": [1, max_seq_len]\n",
" })"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import the remaining packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from transformers import LlamaTokenizer\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Time to load the tokenizer from the original source."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model_id = \"meta-llama/Llama-2-7b-chat-hf\"\n",
"tokenizer = LlamaTokenizer.from_pretrained(model_id)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next, we will define all the steps one by one, to make the last step short and simple.\n",
"\n",
"The first step will be to tokenize the user prompt."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def tokenize(prompt):\n",
" return tokenizer(prompt, return_tensors=\"np\").input_ids"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Next step will be to convert it to match the model input.\n",
"\n",
"We will generate the attention mask and positions as well."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_input_features_for_input_ids(input_ids):\n",
" input_ids_len = len(input_ids[0])\n",
" padding_len = max_seq_len - input_ids_len\n",
" input_ids = np.hstack([input_ids, np.zeros(\n",
" (1, padding_len))]).astype(np.int64)\n",
" # 0 masked | 1 un-masked\n",
" attention_mask = np.array([1] * input_ids_len + [0] * padding_len).astype(\n",
" np.int64)\n",
" attention_mask = attention_mask[np.newaxis]\n",
" position_ids = np.arange(0, max_seq_len, dtype=np.int64)\n",
" position_ids = position_ids[np.newaxis]\n",
"\n",
" return (input_ids, attention_mask, position_ids)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will use these in the decoding step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def decode_step(input_ids, attention_mask, position_ids):\n",
" return np.array(\n",
" decoder_model.run({\n",
" \"input_ids\": input_ids,\n",
" \"attention_mask\": attention_mask,\n",
" \"position_ids\": position_ids\n",
" })[0])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The generated tokens will be decoded with the tokenizer."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def decode_tokens(generated_tokens):\n",
" return ''.join(tokenizer.decode(generated_tokens,\n",
" skip_special_tokens=True))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally the text generation part.\n",
"\n",
"With each decoding step, we will get the probabilities for the next token. We greedily get best match, add it to the decoded tokens and unmask it.\n",
"\n",
"If the token is end-of-sequence, we finished with the generation."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import clear_output\n",
"\n",
"def generate(input_ids):\n",
" start_timestep = len(input_ids[0]) - 1\n",
" input_ids, attention_mask, position_ids = get_input_features_for_input_ids(\n",
" input_ids)\n",
"\n",
" for timestep in range(start_timestep, max_seq_len):\n",
" # get logits for the current timestep\n",
" logits = decode_step(input_ids, attention_mask, position_ids)\n",
" # greedily get the highest probable token\n",
" new_token = np.argmax(logits[0][timestep])\n",
"\n",
" # add it to the tokens and unmask it\n",
" input_ids[0][timestep + 1] = new_token\n",
" attention_mask[0][timestep + 1] = 1\n",
"\n",
" decoded_tokens = decode_tokens(input_ids[0][:timestep+2])\n",
" clear_output(wait=True)\n",
" print(decoded_tokens)\n",
"\n",
" if new_token == tokenizer.eos_token_id:\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"And now, to put everything together and run the whole pipeline:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"prompt = \"Where is Szeged?\"\n",
"input_ids = tokenize(prompt)\n",
"generate(input_ids)"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
examples/transformers/python_llama2/requirements.txt
0 → 100644
View file @
b4ecca3e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
accelerate
huggingface_hub[cli]
optimum[onnxruntime]
transformers
\ No newline at end of file
examples/transformers/python_llama2/txtgen.py
0 → 100644
View file @
b4ecca3e
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
from
argparse
import
ArgumentParser
from
transformers
import
LlamaTokenizer
import
numpy
as
np
import
migraphx
as
mgx
import
os
import
sys
import
time
from
functools
import
wraps
# measurement helper
def
measure
(
fn
):
@
wraps
(
fn
)
def
measure_ms
(
*
args
,
**
kwargs
):
start_time
=
time
.
perf_counter_ns
()
result
=
fn
(
*
args
,
**
kwargs
)
end_time
=
time
.
perf_counter_ns
()
print
(
f
"Elapsed time for
{
fn
.
__name__
}
:
{
(
end_time
-
start_time
)
*
1e-6
:.
4
f
}
ms
\n
"
)
return
result
return
measure_ms
def
get_args
():
parser
=
ArgumentParser
()
parser
.
add_argument
(
"-p"
,
"--prompt"
,
type
=
str
,
required
=
True
,
help
=
"Input prompt"
,
)
parser
.
add_argument
(
"-l"
,
"--log-process"
,
action
=
"store_true"
,
help
=
"Print the current state of transcribing."
,
)
parser
.
add_argument
(
"-s"
,
"--max-seq-len"
,
type
=
int
,
choices
=
[
256
,
512
,
1024
,
2048
,
4096
],
default
=
1024
,
help
=
"Max sequence length the model can handle"
)
return
parser
.
parse_args
()
class
Llama2MGX
():
def
__init__
(
self
,
max_seq_len
=
1024
):
model_id
=
"meta-llama/Llama-2-7b-chat-hf"
self
.
max_seq_len
=
max_seq_len
print
(
"Load mgx model"
)
self
.
model
=
Llama2MGX
.
load_mgx_model
(
max_seq_len
,
{
"input_ids"
:
[
1
,
max_seq_len
],
"attention_mask"
:
[
1
,
max_seq_len
],
"position_ids"
:
[
1
,
max_seq_len
]
})
print
(
f
"Load AutoTokenizer model from
{
model_id
}
"
)
self
.
tokenizer
=
LlamaTokenizer
.
from_pretrained
(
model_id
)
@
staticmethod
@
measure
def
load_mgx_model
(
max_seq_len
,
shapes
):
file
=
"models/llama-2-7b-chat-hf/model"
print
(
f
"Loading
{
max_seq_len
}
seq-len version model from
{
file
}
"
)
if
os
.
path
.
isfile
(
f
"
{
file
}
-
{
max_seq_len
}
.mxr"
):
print
(
"Found mxr, loading it..."
)
model
=
mgx
.
load
(
f
"
{
file
}
-
{
max_seq_len
}
.mxr"
,
format
=
"msgpack"
)
elif
os
.
path
.
isfile
(
f
"
{
file
}
.onnx"
):
print
(
"Parsing from onnx file..."
)
model
=
mgx
.
parse_onnx
(
f
"
{
file
}
.onnx"
,
map_input_dims
=
shapes
)
model
.
compile
(
mgx
.
get_target
(
"gpu"
))
print
(
"Saving model to mxr file..."
)
mgx
.
save
(
model
,
f
"
{
file
}
-
{
max_seq_len
}
.mxr"
,
format
=
"msgpack"
)
else
:
print
(
"No model found. Please download it and re-try."
)
sys
.
exit
(
1
)
return
model
@
measure
def
tokenize
(
self
,
prompt
):
return
self
.
tokenizer
(
prompt
,
return_tensors
=
"np"
).
input_ids
@
measure
def
get_input_features_for_input_ids
(
self
,
input_ids
):
input_ids_len
=
len
(
input_ids
[
0
])
padding_len
=
self
.
max_seq_len
-
input_ids_len
input_ids
=
np
.
hstack
([
input_ids
,
np
.
zeros
((
1
,
padding_len
))]).
astype
(
np
.
int64
)
# 0 masked | 1 un-masked
attention_mask
=
np
.
array
([
1
]
*
input_ids_len
+
[
0
]
*
padding_len
).
astype
(
np
.
int64
)
attention_mask
=
attention_mask
[
np
.
newaxis
]
position_ids
=
np
.
arange
(
0
,
self
.
max_seq_len
,
dtype
=
np
.
int64
)
position_ids
=
position_ids
[
np
.
newaxis
]
return
(
input_ids
,
attention_mask
,
position_ids
)
@
measure
def
decode_step
(
self
,
input_ids
,
attention_mask
,
position_ids
):
return
np
.
array
(
self
.
model
.
run
({
"input_ids"
:
input_ids
,
"attention_mask"
:
attention_mask
,
"position_ids"
:
position_ids
})[
0
])
@
measure
def
decode_tokens
(
self
,
generated_tokens
,
skip_special_tokens
=
True
):
return
''
.
join
(
self
.
tokenizer
.
decode
(
generated_tokens
,
skip_special_tokens
=
skip_special_tokens
))
@
measure
def
generate
(
self
,
input_ids
,
log_process
=
False
):
start_timestep
=
len
(
input_ids
[
0
])
-
1
end_timestep
=
self
.
max_seq_len
input_ids
,
attention_mask
,
position_ids
=
self
.
get_input_features_for_input_ids
(
input_ids
)
print
(
"Generating response..."
)
for
timestep
in
range
(
start_timestep
,
self
.
max_seq_len
):
# get logits for the current timestep
logits
=
self
.
decode_step
(
input_ids
,
attention_mask
,
position_ids
)
# greedily get the highest probable token
new_token
=
np
.
argmax
(
logits
[
0
][
timestep
])
# add it to the tokens and unmask it
input_ids
[
0
][
timestep
+
1
]
=
new_token
attention_mask
[
0
][
timestep
+
1
]
=
1
if
log_process
:
print
(
self
.
decode_tokens
(
input_ids
[
0
][:
timestep
+
2
]))
if
new_token
==
self
.
tokenizer
.
eos_token_id
:
end_timestep
=
timestep
+
1
break
return
self
.
decode_tokens
(
input_ids
[
0
][:
end_timestep
+
1
])
if
__name__
==
"__main__"
:
args
=
get_args
()
llama
=
Llama2MGX
(
args
.
max_seq_len
)
print
(
f
"Call tokenizer with
\"
{
args
.
prompt
}
\"
"
)
input_ids
=
llama
.
tokenize
(
args
.
prompt
)
result
=
llama
.
generate
(
input_ids
,
log_process
=
args
.
log_process
)
print
(
f
"Result text:
{
result
}
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment