"git@developer.sourcefind.cn:change/sglang.git" did not exist on "298509008451f7861a848d477829d5816eef12cd"
Commit 7afec310 authored by liucong's avatar liucong
Browse files

修改GPT2模型的最大shape

parent 741ac4ae
Doc/Images/GPT_03.png

31.4 KB | W: | H:

Doc/Images/GPT_03.png

15.7 KB | W: | H:

Doc/Images/GPT_03.png
Doc/Images/GPT_03.png
Doc/Images/GPT_03.png
Doc/Images/GPT_03.png
  • 2-up
  • Swipe
  • Onion skin
Doc/Images/GPT_04.png

35.8 KB | W: | H:

Doc/Images/GPT_04.png

35.3 KB | W: | H:

Doc/Images/GPT_04.png
Doc/Images/GPT_04.png
Doc/Images/GPT_04.png
Doc/Images/GPT_04.png
  • 2-up
  • Swipe
  • Onion skin
...@@ -9,7 +9,7 @@ vocab_file = os.path.join('../../../Resource/Models/NLP/GPT2', 'vocab_shici.txt' ...@@ -9,7 +9,7 @@ vocab_file = os.path.join('../../../Resource/Models/NLP/GPT2', 'vocab_shici.txt'
tokenizer = BertTokenizerFast(vocab_file, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]") tokenizer = BertTokenizerFast(vocab_file, sep_token="[SEP]", pad_token="[PAD]", cls_token="[CLS]")
# 设置最大输入shape # 设置最大输入shape
maxInput={"input":[1,1024]} maxInput={"input":[1,1000]}
# 加载模型 # 加载模型
print("INFO: Parsing and compiling the model") print("INFO: Parsing and compiling the model")
......
...@@ -116,7 +116,7 @@ export MIGRAPHX_DYNAMIC_SHAPE=1 ...@@ -116,7 +116,7 @@ export MIGRAPHX_DYNAMIC_SHAPE=1
如下所示,采用交互式界面,通过输入开头诗词,GPT2模型可以推理出后续的诗句。 如下所示,采用交互式界面,通过输入开头诗词,GPT2模型可以推理出后续的诗句。
<img src="./Doc/Images/GPT_04.png" style="zoom:90%;" align=middle> <img src="./Doc/Images/GPT_04.png" style="zoom:100%;" align=middle>
## 历史版本 ## 历史版本
......
...@@ -46,7 +46,7 @@ ErrorCode GPT2::Initialize(InitializationParameterOfNLP initParamOfNLPGPT2) ...@@ -46,7 +46,7 @@ ErrorCode GPT2::Initialize(InitializationParameterOfNLP initParamOfNLPGPT2)
// 设置最大输入shape // 设置最大输入shape
migraphx::onnx_options onnx_options; migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"]={1,1024}; onnx_options.map_input_dims["input"]={1,1000};
// 加载模型 // 加载模型
if(Exists(modelPath)==false) if(Exists(modelPath)==false)
...@@ -70,12 +70,7 @@ ErrorCode GPT2::Initialize(InitializationParameterOfNLP initParamOfNLPGPT2) ...@@ -70,12 +70,7 @@ ErrorCode GPT2::Initialize(InitializationParameterOfNLP initParamOfNLPGPT2)
options.device_id=0; // 设置GPU设备,默认为0号设备 options.device_id=0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; // 设置offload_copy options.offload_copy=true; // 设置offload_copy
net.compile(gpuTarget,options); net.compile(gpuTarget,options);
LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str()); LOG_INFO(logFile,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
// Run once by itself
migraphx::parameter_map inputData;
inputData[inputName]=migraphx::generate_argument(inputShape);
net.eval(inputData);
return SUCCESS; return SUCCESS;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment