inferencer.py 19.7 KB
Newer Older
Rayyyyy's avatar
Rayyyyy committed
1
2
3
4
import time
import os
import configparser
import argparse
Rayyyyy's avatar
Rayyyyy committed
5
6
# import torch
import asyncio
Rayyyyy's avatar
Rayyyyy committed
7

Rayyyyy's avatar
Rayyyyy committed
8
from loguru import logger
Rayyyyy's avatar
Rayyyyy committed
9
from aiohttp import web
Rayyyyy's avatar
Rayyyyy committed
10
# from multiprocessing import Value
11
from transformers import AutoModelForCausalLM, Autotokenzier
Rayyyyy's avatar
Rayyyyy committed
12
13
14



Rayyyyy's avatar
Rayyyyy committed
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
COMMON = {
    "<光合组织登记网址>": "https://www.hieco.com.cn/partner?from=timeline",
    "<官网>": "https://www.sugon.com/after_sale/policy?sh=1",
    "<平台联系方式>": "1、访问官网,根据您所在地地址联系平台人员,网址地址:https://www.sugon.com/about/contact;\n2、点击人工客服进行咨询;\n3、请您拨打中科曙光服务热线400-810-0466联系人工进行咨询。",
    "<购买与维修的咨询方法>": "1、确定付费处理,可以微信搜索'sugon中科曙光服务'小程序,选择'在线报修'业务\n2、先了解价格,可以微信搜索'sugon中科曙光服务'小程序,选择'其他咨询'业务\n3、请您拨打中科曙光服务热线400-810-0466",
    "<服务器续保流程>": "1、微信搜索'sugon中科曙光服务'小程序,选择'延保与登记'业务\n2、点击人工客服进行登记\n3、请您拨打中科曙光服务热线400-810-0466根据语音提示选择维保与购买",
    "<XC内外网OS网盘链接>": "【腾讯文档】XC内外网OS网盘链接:https://docs.qq.com/sheet/DTWtXbU1BZHJvWkJm",
    "<W360-G30机器,安装Win7使用的镜像链接>": "W360-G30机器,安装Win7使用的镜像链接:https://pan.baidu.com/s/1SjHqCP6kJ9KzdJEBZDEynw;提取码:x6m4",
    "<麒麟系统搜狗输入法下载链接>": "软件下载链接(百度云盘):链接:https://pan.baidu.com/s/18Iluvs4BOAfFET0yFMBeLQ,提取码:bhkf",
    "<X660 G45 GPU服务器拆解视频网盘链接>": "链接: https://pan.baidu.com/s/1RkRGh4XY1T2oYftGnjLp4w;提取码: v2qi",
    "<DS800,SANTRICITY存储IBM版本模拟器网盘链接>": "链接:https://pan.baidu.com/s/1euG9HGbPfrVbThEB8BX76g;提取码:o2ya",
    "<E80-D312(X680-G55)风冷整机组装说明下载链接>": "链接:https://pan.baidu.com/s/17KDpm-Z9lp01WGp9sQaQ4w;提取码:0802",
    "<X680 G55 风冷相关资料下载链接>": "链接:https://pan.baidu.com/s/1KQ-hxUIbTWNkc0xzrEQLjg;提取码:0802",
    "<R620 G51刷写EEPROM下载>": "下载链接如下:http://10.2.68.104/tools/bytedance/eeprom/",
    "<X7450A0服务器售后培训文件网盘链接>": "网盘下载:https://pan.baidu.com/s/1tZJIf_IeQLOWsvuOawhslQ?pwd=kgf1;提取码:kgf1",
    "<福昕阅读器补丁链接>": "补丁链接: https://pan.baidu.com/s/1QJQ1kHRplhhFly-vxJquFQ,提取码: aupx1",
    "<W330-H35A_22DB4/W3335HA安装win7网盘链接>": "硬盘链接: https://pan.baidu.com/s/1fDdGPH15mXiw0J-fMmLt6Q提取码: k97i",
    "<X680 G55服务器售后培训资料网盘链接>": "云盘连接下载:链接:https://pan.baidu.com/s/1gaok13DvNddtkmk6Q-qLYg?pwd=xyhb提取码:xyhb",
Rayyyyy's avatar
Rayyyyy committed
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
    "<展厅管理员>": "北京-穆淑娟18001053012\n天津-马书跃15720934870\n昆山-关天琪15304169908\n成都-贾小芳18613216313\n重庆-李子艺17347743273\n安阳-郭永军15824623085\n桐乡-李梦瑶18086537055\n青岛-陶祉伊15318733259",
    "<线上预约展厅>": "北京、天津、昆山、成都、重庆、安阳、桐乡、青岛",
    "<马华>": "联系人:马华,电话:13761751980,邮箱:china@pinbang.com",
    "<梁静>": "联系人:梁静,电话:18917566297,邮箱:ing.liang@omaten.com",
    "<徐斌>": "联系人:徐斌,电话:13671166044,邮箱:244898943@qq.com",
    "<俞晓枫>": "联系人:俞晓枫,电话13750869272,邮箱:857233013@qq.com",
    "<刘广鹏>": "联系人:刘广鹏,电话13321992411,邮箱:liuguangpeng@pinbang.com",
    "<马英伟>": "联系人:马英伟,电话:13260021849,邮箱:13260021849@163.com",
    "<杨洋>": "联系人:杨洋,电话15801203938,邮箱bing523888@163.com",
    "<展会合规要求>": "1.展品内容:展品内容需符合公司合规要求,展示内容需经过法务合规审查。\n2.文字材料内容:文字材料内容需符合公司合规要求,展示内容需经过法务合规审查。\n3.展品标签:展品标签内容需符合公司合规要求。\n4.礼品内容:礼品内容需符合公司合规要求。\n5.视频内容:视频内容需符合公司合规要求,展示内容需经过法务合规审查。\n6.讲解词内容:讲解词内容需符合公司合规要求,展示内容需经过法务合规审查。\n7.现场发放材料:现场发放的材料内容需符合公司合规要求。\n8.展示内容:整体展示内容需要经过法务合规审查。",
    "<展会质量>": "1.了解展会的组织者背景、往届展会的评价以及提供的服务支持,确保展会的专业性和高效性。\n.了解展会的规模、参观人数、行业影响力等因素,以判断展会是否能够提供足够的曝光度和商机。\n3.关注同行业其他竞争对手是否参展,以及他们的展位布置、展示内容等信息,以便制定自己的参展策略。\n4.展会的日期是否与公司的其他重要活动冲突,以及举办地点是否便于客户和合作伙伴的参观。\n5.销售部门会询问展会方提供的宣传渠道和推广服务,以及如何利用这些资源来提升公司及产品的知名度。\n6.记录展会期间的重要领导参观、商机线索、合作洽谈、公司拜访预约等信息,跟进后续商业机会。",
    "<摊位费规则>": "根据展位面积大小,支付相应费用。\n展位照明费:支付展位内的照明服务费。\n展位保安费:支付展位内的保安服务费。\n展位网络使用费:支付展位内网络使用的费用。\n展位电源使用费:支付展位内电源使用的费用。",
    "<展会主题要求>": "展会主题的确定需要符合公司产品和服务业务范围,以确保能够吸引目标客户群体。因此,确定展会主题时,需要考虑以下因素:\n专业性:展会的主题应确保专业性,符合行业特点和目标客户的需求。\n目标客户群体:展会的主题定位应考虑目标客户群体,确保能够吸引他们的兴趣。\n业务重点:展会的主题应突出公司的业务重点和优势,以便更好地推广公司的核心产品或服务。\n行业影响力:展会的主题定位需要考虑行业的最新发展趋势,以凸显公司的行业地位和影响力。\n往届展会经验:可以参考往届展会的主题定位,总结经验教训,以确定本届展会的主题。\n市场部意见:在确定展会主题时,应听取市场部的意见,确保主题符合公司的整体市场战略。\n领导意见:还需要考虑公司领导的意见,以确保展会主题符合公司的战略发展方向。",
    "<办理展商证注意事项>": "人员范围:除公司领导和同事需要办理展商证外,展会运营工作人员也需要办理。\n提前准备:展商证的办理需要提前进行,以确保摄影师、摄像师等工作人员可以提前入场进行布置。\n办理流程:需要熟悉展商证的办理流程,准备好相关材料,如身份证件等。\n数量需求:需要评估所需的展商证数量,避免数量不足或过多的情况。\n有效期限:展商证的有效期限需要注意,避免在展期内过期。\n存放安全:办理完的展商证需要妥善保管,避免丢失或被他人使用。\n使用规范:使用展商证时需要遵守展会相关规定,不得转让给他人使用。\n回收处理:展会结束后,需要及时回收展商证,避免泄露相关信息。",
    "<项目单价要求>": "请注意:无论是否年框供应商,项目单价都不得超过采购部制定的“2024常见活动项目标准单价”,此报价仅可内部使用,严禁外传",
    "<年框供应商细节表格>": "在线表格https://kdocs.cn/l/camwZE63frNw",
    "<年框供应商流程>": "1.需求方发出项目需求(大型项目需比稿)\n2.外协根据项目需求报价,提供需求方“预算单”(按照基准单价报价,如有发现不按单价情况,解除合同不再使用)\n3.需求方确认预算价格,并提交OA市场活动申请\n4.外协现场执行\n5.需求方现场验收,并签署验收单(物料、设备、人员等实际清单)\n6.外协出具结算单(金额与验收单一致,加盖公章)、结案报告、年框合同,作为报销凭证\n7.外协请需求方项目负责人填写“满意度调研表”(如无,会影响年度评价)\n8.需求方项目经理提交报销",
    "<市场活动结案报告内容>": "1.项目简介(时间、地点、参与人数等);2.最终会议安排;3.活动各环节现场图片;4.费用相关证明材料(如执行人员、物料照片);5.活动成效汇总;6.活动原始照片/视频网络链接",
    "<展板设计选择>": "1.去OA文档中心查找一些设计模板; 2. 联系专业的活动服务公司来协助设计",
    "<餐费标准>": "一般地区的餐饮费用规定为不超过300元/人(一顿正餐),特殊地区则为不超过400元/人(一顿正餐),特殊地区的具体规定请参照公司的《差旅费管理制度》",
    "":"",
Rayyyyy's avatar
Rayyyyy committed
54
55
}

Rayyyyy's avatar
Rayyyyy committed
56

Rayyyyy's avatar
Rayyyyy committed
57
58
59
60
61
62
63
64
65
66
67
def build_history_messages(prompt, history, system: str = None):
    history_messages = []
    if system is not None and len(system) > 0:
        history_messages.append({'role': 'system', 'content': system})
    for item in history:
        history_messages.append({'role': 'user', 'content': item[0]})
        history_messages.append({'role': 'assistant', 'content': item[1]})
    history_messages.append({'role': 'user', 'content': prompt})
    return history_messages


68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def substitution(output_text):
    # 翻译特殊字符
    import re
    if isinstance(output_text, list):
        output_text = output_text[0]

    matchObj = re.split('.*(<.*>).*', output_text, re.M|re.I)
    if len(matchObj) > 1:
        obj = matchObj[1]
        replace_str = COMMON.get(obj)
        if replace_str:
            output_text = output_text.replace(obj, replace_str)
            logger.info(f"{obj} be replaced {replace_str}, after {output_text}")
    return output_text


Rayyyyy's avatar
Rayyyyy committed
84
85
86
87
class LLMInference:

    def __init__(self,
                 model,
88
                 tokenzier,
Rayyyyy's avatar
Rayyyyy committed
89
90
                 device: str = 'cuda',
                 ) -> None:
Rayyyyy's avatar
Rayyyyy committed
91

Rayyyyy's avatar
Rayyyyy committed
92
93
        self.device = device
        self.model = model
94
        self.tokenzier = tokenzier
Rayyyyy's avatar
update  
Rayyyyy committed
95

Rayyyyy's avatar
Rayyyyy committed
96
97
98
99
    def generate_response(self, prompt, history=[]):
        print("generate")
        output_text = ''
        error = ''
100
        time_tokenzier = time.time()
Rayyyyy's avatar
Rayyyyy committed
101
102
103
104
105
106
107
108
109
110
        try:
            output_text = self.chat(prompt, history)

        except Exception as e:
            error = str(e)
            logger.error(error)

        time_finish = time.time()

        logger.debug('output_text:{} \ntimecost {} '.format(output_text,
111
            time_finish - time_tokenzier))
Rayyyyy's avatar
Rayyyyy committed
112
113

        return output_text, error
Rayyyyy's avatar
Rayyyyy committed
114

115
    def chat(self, messages, history=[]):
Rayyyyy's avatar
update  
Rayyyyy committed
116
        '''单轮问答'''
Rayyyyy's avatar
Rayyyyy committed
117
        logger.info("****************** in chat ******************")
Rayyyyy's avatar
Rayyyyy committed
118
        try:
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            # transformers
            input_ids = self.tokenzier.apply_chat_template(
                messages, add_generation_prompt=True, return_tensors="pt").to('cuda')
            outputs = self.model.generate(
                input_ids,
                max_new_tokens=1024,
            )

            response = outputs[0][input_ids.shape[-1]:]
            generated_text = self.tokenzier.decode(response, skip_special_tokens=True)

            output_text = substitution(generated_text)
            logger.info(f"using transformers, output_text {output_text}")
            return output_text
Rayyyyy's avatar
Rayyyyy committed
133

Rayyyyy's avatar
Rayyyyy committed
134
        except Exception as e:
Rayyyyy's avatar
Rayyyyy committed
135
            logger.error(f"chat inference failed, {e}")
Rayyyyy's avatar
Rayyyyy committed
136

Rayyyyy's avatar
Rayyyyy committed
137

138
    def chat_stream(self, messages, history=[]):
Rayyyyy's avatar
Rayyyyy committed
139
        '''流式服务'''
Rayyyyy's avatar
Rayyyyy committed
140
        # HuggingFace
Rayyyyy's avatar
Rayyyyy committed
141
        logger.info("****************** in chat stream *****************")
Rayyyyy's avatar
Rayyyyy committed
142
        current_length = 0
143

Rayyyyy's avatar
Rayyyyy committed
144
        logger.info(f"stream_chat messages {messages}")
145
        for response, _, _ in self.model.stream_chat(self.tokenzier, messages, history=history,
Rayyyyy's avatar
Rayyyyy committed
146
147
148
                                                     max_length=1024,
                                                    past_key_values=None,
                                                    return_past_key_values=True):
Rayyyyy's avatar
Rayyyyy committed
149
            output_text = response[current_length:]
150
151
            output_text = substitution(output_text)
            logger.info(f"using transformers chat_stream, Prompt: {messages!r}, Generated text: {output_text!r}")
Rayyyyy's avatar
Rayyyyy committed
152

Rayyyyy's avatar
Rayyyyy committed
153
154
            yield output_text
            current_length = len(response)
Rayyyyy's avatar
Rayyyyy committed
155
156


157
def init_model(model_path, use_vllm=False, tensor_parallel_size=1):
Rayyyyy's avatar
Rayyyyy committed
158
    ## init models
159
    logger.info("Starting initial model of LLM")
Rayyyyy's avatar
Rayyyyy committed
160

161
162
163
    tokenzier = Autotokenzier.from_pretrained(model_path, trust_remote_code=True)
    if use_vllm:
        from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
Rayyyyy's avatar
Rayyyyy committed
164
165
166
        sampling_params = SamplingParams(temperature=1,
                                        top_p=0.95,
                                        max_tokens=1024,
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
                                        early_stopping=False,
                                        stop_token_ids=[tokenzier.eos_token_id]
                                        )
        # vLLM基础配置
        args = AsyncEngineArgs(model_path)
        args.worker_use_ray = False
        args.engine_use_ray = False
        args.tokenzier = model_path
        args.tensor_parallel_size = tensor_parallel_size
        args.trust_remote_code = True
        args.enforce_eager = True
        args.max_model_len = 1024
        args.dtype = 'float16'
        # 加载模型
        engine = AsyncLLMEngine.from_engine_args(args)
        return engine, tokenzier, sampling_params
Rayyyyy's avatar
Rayyyyy committed
183
    else:
184
        # huggingface
Rayyyyy's avatar
Rayyyyy committed
185
        model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().cuda().eval()
186
        return model, tokenzier, None
Rayyyyy's avatar
Rayyyyy committed
187

Rayyyyy's avatar
Rayyyyy committed
188

189
190
191
def hf_inference(bind_port, model, tokenzier, stream_chat):
    '''启动 hf Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
    llm_infer = LLMInference(model, tokenzier)
Rayyyyy's avatar
Rayyyyy committed
192

Rayyyyy's avatar
Rayyyyy committed
193
194
195
    async def inference(request):
        start = time.time()
        input_json = await request.json()
Rayyyyy's avatar
Rayyyyy committed
196

Rayyyyy's avatar
Rayyyyy committed
197
        prompt = input_json['query']
Rayyyyy's avatar
Rayyyyy committed
198
        history = input_json['history']
199
200
201
202

        messages = [{"role": "user", "content": prompt}]
        logger.info("****************** use transformers ******************")

Rayyyyy's avatar
Rayyyyy committed
203
        if stream_chat:
204
            text = await asyncio.to_thread(llm_infer.chat_stream, messages=messages, history=history)
Rayyyyy's avatar
Rayyyyy committed
205
        else:
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
            text = await asyncio.to_thread(llm_infer.chat, messages=messages, history=history)

        logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, time.time() - start))
        return web.json_response({'text': text})

    app = web.Application()
    app.add_routes([web.post('/hf_inference', inference)])
    web.run_app(app, host='0.0.0.0', port=bind_port)


def vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat):
    '''启动 Web 服务器,接收 HTTP 请求,并通过调用本地的 LLM 推理服务生成响应. '''
    import uuid

    from typing import AsyncGenerator
    from fastapi.responses import StreamingResponse

    async def inference(request):
        start = time.time()
        input_json = await request.json()

        prompt = input_json['query']
        # history = input_json['history']

        messages = [{"role": "user", "content": prompt}]

        logger.info("****************** use vllm ******************")
        ## generate template
        input_text = tokenzier.apply_chat_template(
                            messages, tokenize=False, add_generation_prompt=True)
        logger.info(f"The input_text is {input_text}")
        assert model is not None
        request_id = str(uuid.uuid4().hex)
        results_generator = model.generate(input_text, sampling_params=sampling_params, request_id=request_id)

        # Streaming case
        async def stream_results() -> AsyncGenerator[bytes, None]:
            async for request_output in results_generator:

                text_outputs = [output.text for output in request_output.outputs]
                ret = {"text": text_outputs}
                print(ret)
                # yield (json.dumps(ret) + "\0").encode("utf-8")
                yield web.json_response({'text': text})

        if stream_chat:
            logger.info("****************** in chat stream *****************")
            return StreamingResponse(stream_results())

        # Non-streaming case
        logger.info("****************** in chat ******************")
        final_output = None
        async for request_output in results_generator:
            # if await request.is_disconnected():
            #     # Abort the request if the client disconnects.
            #     await engine.abort(request_id)
            #     return Response(status_code=499)
            final_output = request_output

        assert final_output is not None

        text = [output.text for output in final_output.outputs]
Rayyyyy's avatar
Rayyyyy committed
268
        end = time.time()
269
        output_text = substitution(text)
Rayyyyy's avatar
Rayyyyy committed
270
        logger.debug('问题:{} 回答:{} \ntimecost {} '.format(prompt, text, end - start))
271
        return web.json_response({'text': output_text})
Rayyyyy's avatar
Rayyyyy committed
272
273

    app = web.Application()
274
    app.add_routes([web.post('/vllm_inference', inference)])
Rayyyyy's avatar
Rayyyyy committed
275
276
277
    web.run_app(app, host='0.0.0.0', port=bind_port)


Rayyyyy's avatar
Rayyyyy committed
278
279
280
281
282
283
284
def infer_test(args):
    config = configparser.ConfigParser()
    config.read(args.config_path)

    model_path = config['llm']['local_llm_path']
    use_vllm = config.getboolean('llm', 'use_vllm')
    tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
Rayyyyy's avatar
Rayyyyy committed
285
    stream_chat = config.getboolean('llm', 'stream_chat')
Rayyyyy's avatar
Rayyyyy committed
286
    logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")
Rayyyyy's avatar
Rayyyyy committed
287
288

    model, tokenzier = init_model(model_path, use_vllm, tensor_parallel_size)
Rayyyyy's avatar
Rayyyyy committed
289
    llm_infer = LLMInference(model,
Rayyyyy's avatar
Rayyyyy committed
290
                            tokenzier,
Rayyyyy's avatar
Rayyyyy committed
291
                            use_vllm=use_vllm)
Rayyyyy's avatar
Rayyyyy committed
292

Rayyyyy's avatar
Rayyyyy committed
293
    time_first = time.time()
Rayyyyy's avatar
Rayyyyy committed
294
    output_text = llm_infer.chat(args.query)
Rayyyyy's avatar
Rayyyyy committed
295
296
    time_second = time.time()
    logger.debug('问题:{} 回答:{} \ntimecost {} '.format(
Rayyyyy's avatar
Rayyyyy committed
297
        args.query, output_text, time_second - time_first))
Rayyyyy's avatar
Rayyyyy committed
298
299


Rayyyyy's avatar
Rayyyyy committed
300
301
302
303
304
305
306
307
308
def set_envs(dcu_ids):
    try:
        os.environ["CUDA_VISIBLE_DEVICES"] = dcu_ids
        logger.info(f"Set environment variable CUDA_VISIBLE_DEVICES to {dcu_ids}")
    except Exception as e:
        logger.error(f"{e}, but got {dcu_ids}")
        raise ValueError(f"{e}")


Rayyyyy's avatar
Rayyyyy committed
309
310
311
312
313
314
def parse_args():
    '''参数'''
    parser = argparse.ArgumentParser(
        description='Feature store for processing directories.')
    parser.add_argument(
        '--config_path',
Rayyyyy's avatar
update  
Rayyyyy committed
315
        default='../config.ini',
Rayyyyy's avatar
Rayyyyy committed
316
317
318
        help='config目录')
    parser.add_argument(
        '--query',
Rayyyyy's avatar
Rayyyyy committed
319
        default='写一首诗',
Rayyyyy's avatar
Rayyyyy committed
320
321
322
        help='提问的问题.')
    parser.add_argument(
        '--DCU_ID',
Rayyyyy's avatar
Rayyyyy committed
323
        type=str,
Rayyyyy's avatar
Rayyyyy committed
324
        default='6',
Rayyyyy's avatar
Rayyyyy committed
325
        help='设置DCU卡号,卡号之间用英文逗号隔开,输入样例:"0,1,2"')
Rayyyyy's avatar
Rayyyyy committed
326
327
328
329
330
331
    args = parser.parse_args()
    return args


def main():
    args = parse_args()
Rayyyyy's avatar
Rayyyyy committed
332
    set_envs(args.DCU_ID)
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
    # configs
    config = configparser.ConfigParser()
    config.read(args.config_path)
    bind_port = int(config['default']['bind_port'])
    model_path = config['llm']['local_llm_path']
    use_vllm = config.getboolean('llm', 'use_vllm')
    tensor_parallel_size = config.getint('llm', 'tensor_parallel_size')
    stream_chat = config.getboolean('llm', 'stream_chat')
    logger.info(f"Get params: model_path {model_path}, use_vllm {use_vllm}, tensor_parallel_size {tensor_parallel_size}, stream_chat {stream_chat}")

    model, tokenzier, sampling_params = init_model(model_path, use_vllm, tensor_parallel_size)
    if use_vllm:
        vllm_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
    else:
        hf_inference(bind_port, model, tokenzier, sampling_params, stream_chat)
Rayyyyy's avatar
Rayyyyy committed
348
    # infer_test(args)
Rayyyyy's avatar
Rayyyyy committed
349
350
351
352


if __name__ == '__main__':
    main()