Commit 3e398429 authored by wooway777's avatar wooway777
Browse files

issue/99 - relocated pybind contents

parent a1f6e517
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "../../cache/kv_cache.hpp"
#include "../../debug_utils/hooks.hpp"
#include "../../llama/llama.hpp"
#include "../../llama/llama_attention.hpp"
#include "../../models/debug_utils/hooks.hpp"
#include "../../models/llama/llama.hpp"
#include "../../models/llama/llama_attention.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
using infinicore::Device;
......@@ -22,7 +22,8 @@ inline void bind_llama(py::module &m) {
// Bind HookRegistry
py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry")
.def(py::init<>())
.def("register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
.def(
"register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
// Convert Python callable to C++ function
self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) {
try {
......@@ -33,7 +34,8 @@ inline void bind_llama(py::module &m) {
throw;
}
});
}, py::arg("name"), py::arg("callback"))
},
py::arg("name"), py::arg("callback"))
.def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks);
......@@ -167,7 +169,8 @@ inline void bind_llama(py::module &m) {
}
}
return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}), py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
}),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) {
// Convert state_dict to Python dict with shape information
auto state_dict = model.state_dict();
......@@ -181,7 +184,8 @@ inline void bind_llama(py::module &m) {
}
return result;
})
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
.def(
"get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name
auto state_dict = model.state_dict();
auto it = state_dict.find(name);
......@@ -191,8 +195,10 @@ inline void bind_llama(py::module &m) {
return tensor;
}
throw std::runtime_error("Parameter '" + name + "' not found in model");
}, py::arg("name"))
.def("load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
},
py::arg("name"))
.def(
"load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
// Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) {
......@@ -201,9 +207,11 @@ inline void bind_llama(py::module &m) {
cpp_state_dict.emplace(key, convert_to_tensor(value, device));
}
model.load_state_dict(cpp_state_dict);
}, py::arg("state_dict"), py::arg("device"))
},
py::arg("state_dict"), py::arg("device"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
.def(
"forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
// Helper to extract C++ tensor from Python object
auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
......@@ -240,7 +248,8 @@ inline void bind_llama(py::module &m) {
std::vector<void *> *kv_caches_ptr = nullptr;
return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr);
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
},
py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
}
} // namespace infinilm::models::llama
......@@ -42,7 +42,7 @@ target("_infinilm")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs("csrc", { public = false })
add_includedirs("csrc/models/pybind11", { public = false })
add_includedirs("csrc/pybind11", { public = false })
add_includedirs("include", { public = false })
add_includedirs(INFINI_ROOT.."/include", { public = true })
-- spdlog is already included globally via add_includedirs at the top
......@@ -52,7 +52,7 @@ target("_infinilm")
-- Add Llama model files
add_files("csrc/models/*/*.cpp")
add_files("csrc/models/pybind11/models.cc")
add_files("csrc/pybind11/bindings.cc")
set_installdir("python/infinilm")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment