"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict",[](InferEngine&self){
// Return a dictionary containing references to the whole state of the module.
autostate_dict=self.state_dict();
py::liststate_dict_tp_all;
for(constauto&state_dict_tp:self.state_dict()){
py::dictresult;
for(constauto&[name,param]:state_dict){
for(constauto&[name,param]:state_dict_tp){
result[py::cast(name)]=infinicore::Tensor(param);
}
returnresult;
state_dict_tp_all.append(result);
}
returnstate_dict_tp_all;
})
.def("generate",[](InferEngine&self,py::objectinput_ids,py::objectposition_ids)->infinicore::Tensor{returnself.generate(input_ids.cast<infinicore::Tensor>(),position_ids.cast<infinicore::Tensor>());},"Run inference on all ranks with arbitrary arguments")
.def("reset_cache",&InferEngine::reset_cache,
py::arg("pos")=0,py::arg("async")=false,
"Reset the internal cache in all workers to a specific position (clears state between generations). "
.def("reset_cache",&InferEngine::reset_cache,py::arg("pos")=0,py::arg("async")=false,"Reset the internal cache in all workers to a specific position (clears state between generations). "
"By default, this is synchronous. If async=True, this becomes asynchronous (unstable - use with caution).");