Unverified Commit 5b5ff780 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/135: 统一 `InferEngine::forward` 和 `Model::forward` 接口

parent 3a1c0a28
...@@ -73,20 +73,21 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng ...@@ -73,20 +73,21 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
} }
//------------------------------------------------------ //------------------------------------------------------
// generate // forward
//------------------------------------------------------ //------------------------------------------------------
infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids, InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
const infinicore::Tensor &position_ids) { const auto &[input_ids, position_ids] = input;
// Trigger each worker to run inference // Trigger each worker to run inference
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->run(std::vector<std::any>({input_ids, position_ids})); worker->run({input_ids, position_ids});
} }
// Wait for all workers // Wait for all workers
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->wait(); worker->wait();
} }
return workers_[0]->get_output(); return {workers_[0]->get_output().logits};
} }
//------------------------------------------------------ //------------------------------------------------------
......
...@@ -12,6 +12,16 @@ namespace infinilm::engine { ...@@ -12,6 +12,16 @@ namespace infinilm::engine {
class InferEngine { class InferEngine {
public: public:
struct Input {
infinicore::Tensor input_ids;
infinicore::Tensor position_ids;
};
struct Output {
infinicore::Tensor logits;
};
// Updated constructor: accept CacheConfig instead of CacheType // Updated constructor: accept CacheConfig instead of CacheType
InferEngine( InferEngine(
const InfinilmModel::Config &config, const InfinilmModel::Config &config,
...@@ -26,8 +36,7 @@ public: ...@@ -26,8 +36,7 @@ public:
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict(); std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
// Run a single forward pass on all workers and return the outputs from all ranks // Run a single forward pass on all workers and return the outputs from all ranks
infinicore::Tensor generate(const infinicore::Tensor &input_ids, Output forward(const Input &input);
const infinicore::Tensor &position_ids);
// Reset the internal cache pos in all workers (clears state between generations) // Reset the internal cache pos in all workers (clears state between generations)
void reset_cache(size_t pos = 0); void reset_cache(size_t pos = 0);
......
...@@ -86,7 +86,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic ...@@ -86,7 +86,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic
//------------------------------------------------------ //------------------------------------------------------
// run -- asynchronous // run -- asynchronous
//------------------------------------------------------ //------------------------------------------------------
void RankWorker::run(const std::vector<std::any> &args) { void RankWorker::run(const InfinilmModel::Input &args) {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) { if (should_exit_) {
...@@ -164,7 +164,7 @@ void RankWorker::close() { ...@@ -164,7 +164,7 @@ void RankWorker::close() {
//------------------------------------------------------ //------------------------------------------------------
// get_output (thread safe) // get_output (thread safe)
//------------------------------------------------------ //------------------------------------------------------
infinicore::Tensor RankWorker::get_output() { InfinilmModel::Output RankWorker::get_output() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
return output_; return output_;
} }
...@@ -194,7 +194,7 @@ void RankWorker::thread_loop() { ...@@ -194,7 +194,7 @@ void RankWorker::thread_loop() {
Command local_cmd = Command::INIT; Command local_cmd = Command::INIT;
std::string local_param_name; std::string local_param_name;
infinicore::Tensor local_param; infinicore::Tensor local_param;
std::vector<std::any> local_args; InfinilmModel::Input local_args;
size_t local_reset_pos = 0; size_t local_reset_pos = 0;
cache::CacheConfig local_reset_config; cache::CacheConfig local_reset_config;
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
std::unordered_map<std::string, infinicore::nn::Parameter> state_dict(); std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();
// Submit a run (forward) job. // Submit a run (forward) job.
void run(const std::vector<std::any> &args); void run(const InfinilmModel::Input &args);
// Reset the internal cache in the model (clears state between generations) // Reset the internal cache in the model (clears state between generations)
void reset_cache(size_t pos = 0); void reset_cache(size_t pos = 0);
...@@ -51,7 +51,7 @@ public: ...@@ -51,7 +51,7 @@ public:
void close(); void close();
// Thread-safe accessor for last output produced by RUN. // Thread-safe accessor for last output produced by RUN.
infinicore::Tensor get_output(); InfinilmModel::Output get_output();
std::string info() const; std::string info() const;
...@@ -77,12 +77,12 @@ private: ...@@ -77,12 +77,12 @@ private:
// Task payloads (protected by mutex) // Task payloads (protected by mutex)
std::string pending_param_name_; std::string pending_param_name_;
infinicore::Tensor pending_param_; infinicore::Tensor pending_param_;
std::vector<std::any> pending_args_; InfinilmModel::Input pending_args_;
size_t pending_reset_pos_ = 0; size_t pending_reset_pos_ = 0;
cache::CacheConfig pending_cache_config_; cache::CacheConfig pending_cache_config_;
// Output (protected by mutex) // Output (protected by mutex)
infinicore::Tensor output_; InfinilmModel::Output output_;
// Thread sync // Thread sync
std::thread thread_; std::thread thread_;
......
...@@ -15,8 +15,24 @@ public: ...@@ -15,8 +15,24 @@ public:
virtual ~Config() = default; virtual ~Config() = default;
}; };
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
infinicore::Tensor input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore::Tensor position_ids;
/// Optional model-level KV cache for incremental decoding. Defaults to `nullptr`.
void *kv_cache = nullptr;
};
struct Output {
/// Output tensor of shape [batch, seq_len, vocab_size].
infinicore::Tensor logits;
};
virtual ~InfinilmModel() = default; virtual ~InfinilmModel() = default;
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0; virtual Output forward(const Input &input) const = 0;
// Optional: reset cache; default no-op for models without cache // Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {} virtual void reset_cache(size_t pos = 0) {}
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0; virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
......
...@@ -24,9 +24,9 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, ...@@ -24,9 +24,9 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
dtype, device); dtype, device);
} }
infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids, LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
const infinicore::Tensor &position_ids, const auto &[input_ids, position_ids, kv_cache] = input;
void *kv_cache) const {
// 1. Forward through base model to get hidden states // 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_); auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache); auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache);
...@@ -34,25 +34,7 @@ infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids ...@@ -34,25 +34,7 @@ infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids
// 2. Apply language modeling head to get logits // 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states); auto logits = lm_head_->forward(hidden_states);
return logits; return {logits};
}
infinicore::Tensor LlamaForCausalLM::forward(std::vector<std::any> args) const {
if (args.size() < 2) {
throw std::invalid_argument("LlamaForCausalLM::forward requires at least 2 arguments: input_ids and position_ids");
}
// Extract input tensors from args
const auto &input_ids = std::any_cast<const infinicore::Tensor &>(args[0]);
const auto &position_ids = std::any_cast<const infinicore::Tensor &>(args[1]);
// Optional KV caches
std::vector<void *> *kv_caches = nullptr;
if (args.size() >= 3) {
kv_caches = std::any_cast<std::vector<void *> *>(args[2]);
}
return forward(input_ids, position_ids, kv_caches);
} }
void LlamaForCausalLM::reset_cache(size_t pos) { void LlamaForCausalLM::reset_cache(size_t pos) {
......
...@@ -37,16 +37,10 @@ public: ...@@ -37,16 +37,10 @@ public:
/** /**
* @brief Forward pass: compute language modeling logits * @brief Forward pass: compute language modeling logits
* *
* @param input_ids Token IDs tensor of shape [batch, seq_len] * @param input Encapsulated input tensors and other parameters
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len] * @return Output structure containing the result
* @param kv_cache Optional model-level KV cache for incremental decoding
* @return Logits tensor of shape [batch, seq_len, vocab_size]
*/ */
infinicore::Tensor forward(const infinicore::Tensor &input_ids, Output forward(const Input &input) const;
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
infinicore::Tensor forward(std::vector<std::any> args) const override;
// Reset internal cache position // Reset internal cache position
void reset_cache(size_t pos = 0) override; void reset_cache(size_t pos = 0) override;
......
...@@ -84,7 +84,8 @@ inline void bind_dist_config(py::module &m) { ...@@ -84,7 +84,8 @@ inline void bind_dist_config(py::module &m) {
namespace infinilm::engine { namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) { inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine") py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
infer_engine
.def(py::init([](const InfinilmModel::Config &cfg, .def(py::init([](const InfinilmModel::Config &cfg,
const infinilm::engine::distributed::DistConfig &dist, const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
...@@ -109,17 +110,22 @@ inline void bind_infer_engine(py::module &m) { ...@@ -109,17 +110,22 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def( .def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
"generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
},
"Run inference on all ranks with arbitrary arguments")
.def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position") .def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration") .def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration") .def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
.def("__repr__", [](const InferEngine &self) { .def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
}); });
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids) {
return new InferEngine::Input{input_ids, position_ids};
}),
py::arg("input_ids"), py::arg("position_ids"));
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
} }
} // namespace infinilm::engine } // namespace infinilm::engine
...@@ -123,81 +123,6 @@ inline void bind_llama(py::module &m) { ...@@ -123,81 +123,6 @@ inline void bind_llama(py::module &m) {
return dir_list; }); return dir_list; });
// Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here // Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here
// Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def("state_dict", [](const LlamaForCausalLM &model) {
// Return a dictionary containing references to the whole state of the module.
auto state_dict = model.state_dict();
py::dict result;
for (const auto &[name, param] : state_dict) {
result[py::cast(name)] = infinicore::Tensor(param);
}
return result;
})
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name
auto state_dict = model.state_dict();
auto it = state_dict.find(name);
if (it != state_dict.end()) {
// Parameter inherits from Tensor, cast to Tensor for pybind11
const infinicore::Tensor &tensor = it->second;
return tensor;
}
throw std::runtime_error("Parameter '" + name + "' not found in model"); }, py::arg("name"))
.def("load_state_dict", [](LlamaForCausalLM &model, py::dict state_dict) {
// Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) {
std::string key = item.first.cast<std::string>();
py::object value = item.second.cast<py::object>();
// Extract InfiniCore tensor from Python object
infinicore::Tensor tensor;
if (py::hasattr(value, "_underlying")) {
tensor = value.attr("_underlying").cast<infinicore::Tensor>();
} else {
tensor = value.cast<infinicore::Tensor>();
}
cpp_state_dict.emplace(key, tensor);
}
model.load_state_dict(cpp_state_dict); }, py::arg("state_dict"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
// Reset the internal cache to prevent state from persisting between generations
model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
.def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) {
// Helper to extract C++ tensor from Python InfiniCore tensor
auto get_tensor = [](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(obj, "_underlying")) {
return obj.attr("_underlying").cast<infinicore::Tensor>();
}
// Try direct cast (in case it's already a C++ tensor)
return obj.cast<infinicore::Tensor>();
};
// Extract InfiniCore tensors from Python objects
auto infini_input_ids = get_tensor(input_ids);
auto infini_position_ids = get_tensor(position_ids);
// Handle kv_cache if provided (model-level DynamicCache)
void *kv_cache_ptr = nullptr;
if (!kv_cache.is_none()) {
// Try to extract DynamicCache from Python object
if (py::hasattr(kv_cache, "_underlying")) {
kv_cache_ptr = kv_cache.attr("_underlying").cast<void *>();
} else {
// Try direct cast
try {
kv_cache_ptr = kv_cache.cast<void *>();
} catch (...) {
// If conversion fails, pass nullptr (cache will be ignored)
kv_cache_ptr = nullptr;
}
}
}
return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); }, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -108,10 +108,9 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -108,10 +108,9 @@ class LlamaForCausalLM(GenerationMixin):
# self._model.forward(input_ids, position_ids, kv_caches) # self._model.forward(input_ids, position_ids, kv_caches)
# ) # )
return infinicore.Tensor( return infinicore.Tensor(
self._model.generate( self._model.forward(
input_ids._underlying, self._model.Input(input_ids._underlying, position_ids._underlying)
position_ids._underlying, ).logits
)
) )
def __call__(self, input_ids, position_ids, *args, **kwargs): def __call__(self, input_ids, position_ids, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment