Unverified Commit cdce626e authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/134: 统一模型配置

parent faa5d405
#include "infer_engine.hpp"
#include "../models/llama/llama_config.hpp"
#include "spdlog/spdlog.h"
namespace infinilm::engine {
......@@ -8,7 +7,7 @@ namespace infinilm::engine {
// Constructor
//------------------------------------------------------
InferEngine::InferEngine(
const std::any &config,
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig &cache_config) // Changed parameter
......@@ -24,8 +23,8 @@ InferEngine::InferEngine(
// Try to extract model configuration to override default cache parameters if needed
try {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
cache_config_.num_layers = llama_config.num_hidden_layers;
cache_config_.max_kv_cache_length = llama_config.max_position_embeddings;
......
#pragma once
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_worker.hpp"
......@@ -13,7 +14,7 @@ class InferEngine {
public:
// Updated constructor: accept CacheConfig instead of CacheType
InferEngine(
const std::any &config,
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig &cache_config = cache::CacheConfig());
......@@ -44,7 +45,7 @@ public:
protected:
std::vector<std::unique_ptr<RankWorker>> workers_;
distributed::CommunicationGroup communication_group_;
std::any model_config_;
const InfinilmModel::Config &model_config_;
cache::CacheConfig cache_config_;
};
......
......@@ -8,7 +8,7 @@
namespace infinilm::engine {
RankWorker::RankWorker(const std::any &model_config,
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config)
: model_config_(model_config),
......
......@@ -24,7 +24,7 @@ class RankWorker {
};
public:
RankWorker(const std::any &model_config,
RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig &cache_config);
......@@ -60,7 +60,7 @@ private:
private:
// Worker properties
std::any model_config_;
const InfinilmModel::Config &model_config_;
distributed::RankInfo rank_info_;
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::DynamicCache> cache_ptr_;
......
......@@ -9,6 +9,12 @@
namespace infinilm {
class InfinilmModel : public infinicore::nn::Module {
public:
struct Config {
std::string model_type;
virtual ~Config() = default;
};
virtual ~InfinilmModel() = default;
virtual infinicore::Tensor forward(std::vector<std::any>) const = 0;
// Optional: reset cache; default no-op for models without cache
......
......@@ -5,6 +5,8 @@
#include <string>
#include <vector>
#include "../infinilm_model.hpp"
namespace infinilm::models::llama {
/**
......@@ -13,7 +15,7 @@ namespace infinilm::models::llama {
* This struct holds all hyperparameters needed to construct a Llama model.
* It follows the same structure as HuggingFace's LlamaConfig.
*/
struct LlamaConfig {
struct LlamaConfig : public InfinilmModel::Config {
// Vocabulary and embedding
size_t vocab_size = 32000; // Vocabulary size
size_t hidden_size = 4096; // Hidden dimension size
......
......@@ -3,12 +3,12 @@
namespace infinilm {
std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
const std::any &config,
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info,
std::shared_ptr<cache::DynamicCache> cache_ptr) {
if (config.type() == typeid(models::llama::LlamaConfig)) {
const auto &llama_config = std::any_cast<models::llama::LlamaConfig>(config);
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr;
auto model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, infinicore::DataType::BF16, rank_info);
......
......@@ -7,6 +7,6 @@
namespace infinilm {
class InfinilmModelFactory {
public:
static std::shared_ptr<InfinilmModel> createModel(const std::any &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
static std::shared_ptr<InfinilmModel> createModel(const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
};
} // namespace infinilm
......@@ -85,11 +85,11 @@ namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>>(m, "InferEngine")
.def(py::init([](const infinilm::models::llama::LlamaConfig &cfg,
.def(py::init([](const InfinilmModel::Config &cfg,
const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev,
const infinilm::cache::CacheConfig &cache_config) {
return new InferEngine(std::any(cfg), dist, dev, cache_config);
return new InferEngine(cfg, dist, dev, cache_config);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
......
......@@ -39,9 +39,11 @@ inline void bind_llama(py::module &m) {
.def("clear", &HookRegistry::clear)
.def("has_hooks", &HookRegistry::has_hooks);
py::class_<InfinilmModel::Config> config(m, "Config");
// Bind LlamaConfig
py::class_<LlamaConfig> config(m, "LlamaConfig");
config
py::class_<LlamaConfig, InfinilmModel::Config> llama_config(m, "LlamaConfig");
llama_config
.def(py::init<>())
.def_readwrite("vocab_size", &LlamaConfig::vocab_size)
.def_readwrite("hidden_size", &LlamaConfig::hidden_size)
......
from ....generation.utils import GenerationMixin
import infinicore
from infinilm.models.llama.configuration_llama import LlamaConfig as _LlamaConfig
from infinilm.models.llama.configuration_llama import LlamaConfig
from infinilm.lib import _infinilm
from infinilm.distributed import DistConfig
import json
......@@ -9,141 +9,6 @@ from typing import Optional, Union
from collections import OrderedDict
class LlamaConfig:
"""Llama model configuration adapter for C++ bindings.
This class wraps configuration_llama.LlamaConfig and provides
a _underlying property that creates the C++ config object.
Automatically detects and handles both regular Llama models and Jiuge models
(fm9g7b, fm9g, minicpm) with appropriate defaults and validation.
"""
def __init__(self, config_dict=None, **kwargs):
"""Create LlamaConfig from dictionary or keyword arguments"""
# Use the Python config from configuration_llama
if isinstance(config_dict, _LlamaConfig):
self._python_config = config_dict
else:
if config_dict is not None and isinstance(config_dict, dict):
merged = {**config_dict, **kwargs}
else:
merged = kwargs
self._python_config = _LlamaConfig(**merged)
# Lazy initialization of C++ config
self._cpp_config = None
def __getattr__(self, name):
"""Delegate attribute access to Python config"""
return getattr(self._python_config, name)
def __setattr__(self, name, value):
"""Delegate attribute setting to Python config"""
if name.startswith("_"):
super().__setattr__(name, value)
else:
if hasattr(self, "_python_config"):
setattr(self._python_config, name, value)
# Invalidate C++ config cache when Python config changes
self._cpp_config = None
else:
super().__setattr__(name, value)
@property
def _underlying(self):
"""Get underlying C++ config object, creating it if needed"""
if self._cpp_config is None:
self._cpp_config = _infinilm.LlamaConfig()
# Copy attributes from Python config to C++ config
for key in dir(self._python_config):
if key.startswith("_"):
continue
try:
value = getattr(self._python_config, key)
if hasattr(self._cpp_config, key) and not callable(value):
setattr(self._cpp_config, key, value)
except (AttributeError, TypeError):
pass
# Handle num_key_value_heads with validation
python_num_kv_heads = getattr(
self._python_config, "num_key_value_heads", None
)
if python_num_kv_heads is None or python_num_kv_heads == 0:
self._cpp_config.num_key_value_heads = (
self._cpp_config.num_attention_heads
)
else:
self._cpp_config.num_key_value_heads = python_num_kv_heads
# Handle head_dim with validation (critical for GEMM operations)
python_head_dim = getattr(self._python_config, "head_dim", None)
if python_head_dim is None or python_head_dim == 0:
# Compute from hidden_size and num_attention_heads
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
computed_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
self._cpp_config.head_dim = computed_head_dim
else:
raise ValueError(
f"Cannot compute head_dim: hidden_size={self._cpp_config.hidden_size}, "
f"num_attention_heads={self._cpp_config.num_attention_heads}"
)
else:
# Use from Python config
self._cpp_config.head_dim = python_head_dim
# Validate it matches expected value (warn but allow for flexibility)
if (
self._cpp_config.hidden_size > 0
and self._cpp_config.num_attention_heads > 0
):
expected_head_dim = (
self._cpp_config.hidden_size
// self._cpp_config.num_attention_heads
)
if self._cpp_config.head_dim != expected_head_dim:
import warnings
warnings.warn(
f"head_dim ({self._cpp_config.head_dim}) != hidden_size/num_attention_heads ({expected_head_dim}). "
f"Using head_dim from config."
)
# Ensure vocab_size is set (explicit handling)
if hasattr(self._python_config, "vocab_size"):
self._cpp_config.vocab_size = self._python_config.vocab_size
# Validate config after setting all values (especially important for jiuge models)
if not self._cpp_config.validate():
raise ValueError(
"C++ LlamaConfig validation failed. Check config values."
)
# Log key config values for debugging (especially useful for jiuge models)
import logging
logger = logging.getLogger(__name__)
logger.info(
f"LlamaConfig ({self._python_config.model_type}) C++ LlamaConfig created: vocab_size={self._cpp_config.vocab_size}, "
f"hidden_size={self._cpp_config.hidden_size}, "
f"num_attention_heads={self._cpp_config.num_attention_heads}, "
f"num_key_value_heads={self._cpp_config.num_key_value_heads}, "
f"head_dim={self._cpp_config.head_dim}, "
f"kv_dim={self._cpp_config.kv_dim()}, "
f"attention_bias={self._cpp_config.attention_bias}, "
f"attention_output_bias={self._cpp_config.attention_output_bias}"
)
return self._cpp_config
class LlamaForCausalLM(GenerationMixin):
"""Llama model for causal language modeling"""
......@@ -186,7 +51,7 @@ class LlamaForCausalLM(GenerationMixin):
# config._underlying, device._underlying, dtype
# )
self._model = _infinilm.InferEngine(
config._underlying, distributed_config._underlying, device._underlying.type
config, distributed_config._underlying, device._underlying.type
)
def reset_cache(self, batch_size: int, pos: int = 0, initial_capacity: int = 1024):
......@@ -279,5 +144,5 @@ class LlamaForCausalLM(GenerationMixin):
config_dict = json.load(f)
# LlamaConfig automatically detects and handles jiuge models
config = LlamaConfig(config_dict)
config = LlamaConfig(**config_dict)
return cls(config, device=device, dtype=dtype, **kwargs)
......@@ -15,10 +15,12 @@
"""LLaMA model configuration"""
from infinilm.lib import _infinilm
from ...configuration_utils import PretrainedConfig
class LlamaConfig(PretrainedConfig):
class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
r"""
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
......@@ -166,7 +168,7 @@ class LlamaConfig(PretrainedConfig):
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
pad_token_id=-1,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
......@@ -179,6 +181,8 @@ class LlamaConfig(PretrainedConfig):
head_dim=None,
**kwargs,
):
_infinilm.LlamaConfig.__init__(self)
# ---
self.model_type = "llama"
self.name_or_path = ""
......@@ -221,7 +225,7 @@ class LlamaConfig(PretrainedConfig):
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self)
super().__init__(
PretrainedConfig.__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment