Commit 3e398429 authored by wooway777's avatar wooway777
Browse files

issue/99 - relocated pybind contents

parent a1f6e517
#pragma once #pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "../../cache/kv_cache.hpp" #include "../../cache/kv_cache.hpp"
#include "../../debug_utils/hooks.hpp" #include "../../models/debug_utils/hooks.hpp"
#include "../../llama/llama.hpp" #include "../../models/llama/llama.hpp"
#include "../../llama/llama_attention.hpp" #include "../../models/llama/llama_attention.hpp"
#include "infinicore/device.hpp" #include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/nn/module.hpp" #include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11; namespace py = pybind11;
using infinicore::Device; using infinicore::Device;
...@@ -22,18 +22,20 @@ inline void bind_llama(py::module &m) { ...@@ -22,18 +22,20 @@ inline void bind_llama(py::module &m) {
// Bind HookRegistry // Bind HookRegistry
py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry") py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry")
.def(py::init<>()) .def(py::init<>())
.def("register_hook", [](HookRegistry &self, const std::string &name, py::object callback) { .def(
// Convert Python callable to C++ function "register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) { // Convert Python callable to C++ function
try { self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) {
// Call Python callback with hook name, tensor, and layer index try {
callback(hook_name, tensor, layer_idx); // Call Python callback with hook name, tensor, and layer index
} catch (const py::error_already_set &e) { callback(hook_name, tensor, layer_idx);
// Re-raise Python exception } catch (const py::error_already_set &e) {
throw; // Re-raise Python exception
} throw;
}); }
}, py::arg("name"), py::arg("callback")) });
},
py::arg("name"), py::arg("callback"))
.def("clear", &HookRegistry::clear) .def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks); .def("has_hooks", &HookRegistry::has_hooks);
...@@ -157,17 +159,18 @@ inline void bind_llama(py::module &m) { ...@@ -157,17 +159,18 @@ inline void bind_llama(py::module &m) {
// Bind LlamaForCausalLM // Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM") py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def(py::init([](const LlamaConfig &config, const Device &device, py::object dtype_obj) { .def(py::init([](const LlamaConfig &config, const Device &device, py::object dtype_obj) {
infinicore::DataType dtype = infinicore::DataType::F32; infinicore::DataType dtype = infinicore::DataType::F32;
if (!dtype_obj.is_none()) { if (!dtype_obj.is_none()) {
// Extract dtype from Python object // Extract dtype from Python object
if (py::hasattr(dtype_obj, "_underlying")) { if (py::hasattr(dtype_obj, "_underlying")) {
dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>(); dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>();
} else { } else {
dtype = dtype_obj.cast<infinicore::DataType>(); dtype = dtype_obj.cast<infinicore::DataType>();
} }
} }
return std::make_shared<LlamaForCausalLM>(config, device, dtype); return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}), py::arg("config"), py::arg("device"), py::arg("dtype") = py::none()) }),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) { .def("state_dict", [](const LlamaForCausalLM &model) {
// Convert state_dict to Python dict with shape information // Convert state_dict to Python dict with shape information
auto state_dict = model.state_dict(); auto state_dict = model.state_dict();
...@@ -181,66 +184,72 @@ inline void bind_llama(py::module &m) { ...@@ -181,66 +184,72 @@ inline void bind_llama(py::module &m) {
} }
return result; return result;
}) })
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) { .def(
// Get actual tensor parameter by name "get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
auto state_dict = model.state_dict(); // Get actual tensor parameter by name
auto it = state_dict.find(name); auto state_dict = model.state_dict();
if (it != state_dict.end()) { auto it = state_dict.find(name);
// Parameter inherits from Tensor, cast to Tensor for pybind11 if (it != state_dict.end()) {
const infinicore::Tensor &tensor = it->second; // Parameter inherits from Tensor, cast to Tensor for pybind11
return tensor; const infinicore::Tensor &tensor = it->second;
} return tensor;
throw std::runtime_error("Parameter '" + name + "' not found in model"); }
}, py::arg("name")) throw std::runtime_error("Parameter '" + name + "' not found in model");
.def("load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) { },
// Convert Python dict to C++ state_dict py::arg("name"))
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict; .def(
for (auto item : state_dict) { "load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
std::string key = item.first.cast<std::string>(); // Convert Python dict to C++ state_dict
py::object value = item.second.cast<py::object>(); std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
cpp_state_dict.emplace(key, convert_to_tensor(value, device)); for (auto item : state_dict) {
} std::string key = item.first.cast<std::string>();
model.load_state_dict(cpp_state_dict); py::object value = item.second.cast<py::object>();
}, py::arg("state_dict"), py::arg("device")) cpp_state_dict.emplace(key, convert_to_tensor(value, device));
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
// Helper to extract C++ tensor from Python object
auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(obj, "_underlying")) {
return obj.attr("_underlying").cast<infinicore::Tensor>();
} }
// Try direct cast (in case it's already a C++ tensor) model.load_state_dict(cpp_state_dict);
try { },
return obj.cast<infinicore::Tensor>(); py::arg("state_dict"), py::arg("device"))
} catch (const py::cast_error &) { .def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
// Extract device from first tensor for conversion .def(
Device device = Device(Device::Type::CPU, 0); "forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
if (py::hasattr(obj, "device")) { // Helper to extract C++ tensor from Python object
try { auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
auto py_device = obj.attr("device"); // If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(py_device, "_underlying")) { if (py::hasattr(obj, "_underlying")) {
device = py_device.attr("_underlying").cast<Device>(); return obj.attr("_underlying").cast<infinicore::Tensor>();
} else { }
device = py_device.cast<Device>(); // Try direct cast (in case it's already a C++ tensor)
try {
return obj.cast<infinicore::Tensor>();
} catch (const py::cast_error &) {
// Extract device from first tensor for conversion
Device device = Device(Device::Type::CPU, 0);
if (py::hasattr(obj, "device")) {
try {
auto py_device = obj.attr("device");
if (py::hasattr(py_device, "_underlying")) {
device = py_device.attr("_underlying").cast<Device>();
} else {
device = py_device.cast<Device>();
}
} catch (...) {
// Keep default CPU device
} }
} catch (...) {
// Keep default CPU device
} }
return convert_to_tensor(obj, device);
} }
return convert_to_tensor(obj, device); };
}
};
// Convert Python tensors to C++ tensors // Convert Python tensors to C++ tensors
auto infini_input_ids = get_tensor(input_ids); auto infini_input_ids = get_tensor(input_ids);
auto infini_position_ids = get_tensor(position_ids); auto infini_position_ids = get_tensor(position_ids);
// Handle kv_caches if provided // Handle kv_caches if provided
std::vector<void *> *kv_caches_ptr = nullptr; std::vector<void *> *kv_caches_ptr = nullptr;
return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr); return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr);
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none()); },
py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
} }
} // namespace infinilm::models::llama } // namespace infinilm::models::llama
...@@ -42,7 +42,7 @@ target("_infinilm") ...@@ -42,7 +42,7 @@ target("_infinilm")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini") local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs("csrc", { public = false }) add_includedirs("csrc", { public = false })
add_includedirs("csrc/models/pybind11", { public = false }) add_includedirs("csrc/pybind11", { public = false })
add_includedirs("include", { public = false }) add_includedirs("include", { public = false })
add_includedirs(INFINI_ROOT.."/include", { public = true }) add_includedirs(INFINI_ROOT.."/include", { public = true })
-- spdlog is already included globally via add_includedirs at the top -- spdlog is already included globally via add_includedirs at the top
...@@ -52,7 +52,7 @@ target("_infinilm") ...@@ -52,7 +52,7 @@ target("_infinilm")
-- Add Llama model files -- Add Llama model files
add_files("csrc/models/*/*.cpp") add_files("csrc/models/*/*.cpp")
add_files("csrc/models/pybind11/models.cc") add_files("csrc/pybind11/bindings.cc")
set_installdir("python/infinilm") set_installdir("python/infinilm")
target_end() target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment