# model parallel + pipeline parallel demo import oneflow as flow from projects.GLM.tokenizer.glm_tokenizer import GLMChineseTokenzier from libai.utils import distributed as dist from projects.GLM.configs.glm_inference import cfg from projects.GLM.modeling_glm import GLMForConditionalGeneration from projects.GLM.utils.glm_loader import GLMLoaderHuggerFace from omegaconf import DictConfig import time # 只需简单配置并行方案 parallel_config = DictConfig( dict( data_parallel_size=1, tensor_parallel_size=2, pipeline_parallel_size=2, pipeline_num_layers=2 * 24 ) ) dist.setup_dist_util(parallel_config) tokenizer = GLMChineseTokenzier.from_pretrained("glm-10b-chinese") sbp = dist.get_nd_sbp([flow.sbp.broadcast, flow.sbp.broadcast]) placement = dist.get_layer_placement(0) loader = GLMLoaderHuggerFace( GLMForConditionalGeneration, cfg, "glm-10b-chinese", embedding_dropout_prob=0, attention_dropout_prob=0, output_dropout_prob=0, ) if dist.is_main_process(): print("请稍等,正在加载模型中...") model = loader.load() question = "" while True: if dist.is_main_process(): print("输入:") question = input("> ") else: question = None question = dist.broadcast_py_object(question, src=0) dist.synchronize() # print(question) if question.lower() == "退出": break input_ids = tokenizer.encode( [ question + " 回答: [gMASK]" ], return_tensors="of", ) inputs = {"input_ids": input_ids, "attention_mask": flow.ones(input_ids.size())} inputs = tokenizer.build_inputs_for_generation(inputs, max_gen_length=128) if dist.is_main_process(): print("正在生成内容...") # start_t = time.time() outputs = model.generate( inputs=inputs['input_ids'].to_global(sbp=sbp, placement=placement), position_ids=inputs['position_ids'].to_global(sbp=sbp, placement=placement), generation_attention_mask=inputs['generation_attention_mask'].to_global(sbp=sbp, placement=placement), max_length=128 ) # end_t = time.time() # if dist.is_main_process(): # print('model.generate: %s秒' % (end_t - start_t)) res = tokenizer.decode(outputs[0]) if dist.is_main_process(): print("> " + res) if dist.is_main_process(): print("> 再见")