.def("forward",[](InferEngine&self,constInferEngine::Input&input)->InferEngine::Output{returnself.forward(input);},"Run inference on all ranks with arbitrary arguments")
"forward",[](InferEngine&self,constInferEngine::Input&input)->InferEngine::Output{returnself.forward(input);},"Run inference on all ranks with arbitrary arguments")
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict",[](InferEngine&self){
py::liststate_dict_tp_all;
for(constauto&state_dict_tp:self.state_dict()){
py::dictresult;
for(constauto&[name,param]:state_dict_tp){
result[py::cast(name)]=infinicore::Tensor(param);
}
state_dict_tp_all.append(result);
}
returnstate_dict_tp_all;
})
})
.def("__repr__",[](constInferEngine&self){
.def("forward",[](InferEngine&self,constInferEngine::Input&input)->InferEngine::Output{returnself.forward(input);},"Run inference on all ranks with arbitrary arguments")
# Ideally this is solved by upgrading transformers. However, doing so causes version mismatch between transformers and mlu pytorch on devices with Phytium CPU. So a branch is temporarily used.