Unverified Commit 5b5ff780 authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/135: 统一 `InferEngine::forward` 和 `Model::forward` 接口

parent 3a1c0a28
......@@ -73,20 +73,21 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
}
//------------------------------------------------------
// generate
// forward
//------------------------------------------------------
infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids) {
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
const auto &[input_ids, position_ids] = input;
// Trigger each worker to run inference
for (auto &worker : workers_) {
worker->run(std::vector<std::any>({input_ids, position_ids}));
worker->run({input_ids, position_ids});
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
return workers_[0]->get_output();
return {workers_[0]->get_output().logits};
}
//------------------------------------------------------
......
......@@ -12,6 +12,16 @@ namespace infinilm::engine {
class InferEngine {
public:
struct Input {
infinicore::Tensor input_ids;
infinicore::Tensor position_ids;
};
struct Output {
infinicore::Tensor logits;
};
// Updated constructor: accept CacheConfig instead of CacheType
InferEngine(
const InfinilmModel::Config &config,
......@@ -26,8 +36,7 @@ public:
std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> state_dict();
// Run a single forward pass on all workers and return the outputs from all ranks
infinicore::Tensor generate(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids);
Output forward(const Input &input);
// Reset the internal cache pos in all workers (clears state between generations)
void reset_cache(size_t pos = 0);
......
......@@ -86,7 +86,7 @@ std::unordered_map<std::string, infinicore::nn::Parameter> RankWorker::state_dic
//------------------------------------------------------
// run -- asynchronous
//------------------------------------------------------
void RankWorker::run(const std::vector<std::any> &args) {
void RankWorker::run(const InfinilmModel::Input &args) {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
......@@ -164,7 +164,7 @@ void RankWorker::close() {
//------------------------------------------------------
// get_output (thread safe)
//------------------------------------------------------
infinicore::Tensor RankWorker::get_output() {
InfinilmModel::Output RankWorker::get_output() {
std::lock_guard<std::mutex> lock(mutex_);
return output_;
}
......@@ -194,7 +194,7 @@ void RankWorker::thread_loop() {
Command local_cmd = Command::INIT;
std::string local_param_name;
infinicore::Tensor local_param;
std::vector<std::any> local_args;
InfinilmModel::Input local_args;
size_t local_reset_pos = 0;
cache::CacheConfig local_reset_config;
......
......@@ -36,7 +36,7 @@ public:
std::unordered_map<std::string, infinicore::nn::Parameter> state_dict();
// Submit a run (forward) job.
void run(const std::vector<std::any> &args);
void run(const InfinilmModel::Input &args);
// Reset the internal cache in the model (clears state between generations)
void reset_cache(size_t pos = 0);
......@@ -51,7 +51,7 @@ public:
void close();
// Thread-safe accessor for last output produced by RUN.
infinicore::Tensor get_output();
InfinilmModel::Output get_output();
std::string info() const;
......@@ -77,12 +77,12 @@ private:
// Task payloads (protected by mutex)
std::string pending_param_name_;
infinicore::Tensor pending_param_;
std::vector<std::any> pending_args_;
InfinilmModel::Input pending_args_;
size_t pending_reset_pos_ = 0;
cache::CacheConfig pending_cache_config_;
// Output (protected by mutex)
infinicore::Tensor output_;
InfinilmModel::Output output_;
// Thread sync
std::thread thread_;
......
......@@ -15,8 +15,24 @@ public:
virtual ~Config() = default;
};
struct Input {
/// Token IDs tensor of shape `[batch, seq_len]`.
infinicore::Tensor input_ids;
/// Position IDs tensor of shape `[batch, seq_len]` or `[seq_len]`.
infinicore::Tensor position_ids;
/// Optional model-level KV cache for incremental decoding. Defaults to `nullptr`.
void *kv_cache = nullptr;
};
struct Output {
/// Output tensor of shape [batch, seq_len, vocab_size].
infinicore::Tensor logits;
};
virtual ~InfinilmModel() = default;
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
virtual Output forward(const Input &input) const = 0;
// Optional: reset cache; default no-op for models without cache
virtual void reset_cache(size_t pos = 0) {}
virtual void reset_cache(const cache::CacheConfig &new_config, size_t pos = 0) = 0;
......
......@@ -24,9 +24,9 @@ LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
dtype, device);
}
infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
void *kv_cache) const {
LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
const auto &[input_ids, position_ids, kv_cache] = input;
// 1. Forward through base model to get hidden states
auto position_ids_device = position_ids->to(device_);
auto hidden_states = model_->forward(input_ids, position_ids_device, kv_cache);
......@@ -34,25 +34,7 @@ infinicore::Tensor LlamaForCausalLM::forward(const infinicore::Tensor &input_ids
// 2. Apply language modeling head to get logits
auto logits = lm_head_->forward(hidden_states);
return logits;
}
infinicore::Tensor LlamaForCausalLM::forward(std::vector<std::any> args) const {
if (args.size() < 2) {
throw std::invalid_argument("LlamaForCausalLM::forward requires at least 2 arguments: input_ids and position_ids");
}
// Extract input tensors from args
const auto &input_ids = std::any_cast<const infinicore::Tensor &>(args[0]);
const auto &position_ids = std::any_cast<const infinicore::Tensor &>(args[1]);
// Optional KV caches
std::vector<void *> *kv_caches = nullptr;
if (args.size() >= 3) {
kv_caches = std::any_cast<std::vector<void *> *>(args[2]);
}
return forward(input_ids, position_ids, kv_caches);
return {logits};
}
void LlamaForCausalLM::reset_cache(size_t pos) {
......
......@@ -37,16 +37,10 @@ public:
/**
* @brief Forward pass: compute language modeling logits
*
* @param input_ids Token IDs tensor of shape [batch, seq_len]
* @param position_ids Position IDs tensor of shape [batch, seq_len] or [seq_len]
* @param kv_cache Optional model-level KV cache for incremental decoding
* @return Logits tensor of shape [batch, seq_len, vocab_size]
* @param input Encapsulated input tensors and other parameters
* @return Output structure containing the result
*/
infinicore::Tensor forward(const infinicore::Tensor &input_ids,
const infinicore::Tensor &position_ids,
void *kv_cache = nullptr) const;
infinicore::Tensor forward(std::vector<std::any> args) const override;
Output forward(const Input &input) const;
// Reset internal cache position
void reset_cache(size_t pos = 0) override;
......
......@@ -84,7 +84,8 @@ inline void bind_dist_config(py::module &m) {
namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
infer_engine
.def(py::init([](const InfinilmModel::Config &cfg,
const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev,
......@@ -109,17 +110,22 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def(
"generate", [](InferEngine &self, py::object input_ids, py::object position_ids) -> infinicore::Tensor {
return self.generate(input_ids.cast<infinicore::Tensor>(), position_ids.cast<infinicore::Tensor>());
},
"Run inference on all ranks with arbitrary arguments")
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
.def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
});
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids) {
return new InferEngine::Input{input_ids, position_ids};
}),
py::arg("input_ids"), py::arg("position_ids"));
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
}
} // namespace infinilm::engine
......@@ -123,81 +123,6 @@ inline void bind_llama(py::module &m) {
return dir_list; });
// Note: Device is already bound in InfiniCore bindings, so we don't need to bind it here
// Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def("state_dict", [](const LlamaForCausalLM &model) {
// Return a dictionary containing references to the whole state of the module.
auto state_dict = model.state_dict();
py::dict result;
for (const auto &[name, param] : state_dict) {
result[py::cast(name)] = infinicore::Tensor(param);
}
return result;
})
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name
auto state_dict = model.state_dict();
auto it = state_dict.find(name);
if (it != state_dict.end()) {
// Parameter inherits from Tensor, cast to Tensor for pybind11
const infinicore::Tensor &tensor = it->second;
return tensor;
}
throw std::runtime_error("Parameter '" + name + "' not found in model"); }, py::arg("name"))
.def("load_state_dict", [](LlamaForCausalLM &model, py::dict state_dict) {
// Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) {
std::string key = item.first.cast<std::string>();
py::object value = item.second.cast<py::object>();
// Extract InfiniCore tensor from Python object
infinicore::Tensor tensor;
if (py::hasattr(value, "_underlying")) {
tensor = value.attr("_underlying").cast<infinicore::Tensor>();
} else {
tensor = value.cast<infinicore::Tensor>();
}
cpp_state_dict.emplace(key, tensor);
}
model.load_state_dict(cpp_state_dict); }, py::arg("state_dict"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("reset_cache", [](const LlamaForCausalLM &model, size_t pos = 0) {
// Reset the internal cache to prevent state from persisting between generations
model.model().reset_cache(pos); }, py::arg("pos") = 0, "Reset the internal cache to a specific position (clears state between generations)")
.def("forward", [](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_cache = py::none()) {
// Helper to extract C++ tensor from Python InfiniCore tensor
auto get_tensor = [](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(obj, "_underlying")) {
return obj.attr("_underlying").cast<infinicore::Tensor>();
}
// Try direct cast (in case it's already a C++ tensor)
return obj.cast<infinicore::Tensor>();
};
// Extract InfiniCore tensors from Python objects
auto infini_input_ids = get_tensor(input_ids);
auto infini_position_ids = get_tensor(position_ids);
// Handle kv_cache if provided (model-level DynamicCache)
void *kv_cache_ptr = nullptr;
if (!kv_cache.is_none()) {
// Try to extract DynamicCache from Python object
if (py::hasattr(kv_cache, "_underlying")) {
kv_cache_ptr = kv_cache.attr("_underlying").cast<void *>();
} else {
// Try direct cast
try {
kv_cache_ptr = kv_cache.cast<void *>();
} catch (...) {
// If conversion fails, pass nullptr (cache will be ignored)
kv_cache_ptr = nullptr;
}
}
}
return model.forward(infini_input_ids, infini_position_ids, kv_cache_ptr); }, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
}
} // namespace infinilm::models::llama
......@@ -108,10 +108,9 @@ class LlamaForCausalLM(GenerationMixin):
# self._model.forward(input_ids, position_ids, kv_caches)
# )
return infinicore.Tensor(
self._model.generate(
input_ids._underlying,
position_ids._underlying,
)
self._model.forward(
self._model.Input(input_ids._underlying, position_ids._underlying)
).logits
)
def __call__(self, input_ids, position_ids, *args, **kwargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment