Unverified Commit 42f9d47d authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #100 from InfiniTensor/issue/99

issue/99 - relocated pybind contents
parents a1f6e517 3e398429
#pragma once
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include "../../cache/kv_cache.hpp"
#include "../../debug_utils/hooks.hpp"
#include "../../llama/llama.hpp"
#include "../../llama/llama_attention.hpp"
#include "../../models/debug_utils/hooks.hpp"
#include "../../models/llama/llama.hpp"
#include "../../models/llama/llama_attention.hpp"
#include "infinicore/device.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
using infinicore::Device;
......@@ -22,18 +22,20 @@ inline void bind_llama(py::module &m) {
// Bind HookRegistry
py::class_<HookRegistry, std::shared_ptr<HookRegistry>>(m, "HookRegistry")
.def(py::init<>())
.def("register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
// Convert Python callable to C++ function
self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) {
try {
// Call Python callback with hook name, tensor, and layer index
callback(hook_name, tensor, layer_idx);
} catch (const py::error_already_set &e) {
// Re-raise Python exception
throw;
}
});
}, py::arg("name"), py::arg("callback"))
.def(
"register_hook", [](HookRegistry &self, const std::string &name, py::object callback) {
// Convert Python callable to C++ function
self.register_hook(name, [callback](const std::string &hook_name, const infinicore::Tensor &tensor, int layer_idx) {
try {
// Call Python callback with hook name, tensor, and layer index
callback(hook_name, tensor, layer_idx);
} catch (const py::error_already_set &e) {
// Re-raise Python exception
throw;
}
});
},
py::arg("name"), py::arg("callback"))
.def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks);
......@@ -157,17 +159,18 @@ inline void bind_llama(py::module &m) {
// Bind LlamaForCausalLM
py::class_<LlamaForCausalLM, std::shared_ptr<LlamaForCausalLM>>(m, "LlamaForCausalLM")
.def(py::init([](const LlamaConfig &config, const Device &device, py::object dtype_obj) {
infinicore::DataType dtype = infinicore::DataType::F32;
if (!dtype_obj.is_none()) {
// Extract dtype from Python object
if (py::hasattr(dtype_obj, "_underlying")) {
dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>();
} else {
dtype = dtype_obj.cast<infinicore::DataType>();
}
}
return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}), py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
infinicore::DataType dtype = infinicore::DataType::F32;
if (!dtype_obj.is_none()) {
// Extract dtype from Python object
if (py::hasattr(dtype_obj, "_underlying")) {
dtype = dtype_obj.attr("_underlying").cast<infinicore::DataType>();
} else {
dtype = dtype_obj.cast<infinicore::DataType>();
}
}
return std::make_shared<LlamaForCausalLM>(config, device, dtype);
}),
py::arg("config"), py::arg("device"), py::arg("dtype") = py::none())
.def("state_dict", [](const LlamaForCausalLM &model) {
// Convert state_dict to Python dict with shape information
auto state_dict = model.state_dict();
......@@ -181,66 +184,72 @@ inline void bind_llama(py::module &m) {
}
return result;
})
.def("get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name
auto state_dict = model.state_dict();
auto it = state_dict.find(name);
if (it != state_dict.end()) {
// Parameter inherits from Tensor, cast to Tensor for pybind11
const infinicore::Tensor &tensor = it->second;
return tensor;
}
throw std::runtime_error("Parameter '" + name + "' not found in model");
}, py::arg("name"))
.def("load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
// Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) {
std::string key = item.first.cast<std::string>();
py::object value = item.second.cast<py::object>();
cpp_state_dict.emplace(key, convert_to_tensor(value, device));
}
model.load_state_dict(cpp_state_dict);
}, py::arg("state_dict"), py::arg("device"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def("forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
// Helper to extract C++ tensor from Python object
auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(obj, "_underlying")) {
return obj.attr("_underlying").cast<infinicore::Tensor>();
.def(
"get_parameter", [](const LlamaForCausalLM &model, const std::string &name) {
// Get actual tensor parameter by name
auto state_dict = model.state_dict();
auto it = state_dict.find(name);
if (it != state_dict.end()) {
// Parameter inherits from Tensor, cast to Tensor for pybind11
const infinicore::Tensor &tensor = it->second;
return tensor;
}
throw std::runtime_error("Parameter '" + name + "' not found in model");
},
py::arg("name"))
.def(
"load_state_dict", [convert_to_tensor](LlamaForCausalLM &model, py::dict state_dict, const Device &device) {
// Convert Python dict to C++ state_dict
std::unordered_map<std::string, infinicore::Tensor> cpp_state_dict;
for (auto item : state_dict) {
std::string key = item.first.cast<std::string>();
py::object value = item.second.cast<py::object>();
cpp_state_dict.emplace(key, convert_to_tensor(value, device));
}
// Try direct cast (in case it's already a C++ tensor)
try {
return obj.cast<infinicore::Tensor>();
} catch (const py::cast_error &) {
// Extract device from first tensor for conversion
Device device = Device(Device::Type::CPU, 0);
if (py::hasattr(obj, "device")) {
try {
auto py_device = obj.attr("device");
if (py::hasattr(py_device, "_underlying")) {
device = py_device.attr("_underlying").cast<Device>();
} else {
device = py_device.cast<Device>();
model.load_state_dict(cpp_state_dict);
},
py::arg("state_dict"), py::arg("device"))
.def("config", &LlamaForCausalLM::config, py::return_value_policy::reference_internal)
.def(
"forward", [convert_to_tensor](const LlamaForCausalLM &model, py::object input_ids, py::object position_ids, py::object kv_caches = py::none()) {
// Helper to extract C++ tensor from Python object
auto get_tensor = [convert_to_tensor](py::object obj) -> infinicore::Tensor {
// If it's already a Python InfiniCore tensor wrapper, extract underlying
if (py::hasattr(obj, "_underlying")) {
return obj.attr("_underlying").cast<infinicore::Tensor>();
}
// Try direct cast (in case it's already a C++ tensor)
try {
return obj.cast<infinicore::Tensor>();
} catch (const py::cast_error &) {
// Extract device from first tensor for conversion
Device device = Device(Device::Type::CPU, 0);
if (py::hasattr(obj, "device")) {
try {
auto py_device = obj.attr("device");
if (py::hasattr(py_device, "_underlying")) {
device = py_device.attr("_underlying").cast<Device>();
} else {
device = py_device.cast<Device>();
}
} catch (...) {
// Keep default CPU device
}
} catch (...) {
// Keep default CPU device
}
return convert_to_tensor(obj, device);
}
return convert_to_tensor(obj, device);
}
};
};
// Convert Python tensors to C++ tensors
auto infini_input_ids = get_tensor(input_ids);
auto infini_position_ids = get_tensor(position_ids);
// Convert Python tensors to C++ tensors
auto infini_input_ids = get_tensor(input_ids);
auto infini_position_ids = get_tensor(position_ids);
// Handle kv_caches if provided
std::vector<void *> *kv_caches_ptr = nullptr;
// Handle kv_caches if provided
std::vector<void *> *kv_caches_ptr = nullptr;
return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr);
}, py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
return model.forward(infini_input_ids, infini_position_ids, kv_caches_ptr);
},
py::arg("input_ids"), py::arg("position_ids"), py::arg("kv_caches") = py::none());
}
} // namespace infinilm::models::llama
......@@ -42,7 +42,7 @@ target("_infinilm")
local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") and "HOMEPATH" or "HOME") .. "/.infini")
add_includedirs("csrc", { public = false })
add_includedirs("csrc/models/pybind11", { public = false })
add_includedirs("csrc/pybind11", { public = false })
add_includedirs("include", { public = false })
add_includedirs(INFINI_ROOT.."/include", { public = true })
-- spdlog is already included globally via add_includedirs at the top
......@@ -52,7 +52,7 @@ target("_infinilm")
-- Add Llama model files
add_files("csrc/models/*/*.cpp")
add_files("csrc/models/pybind11/models.cc")
add_files("csrc/pybind11/bindings.cc")
set_installdir("python/infinilm")
target_end()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment