Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import datetime
import json
import os
import sys
from flask import Flask, request, jsonify
from flask_restful import Resource, Api
from megatron.inference.text_generation import generate_and_post_process
from megatron.inference.text_generation import beam_search_and_post_process
from megatron.core.inference.sampling_params import SamplingParams
from megatron.inference.endpoints.common import send_do_generate, send_do_beam_search, LOCK
from megatron.inference.endpoints.completions import MegatronCompletions
from megatron.inference.text_generation import beam_search_and_post_process
sys.path.append(
os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))
)
class MegatronGenerate(Resource):
def __init__(self, model):
self.model = model
def __init__(self, engine, args):
self.engine = engine
self.args = args
def put(self):
if not "prompts" in request.get_json():
......@@ -188,43 +195,39 @@ class MegatronGenerate(Resource):
)
else:
send_do_generate() # Tell other ranks we're doing generate
result = generate_and_post_process(
self.model,
prompts=prompts,
tokens_to_generate=tokens_to_generate,
return_output_log_probs=logprobs,
top_k_sampling=top_k,
top_p_sampling=top_p,
top_p_decay=top_p_decay,
top_p_bound=top_p_bound,
sampling_params = SamplingParams(
temperature=temperature,
add_BOS=add_BOS,
use_eod_token_for_early_termination=True,
stop_on_double_eol=stop_on_double_eol,
stop_on_eol=stop_on_eol,
prevent_newline_after_colon=prevent_newline_after_colon,
random_seed=random_seed,
top_k=top_k,
top_p=top_p,
return_segments=True,
return_log_probs=logprobs,
num_tokens_to_generate=tokens_to_generate,
)
result = list(
self.engine.generate(
prompts=prompts, common_inference_params=sampling_params
)
)
response_dict = {"text": [x.prompt + x.generated_text for x in result]}
if sampling_params.return_log_probs:
response_logprobs = [x.prompt_log_probs + x.generated_log_probs for x in
result]
response_dict["logprobs"] = response_logprobs
if sampling_params.return_segments:
response_dict["segments"] = [x.segments for x in result]
response, response_seg, response_logprobs = result[:3]
response = {
"text": response,
"segments": response_seg,
"logprobs": response_logprobs,
}
return jsonify(response)
return jsonify(response_dict)
except ValueError as ve:
return ve.args[0]
print("end time: ", datetime.datetime.now())
class MegatronServer(object):
def __init__(self, model):
def __init__(self, model, args=None):
self.app = Flask(__name__, static_url_path='')
api = Api(self.app)
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model])
api.add_resource(MegatronGenerate, '/api', resource_class_args=[model, args])
api.add_resource(MegatronCompletions, '/completions', resource_class_args=[model])
def run(self, url, port):
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment