Commit 824cfb81 authored by liucong's avatar liucong
Browse files

修改gpt2工程格式

parent eccad09e
......@@ -16,18 +16,10 @@ print("INFO: Parsing and compiling the model")
model = migraphx.parse_onnx("../Resource/GPT2_shici.onnx", map_input_dims=maxInput)
# 获取模型输入/输出节点信息
print("inputs:")
inputs = model.get_inputs()
for key,value in inputs.items():
print("{}:{}".format(key,value))
print("outputs:")
outputs = model.get_outputs()
for key,value in outputs.items():
print("{}:{}".format(key,value))
inputName="input"
inputShape=inputs[inputName].lens()
inputName = model.get_parameter_names()[0]
inputShape = inputs[inputName].lens()
# 编译
model.compile(t=migraphx.get_target("gpu"), device_id=0)
......
......@@ -29,10 +29,10 @@ ErrorCode GPT2::Initialize()
// 设置最大输入shape
migraphx::onnx_options onnx_options;
onnx_options.map_input_dims["input"]={1,1000};
onnx_options.map_input_dims["input"] = {1,1000};
// 加载模型
if(Exists(modelPath)==false)
if(!Exists(modelPath))
{
LOG_ERROR(stdout,"%s not exist!\n",modelPath.c_str());
return MODEL_NOT_EXIST;
......@@ -41,18 +41,8 @@ ErrorCode GPT2::Initialize()
LOG_INFO(stdout,"succeed to load model: %s\n",GetFileName(modelPath).c_str());
// 获取模型输入/输出节点信息
std::cout<<"inputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> inputs=net.get_inputs();
for(auto i:inputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::cout<<"outputs:"<<std::endl;
std::unordered_map<std::string, migraphx::shape> outputs=net.get_outputs();
for(auto i:outputs)
{
std::cout<<i.first<<":"<<i.second<<std::endl;
}
std::unordered_map<std::string, migraphx::shape> inputs = net.get_inputs();
std::unordered_map<std::string, migraphx::shape> outputs = net.get_outputs();
inputName=inputs.begin()->first;
inputShape=inputs.begin()->second;
......@@ -61,8 +51,8 @@ ErrorCode GPT2::Initialize()
// 编译模型
migraphx::compile_options options;
options.device_id=0; // 设置GPU设备,默认为0号设备
options.offload_copy=true; // 设置offload_copy
options.device_id = 0; // 设置GPU设备,默认为0号设备
options.offload_copy = true; // 设置offload_copy
net.compile(gpuTarget,options);
LOG_INFO(stdout,"succeed to compile model: %s\n",GetFileName(modelPath).c_str());
......@@ -120,11 +110,11 @@ long unsigned int GPT2::Inference(const std::vector<long unsigned int> &input_id
}
ErrorCode GPT2::Preprocessing(cuBERT::FullTokenizer tokenizer,
char *question,
std::vector<long unsigned int> &input_id)
char *question,
std::vector<long unsigned int> &input_id)
{
// 分词操作
int max_seq_length =1000;
int max_seq_length = 1000;
std::vector<std::string> tokens_question;
tokens_question.reserve(max_seq_length);
tokenizer.tokenize(question, &tokens_question, max_seq_length);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment