Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import datetime import datetime
import json import json
import os
import sys
from flask import Flask, request, jsonify from flask import Flask, request, jsonify
from flask_restful import Resource, Api from flask_restful import Resource, Api
from megatron.inference.text_generation import generate_and_post_process from megatron.core.inference.sampling_params import SamplingParams
from megatron.inference.text_generation import beam_search_and_post_process
from megatron.inference.endpoints.common import send_do_generate, send_do_beam_search, LOCK from megatron.inference.endpoints.common import send_do_generate, send_do_beam_search, LOCK
from megatron.inference.endpoints.completions import MegatronCompletions from megatron.inference.endpoints.completions import MegatronCompletions
from megatron.inference.text_generation import beam_search_and_post_process
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)
class MegatronGenerate(Resource): class MegatronGenerate(Resource):
def __init__(self, model): def __init__(self, engine, args):
self.model = model self.engine = engine
self.args = args
def put(self): def put(self):
if not "prompts" in request.get_json(): if not "prompts" in request.get_json():
...@@ -188,43 +195,39 @@ class MegatronGenerate(Resource): ...@@ -188,43 +195,39 @@ class MegatronGenerate(Resource):
) )
else: else:
send_do_generate() # Tell other ranks we're doing generate send_do_generate() # Tell other ranks we're doing generate
result = generate_and_post_process(
self.model, sampling_params = SamplingParams(
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
temperature=temperature, temperature=temperature,
add_BOS=add_BOS, top_k=top_k,
use_eod_token_for_early_termination=True, top_p=top_p,
stop_on_double_eol=stop_on_double_eol, return_segments=True,
stop_on_eol=stop_on_eol, return_log_probs=logprobs,
prevent_newline_after_colon=prevent_newline_after_colon, num_tokens_to_generate=tokens_to_generate,
random_seed=random_seed,
) )
result = list(
self.engine.generate(
prompts=prompts, common_inference_params=sampling_params
)
)
response_dict = {"text": [x.prompt + x.generated_text for x in result]}
if sampling_params.return_log_probs:
response_logprobs = [x.prompt_log_probs + x.generated_log_probs for x in
result]
response_dict["logprobs"] = response_logprobs
if sampling_params.return_segments:
response_dict["segments"] = [x.segments for x in result]
response, response_seg, response_logprobs = result[:3] return jsonify(response_dict)
response = {
"text": response,
"segments": response_seg,
"logprobs": response_logprobs,
}
return jsonify(response)
except ValueError as ve: except ValueError as ve:
return ve.args[0] return ve.args[0]
print("end time: ", datetime.datetime.now())
class MegatronServer(object): class MegatronServer(object):
def __init__(self, model): def __init__(self, model, args=None):
self.app = Flask(__name__, static_url_path='') self.app = Flask(__name__, static_url_path='')
api = Api(self.app) api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model]) api.add_resource(MegatronGenerate, '/api', resource_class_args=[model, args])
api.add_resource(MegatronCompletions, '/completions', resource_class_args=[model]) api.add_resource(MegatronCompletions, '/completions', resource_class_args=[model])
def run(self, url, port): def run(self, url, port):
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment