Unverified Commit cdce626e authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/134: 统一模型配置

parent faa5d405
#include "infer_engine.hpp" #include "infer_engine.hpp"
#include "../models/llama/llama_config.hpp"
#include "spdlog/spdlog.h" #include "spdlog/spdlog.h"
namespace infinilm::engine { namespace infinilm::engine {
...@@ -8,7 +7,7 @@ namespace infinilm::engine { ...@@ -8,7 +7,7 @@ namespace infinilm::engine {
// Constructor // Constructor
//------------------------------------------------------ //------------------------------------------------------
InferEngine::InferEngine( InferEngine::InferEngine(
const std::any &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config, const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type, infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter const cache::CacheConfig &cache_config) // Changed parameter
...@@ -24,8 +23,8 @@ InferEngine::InferEngine( ...@@ -24,8 +23,8 @@ InferEngine::InferEngine(
// Try to extract model configuration to override default cache parameters if needed // Try to extract model configuration to override default cache parameters if needed
try { try {
if (config.type() == typeid(models::llama::LlamaConfig)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config); const auto &llama_config = *llama_config_ptr;
cache_config_.num_layers = llama_config.num_hidden_layers; cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings; cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
......
#pragma once #pragma once
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp" #include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include "rank_worker.hpp" #include "rank_worker.hpp"
...@@ -13,7 +14,7 @@ class InferEngine { ...@@ -13,7 +14,7 @@ class InferEngine {
public: public:
// Updated constructor: accept CacheConfig instead of CacheType // Updated constructor: accept CacheConfig instead of CacheType
InferEngine( InferEngine(
const std::any &config, const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(), const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(), infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig()); const cache::CacheConfig &cache_config = cache::CacheConfig());
...@@ -44,7 +45,7 @@ public: ...@@ -44,7 +45,7 @@ public:
protected: protected:
std::vector<std::unique_ptr<RankWorker>> workers_; std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_; distributed::CommunicationGroup communication_group_;
std::any model_config_; const InfinilmModel::Config &model_config_;
cache::CacheConfig cache_config_; cache::CacheConfig cache_config_;
}; };
......
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace infinilm::engine { namespace infinilm::engine {
RankWorker::RankWorker(const std::any &model_config, RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config) const cache::CacheConfig &cache_config)
: model_config_(model_config), : model_config_(model_config),
......
...@@ -24,7 +24,7 @@ class RankWorker { ...@@ -24,7 +24,7 @@ class RankWorker {
}; };
public: public:
RankWorker(const std::any &model_config, RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info, const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config); const cache::CacheConfig &cache_config);
...@@ -60,7 +60,7 @@ private: ...@@ -60,7 +60,7 @@ private:
private: private:
// Worker properties // Worker properties
std::any model_config_; const InfinilmModel::Config &model_config_;
distributed::RankInfo rank_info_; distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_; std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_; std::shared_ptr<cache::DynamicCache> cache_ptr_;
......
...@@ -9,6 +9,12 @@ ...@@ -9,6 +9,12 @@
namespace infinilm { namespace infinilm {
class InfinilmModel : public infinicore::nn::Module { class InfinilmModel : public infinicore::nn::Module {
public: public:
struct Config {
std::string model_type;
virtual ~Config() = default;
};
virtual ~InfinilmModel() = default; virtual ~InfinilmModel() = default;
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0; virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
// Optional: reset cache; default no-op for models without cache // Optional: reset cache; default no-op for models without cache
......
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "../infinilm_model.hpp"
namespace infinilm::models::llama { namespace infinilm::models::llama {
/** /**
...@@ -13,7 +15,7 @@ namespace infinilm::models::llama { ...@@ -13,7 +15,7 @@ namespace infinilm::models::llama {
* This struct holds all hyperparameters needed to construct a Llama model. * This struct holds all hyperparameters needed to construct a Llama model.
* It follows the same structure as HuggingFace's LlamaConfig. * It follows the same structure as HuggingFace's LlamaConfig.
*/ */
struct LlamaConfig { struct LlamaConfig : public InfinilmModel::Config {
// Vocabulary and embedding // Vocabulary and embedding
size_t vocab_size = 32000; // Vocabulary size size_t vocab_size = 32000; // Vocabulary size
size_t hidden_size = 4096; // Hidden dimension size size_t hidden_size = 4096; // Hidden dimension size
......
...@@ -3,12 +3,12 @@ ...@@ -3,12 +3,12 @@
namespace infinilm { namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const std::any &config, const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info, engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) { std::shared_ptr<cache::DynamicCache> cache_ptr) {
if (config.type() == typeid(models::llama::LlamaConfig)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config); const auto &llama_config = *llama_config_ptr;
auto model = std::make_shared<models::llama::LlamaForCausalLM>( auto model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, infinicore::DataType::BF16, rank_info); llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
......
...@@ -7,6 +7,6 @@ ...@@ -7,6 +7,6 @@
namespace infinilm { namespace infinilm {
class InfinilmModelFactory { class InfinilmModelFactory {
public: public:
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr); static std::shared_ptr<InfinilmModel> createModel(const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
}; };
} // namespace infinilm } // namespace infinilm
...@@ -85,11 +85,11 @@ namespace infinilm::engine { ...@@ -85,11 +85,11 @@ namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) { inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine") py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
.def(py::init([](const infinilm::models::llama::LlamaConfig &cfg, .def(py::init([](const InfinilmModel::Config &cfg,
const infinilm::engine::distributed::DistConfig &dist, const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev, infinicore::Device::Type dev,
const infinilm::cache::CacheConfig &cache_config) { const infinilm::cache::CacheConfig &cache_config) {
return new InferEngine(std::any(cfg), dist, dev, cache_config); return new InferEngine(cfg, dist, dev, cache_config);
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
......
...@@ -39,9 +39,11 @@ inline void bind_llama(py::module &m) { ...@@ -39,9 +39,11 @@ inline void bind_llama(py::module &m) {
.def("clear", &HookRegistry::clear) .def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks); .def("has_hooks", &HookRegistry::has_hooks);
py::class_<InfinilmModel::Config> config(m, "Config");
// Bind LlamaConfig // Bind LlamaConfig
py::class_<LlamaConfig> config(m, "LlamaConfig"); py::class_<LlamaConfig, InfinilmModel::Config> llama_config(m, "LlamaConfig");
config llama_config
.def(py::init<>()) .def(py::init<>())
.def_readwrite("vocab_size", &LlamaConfig::vocab_size) .def_readwrite("vocab_size", &LlamaConfig::vocab_size)
.def_readwrite("hidden_size", &LlamaConfig::hidden_size) .def_readwrite("hidden_size", &LlamaConfig::hidden_size)
......
from ....generation.utils import GenerationMixin from ....generation.utils import GenerationMixin
import infinicore import infinicore
from infinilm.models.llama.configuration_llama import LlamaConfig as _LlamaConfig from infinilm.models.llama.configuration_llama import LlamaConfig
from infinilm.lib import _infinilm from infinilm.lib import _infinilm
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
import json import json
...@@ -9,141 +9,6 @@ from typing import Optional, Union ...@@ -9,141 +9,6 @@ from typing import Optional, Union
from collections import OrderedDict from collections import OrderedDict
class LlamaConfig:
"""Llama model configuration adapter for C++ bindings.
This class wraps configuration_llama.LlamaConfig and provides
a _underlying property that creates the C++ config object.
Automatically detects and handles both regular Llama models and Jiuge models
(fm9g7b, fm9g, minicpm) with appropriate defaults and validation.
"""
def __init__(self, config_dict=None, **kwargs):
"""Create LlamaConfig from dictionary or keyword arguments"""
# Use the Python config from configuration_llama
if isinstance(config_dict, _LlamaConfig):
self._python_config = config_dict
else:
if config_dict is not None and isinstance(config_dict, dict):
merged = {**config_dict, **kwargs}
else:
merged = kwargs
self._python_config = _LlamaConfig(**merged)
# Lazy initialization of C++ config
self._cpp_config = None
def __getattr__(self, name):
"""Delegate attribute access to Python config"""
return getattr(self._python_config, name)
def __setattr__(self, name, value):
"""Delegate attribute setting to Python config"""
if name.startswith("_"):
super().__setattr__(name, value)
else:
if hasattr(self, "_python_config"):
setattr(self._python_config, name, value)
# Invalidate C++ config cache when Python config changes
self._cpp_config = None
else:
super().__setattr__(name, value)
@property
def _underlying(self):
"""Get underlying C++ config object, creating it if needed"""
if self._cpp_config is None:
self._cpp_config = _infinilm.LlamaConfig()
# Copy attributes from Python config to C++ config
for key in dir(self._python_config):
if key.startswith("_"):
continue
try:
value = getattr(self._python_config, key)
if hasattr(self._cpp_config, key) and not callable(value):
setattr(self._cpp_config, key, value)
except (AttributeError, TypeError):
pass
# Handle num_key_value_heads with validation
python_num_kv_heads = getattr(
self._python_config, "num_key_value_heads", None
)
if python_num_kv_heads is None or python_num_kv_heads == 0:
self._cpp_config.num_key_value_heads = (
self._cpp_config.num_attention_heads
)
else:
self._cpp_config.num_key_value_heads = python_num_kv_heads
# Handle head_dim with validation (critical for GEMM operations)
python_head_dim = getattr(self._python_config, "head_dim", None)
if python_head_dim is None or python_head_dim == 0:
# Compute from hidden_size and num_attention_heads
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
computed_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
self._cpp_config.head_dim = computed_head_dim
else:
raise ValueError(
f"Cannot compute head_dim: hidden_size={self._cpp_config.hidden_size}, "
f"num_attention_heads={self._cpp_config.num_attention_heads}"
)
else:
# Use from Python config
self._cpp_config.head_dim = python_head_dim
# Validate it matches expected value (warn but allow for flexibility)
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
expected_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
if self._cpp_config.head_dim != expected_head_dim:
import warnings
warnings.warn(
f"head_dim ({self._cpp_config.head_dim}) != hidden_size/num_attention_heads ({expected_head_dim}). "
f"Using head_dim from config."
)
# Ensure vocab_size is set (explicit handling)
if hasattr(self._python_config, "vocab_size"):
self._cpp_config.vocab_size = self._python_config.vocab_size
# Validate config after setting all values (especially important for jiuge models)
if not self._cpp_config.validate():
raise ValueError(
"C++ LlamaConfig validation failed. Check config values."
)
# Log key config values for debugging (especially useful for jiuge models)
import logging
logger = logging.getLogger(__name__)
logger.info(
f"LlamaConfig ({self._python_config.model_type}) C++ LlamaConfig created: vocab_size={self._cpp_config.vocab_size}, "
f"hidden_size={self._cpp_config.hidden_size}, "
f"num_attention_heads={self._cpp_config.num_attention_heads}, "
f"num_key_value_heads={self._cpp_config.num_key_value_heads}, "
f"head_dim={self._cpp_config.head_dim}, "
f"kv_dim={self._cpp_config.kv_dim()}, "
f"attention_bias={self._cpp_config.attention_bias}, "
f"attention_output_bias={self._cpp_config.attention_output_bias}"
)
return self._cpp_config
class LlamaForCausalLM(GenerationMixin): class LlamaForCausalLM(GenerationMixin):
"""Llama model for causal language modeling""" """Llama model for causal language modeling"""
...@@ -186,7 +51,7 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -186,7 +51,7 @@ class LlamaForCausalLM(GenerationMixin):
# config._underlying, device._underlying, dtype # config._underlying, device._underlying, dtype
# ) # )
self._model = _infinilm.InferEngine( self._model = _infinilm.InferEngine(
config._underlying, distributed_config._underlying, device._underlying.type config, distributed_config._underlying, device._underlying.type
) )
def reset_cache(self, batch_size: int, pos: int = 0, initial_capacity: int = 1024): def reset_cache(self, batch_size: int, pos: int = 0, initial_capacity: int = 1024):
...@@ -279,5 +144,5 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -279,5 +144,5 @@ class LlamaForCausalLM(GenerationMixin):
config_dict = json.load(f) config_dict = json.load(f)
# LlamaConfig automatically detects and handles jiuge models # LlamaConfig automatically detects and handles jiuge models
config = LlamaConfig(config_dict) config = LlamaConfig(**config_dict)
return cls(config, device=device, dtype=dtype, **kwargs) return cls(config, device=device, dtype=dtype, **kwargs)
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
"""LLaMA model configuration""" """LLaMA model configuration"""
from infinilm.lib import _infinilm
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
class LlamaConfig(PretrainedConfig): class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
r""" r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
...@@ -166,7 +168,7 @@ class LlamaConfig(PretrainedConfig): ...@@ -166,7 +168,7 @@ class LlamaConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
rms_norm_eps=1e-6, rms_norm_eps=1e-6,
use_cache=True, use_cache=True,
pad_token_id=None, pad_token_id=-1,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
pretraining_tp=1, pretraining_tp=1,
...@@ -179,6 +181,8 @@ class LlamaConfig(PretrainedConfig): ...@@ -179,6 +181,8 @@ class LlamaConfig(PretrainedConfig):
head_dim=None, head_dim=None,
**kwargs, **kwargs,
): ):
_infinilm.LlamaConfig.__init__(self)
# --- # ---
self.model_type = "llama" self.model_type = "llama"
self.name_or_path = "" self.name_or_path = ""
...@@ -221,7 +225,7 @@ class LlamaConfig(PretrainedConfig): ...@@ -221,7 +225,7 @@ class LlamaConfig(PretrainedConfig):
self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self) # rope_config_validation(self)
super().__init__( PretrainedConfig.__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment