Commit 3e398429 authored by wooway777's avatar wooway777
Browse files

issue/99 - relocated pybind contents

parent a1f6e517
#pragma once #pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "../../debug_utils/hooks.hpp" #include "../../models/debug_utils/hooks.hpp"
#include "../../llama/llama.hpp" #include "../../models/llama/llama.hpp"
#include "../../llama/llama_attention.hpp" #include "../../models/llama/llama_attention.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11; namespace py = pybind11;
using infinicore::Device; using infinicore::Device;
...@@ -22,7 +22,8 @@ inline void bind_llama(py::module &m) { ...@@ -22,7 +22,8 @@ inline void bind_llama(py::module &m) {
// Bind HookRegistry // Bind HookRegistry
py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry") py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry")
.def(py::init<>()) .def(py::init<>())
.def("register_hook", [](HookRegistry &self, const std::string &name, py::object callback) { .def(
"register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
// Convert Python callable to C++ function // Convert Python callable to C++ function
self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) { self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) {
try { try {
...@@ -33,7 +34,8 @@ inline void bind_llama(py::module &m) { ...@@ -33,7 +34,8 @@ inline void bind_llama(py::module &m) {
throw; throw;
} }
}); });
}, py::arg("name"), py::arg("callback")) },
py::arg("name"), py::arg("callback"))
.def("clear", &HookRegistry::clear) .def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks); .def("has_hooks", &HookRegistry::has_hooks);
...@@ -167,7 +169,8 @@ inline void bind_llama(py::module &m) { ...@@ -167,7 +169,8 @@ inline void bind_llama(py::module &m) {
} }
} }
return std::make_shared<LlamaForCausalLM>(config, device, dtype); return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}), py::arg("config"), py::arg("device"), py::arg("dtype") = py::none()) }),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) { .def("state_dict", [](const LlamaForCausalLM &model) {
// Convert state_dict to Python dict with shape information // Convert state_dict to Python dict with shape information
auto state_dict = model.state_dict(); auto state_dict = model.state_dict();
...@@ -181,7 +184,8 @@ inline void bind_llama(py::module &m) { ...@@ -181,7 +184,8 @@ inline void bind_llama(py::module &m) {
} }
return result; return result;
}) })
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) { .def(
"get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name // Get actual tensor parameter by name
auto state_dict = model.state_dict(); auto state_dict = model.state_dict();
auto it = state_dict.find(name); auto it = state_dict.find(name);
...@@ -191,8 +195,10 @@ inline void bind_llama(py::module &m) { ...@@ -191,8 +195,10 @@ inline void bind_llama(py::module &m) {
return tensor; return tensor;
} }
throw std::runtime_error("Parameter '" + name + "' not found in model"); throw std::runtime_error("Parameter '" + name + "' not found in model");
}, py::arg("name")) },
.def("load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) { py::arg("name"))
.def(
"load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
// Convert Python dict to C++ state_dict // Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict; std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) { for (auto item : state_dict) {
...@@ -201,9 +207,11 @@ inline void bind_llama(py::module &m) { ...@@ -201,9 +207,11 @@ inline void bind_llama(py::module &m) {
cpp_state_dict.emplace(key, convert_to_tensor(value, device)); cpp_state_dict.emplace(key, convert_to_tensor(value, device));
} }
model.load_state_dict(cpp_state_dict); model.load_state_dict(cpp_state_dict);
}, py::arg("state_dict"), py::arg("device")) },
py::arg("state_dict"), py::arg("device"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal) .def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) { .def(
"forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
// Helper to extract C++ tensor from Python object // Helper to extract C++ tensor from Python object
auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor { auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying // If it's already a Python InfiniCore tensor wrapper, extract underlying
...@@ -240,7 +248,8 @@ inline void bind_llama(py::module &m) { ...@@ -240,7 +248,8 @@ inline void bind_llama(py::module &m) {
std::vector<void *> *kv_caches_ptr = nullptr; std::vector<void *> *kv_caches_ptr = nullptr;
return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr); return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr);
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none()); },
py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -42,7 +42,7 @@ target("_infinilm") ...@@ -42,7 +42,7 @@ target("_infinilm")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs("csrc", { public = false }) add_includedirs("csrc", { public = false })
add_includedirs("csrc/models/pybind11", { public = false }) add_includedirs("csrc/pybind11", { public = false })
add_includedirs("include", { public = false }) add_includedirs("include", { public = false })
add_includedirs(INFINI_ROOT.."/include", { public = true }) add_includedirs(INFINI_ROOT.."/include", { public = true })
-- spdlog is already included globally via add_includedirs at the top -- spdlog is already included globally via add_includedirs at the top
...@@ -52,7 +52,7 @@ target("_infinilm") ...@@ -52,7 +52,7 @@ target("_infinilm")
-- Add Llama model files -- Add Llama model files
add_files("csrc/models/*/*.cpp") add_files("csrc/models/*/*.cpp")
add_files("csrc/models/pybind11/models.cc") add_files("csrc/pybind11/bindings.cc")
set_installdir("python/infinilm") set_installdir("python/infinilm")
target_end() target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment