Unverified Commit 13a4154a authored by Jiacheng Huang's avatar Jiacheng Huang Committed by GitHub
Browse files

issue/150 移除推理接口中的 `dtype` 传递

parent 91c06fd9
...@@ -16,7 +16,6 @@ namespace infinilm::models::llama { ...@@ -16,7 +16,6 @@ namespace infinilm::models::llama {
LlamaAttention::LlamaAttention(const LlamaConfig &config, LlamaAttention::LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
: layer_idx_(layer_idx), : layer_idx_(layer_idx),
hidden_size_(config.hidden_size), hidden_size_(config.hidden_size),
...@@ -27,6 +26,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config, ...@@ -27,6 +26,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
use_bias_(config.attention_bias), use_bias_(config.attention_bias),
use_output_bias_(config.attention_output_bias), use_output_bias_(config.attention_output_bias),
max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) { max_position_embeddings_(config.max_position_embeddings), rank_info_(rank_info) {
const auto &dtype{config.dtype};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size; int tp_size = rank_info.tp_size;
......
...@@ -38,7 +38,6 @@ public: ...@@ -38,7 +38,6 @@ public:
LlamaAttention(const LlamaConfig &config, LlamaAttention(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
......
...@@ -16,6 +16,9 @@ namespace infinilm::models::llama { ...@@ -16,6 +16,9 @@ namespace infinilm::models::llama {
* It follows the same structure as HuggingFace's LlamaConfig. * It follows the same structure as HuggingFace's LlamaConfig.
*/ */
struct LlamaConfig : public InfinilmModel::Config { struct LlamaConfig : public InfinilmModel::Config {
// Data type
infinicore::DataType dtype = infinicore::DataType::F32;
// Vocabulary and embedding // Vocabulary and embedding
size_t vocab_size = 32000; // Vocabulary size size_t vocab_size = 32000; // Vocabulary size
size_t hidden_size = 4096; // Hidden dimension size size_t hidden_size = 4096; // Hidden dimension size
......
...@@ -7,17 +7,18 @@ namespace infinilm::models::llama { ...@@ -7,17 +7,18 @@ namespace infinilm::models::llama {
LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer::LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype, engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx), rank_info_(rank_info) {
engine::distributed::RankInfo rank_info) : layer_idx_(layer_idx) , rank_info_(rank_info){ const auto &dtype{config.dtype};
// Initialize layer normalization layers
// Initialize layer normalization layers
INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps, INFINICORE_NN_MODULE_INIT(input_layernorm, config.hidden_size, config.rms_norm_eps,
dtype, device); dtype, device);
INFINICORE_NN_MODULE_INIT(post_attention_layernorm, config.hidden_size, config.rms_norm_eps, INFINICORE_NN_MODULE_INIT(post_attention_layernorm, config.hidden_size, config.rms_norm_eps,
dtype, device); dtype, device);
// Initialize attention and MLP modules // Initialize attention and MLP modules
INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, dtype, rank_info_); INFINICORE_NN_MODULE_INIT(self_attn, config, device, layer_idx, rank_info_);
INFINICORE_NN_MODULE_INIT(mlp, config, device, dtype, rank_info_); INFINICORE_NN_MODULE_INIT(mlp, config, device, rank_info_);
} }
infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states, infinicore::Tensor LlamaDecoderLayer::forward(const infinicore::Tensor &hidden_states,
......
...@@ -36,7 +36,6 @@ public: ...@@ -36,7 +36,6 @@ public:
LlamaDecoderLayer(const LlamaConfig &config, LlamaDecoderLayer(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
size_t layer_idx, size_t layer_idx,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
......
...@@ -8,14 +8,15 @@ namespace infinilm::models::llama { ...@@ -8,14 +8,15 @@ namespace infinilm::models::llama {
LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM::LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) { engine::distributed::RankInfo rank_info) {
// Initialize module's device_ member // Initialize module's device_ member
device_ = device; device_ = device;
const auto &dtype{config.dtype};
// Initialize base model // Initialize base model
INFINICORE_NN_MODULE_INIT(model, config, device, dtype, rank_info); INFINICORE_NN_MODULE_INIT(model, config, device, rank_info);
// Initialize language modeling head // Initialize language modeling head
// Note: If tie_word_embeddings is true, we would share weights with embed_tokens // Note: If tie_word_embeddings is true, we would share weights with embed_tokens
......
...@@ -27,11 +27,9 @@ public: ...@@ -27,11 +27,9 @@ public:
* *
* @param config Model configuration * @param config Model configuration
* @param device Device to create tensors on * @param device Device to create tensors on
* @param dtype Optional data type for model parameters (defaults to BF16)
*/ */
LlamaForCausalLM(const LlamaConfig &config, LlamaForCausalLM(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::BF16,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
......
...@@ -6,11 +6,11 @@ namespace infinilm::models::llama { ...@@ -6,11 +6,11 @@ namespace infinilm::models::llama {
LlamaMLP::LlamaMLP(const LlamaConfig &config, LlamaMLP::LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
: hidden_size_(config.hidden_size), : hidden_size_(config.hidden_size),
intermediate_size_(config.intermediate_size), intermediate_size_(config.intermediate_size),
use_bias_(config.mlp_bias), rank_info_(rank_info) { use_bias_(config.mlp_bias), rank_info_(rank_info) {
const auto &dtype{config.dtype};
int tp_rank = rank_info.tp_rank; int tp_rank = rank_info.tp_rank;
int tp_size = rank_info.tp_size; int tp_size = rank_info.tp_size;
......
...@@ -35,7 +35,6 @@ public: ...@@ -35,7 +35,6 @@ public:
*/ */
LlamaMLP(const LlamaConfig &config, LlamaMLP(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
......
...@@ -9,9 +9,10 @@ namespace infinilm::models::llama { ...@@ -9,9 +9,10 @@ namespace infinilm::models::llama {
LlamaModel::LlamaModel(const LlamaConfig &config, LlamaModel::LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype,
engine::distributed::RankInfo rank_info) engine::distributed::RankInfo rank_info)
: config_(config) { : config_(config) {
const auto &dtype{config.dtype};
// Initialize token embeddings // Initialize token embeddings
INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size, INFINICORE_NN_MODULE_INIT(embed_tokens, config.vocab_size, config.hidden_size,
std::nullopt, dtype, device); std::nullopt, dtype, device);
...@@ -23,7 +24,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config, ...@@ -23,7 +24,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
layers_.reserve(config.num_hidden_layers); layers_.reserve(config.num_hidden_layers);
for (size_t i = 0; i < config.num_hidden_layers; ++i) { for (size_t i = 0; i < config.num_hidden_layers; ++i) {
layers_.push_back(this->register_module<LlamaDecoderLayer>( layers_.push_back(this->register_module<LlamaDecoderLayer>(
"layers." + std::to_string(i), config, device, i, dtype, rank_info)); "layers." + std::to_string(i), config, device, i, rank_info));
} }
// Initialize final layer normalization // Initialize final layer normalization
......
...@@ -40,7 +40,6 @@ public: ...@@ -40,7 +40,6 @@ public:
*/ */
LlamaModel(const LlamaConfig &config, LlamaModel(const LlamaConfig &config,
const infinicore::Device &device, const infinicore::Device &device,
infinicore::DataType dtype = infinicore::DataType::F32,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo()); engine::distributed::RankInfo rank_info = engine::distributed::RankInfo());
/** /**
......
...@@ -10,7 +10,7 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel( ...@@ -10,7 +10,7 @@ std::shared_ptr<InfinilmModel> InfinilmModelFactory::createModel(
if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) { if (const auto llama_config_ptr = dynamic_cast<const models::llama::LlamaConfig *>(&config)) {
const auto &llama_config = *llama_config_ptr; const auto &llama_config = *llama_config_ptr;
auto model = std::make_shared<models::llama::LlamaForCausalLM>( auto model = std::make_shared<models::llama::LlamaForCausalLM>(
llama_config, rank_info.device, infinicore::DataType::BF16, rank_info); llama_config, rank_info.device, rank_info);
if (cache_ptr != nullptr) { if (cache_ptr != nullptr) {
model->model().set_external_cache(cache_ptr); model->model().set_external_cache(cache_ptr);
......
...@@ -45,6 +45,8 @@ inline void bind_llama(py::module &m) { ...@@ -45,6 +45,8 @@ inline void bind_llama(py::module &m) {
py::class_<LlamaConfig, InfinilmModel::Config> llama_config(m, "LlamaConfig"); py::class_<LlamaConfig, InfinilmModel::Config> llama_config(m, "LlamaConfig");
llama_config llama_config
.def(py::init<>()) .def(py::init<>())
// TODO: Change this to `dtype` after updating InfiniCore pybind11 exposing mechanism.
.def_readwrite("_dtype", &LlamaConfig::dtype)
.def_readwrite("vocab_size", &LlamaConfig::vocab_size) .def_readwrite("vocab_size", &LlamaConfig::vocab_size)
.def_readwrite("hidden_size", &LlamaConfig::hidden_size) .def_readwrite("hidden_size", &LlamaConfig::hidden_size)
.def_readwrite("intermediate_size", &LlamaConfig::intermediate_size) .def_readwrite("intermediate_size", &LlamaConfig::intermediate_size)
......
...@@ -141,14 +141,6 @@ def get_args(): ...@@ -141,14 +141,6 @@ def get_args():
required=True, required=True,
help="model path", help="model path",
) )
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="bfloat16",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=parse_list, type=parse_list,
...@@ -195,7 +187,6 @@ class TestModel: ...@@ -195,7 +187,6 @@ class TestModel:
def __init__( def __init__(
self, self,
model_path, model_path,
infini_dtype=infinicore.bfloat16,
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
tp=1, tp=1,
) -> None: ) -> None:
...@@ -206,7 +197,6 @@ class TestModel: ...@@ -206,7 +197,6 @@ class TestModel:
model = infinilm.AutoLlamaModel.from_pretrained( model = infinilm.AutoLlamaModel.from_pretrained(
model_path, model_path,
device=infini_device, device=infini_device,
dtype=infini_dtype,
backend="cpp", backend="cpp",
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
) )
...@@ -214,7 +204,7 @@ class TestModel: ...@@ -214,7 +204,7 @@ class TestModel:
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 加载权重 # 加载权重
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
load_model_state_dict_by_file(model, model_path, dtype=infini_dtype) load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建 tokenizer # 创建 tokenizer
...@@ -289,14 +279,6 @@ if __name__ == "__main__": ...@@ -289,14 +279,6 @@ if __name__ == "__main__":
model_path = args.model model_path = args.model
infini_device = infinicore.device(device_str, 0) infini_device = infinicore.device(device_str, 0)
if args.dtype == "float32":
infini_dtype = infinicore.float32
elif args.dtype == "bfloat16":
infini_dtype = infinicore.bfloat16
elif args.dtype == "float16":
infini_dtype = infinicore.float16
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
tp = args.tensor_parallel_size tp = args.tensor_parallel_size
...@@ -321,7 +303,6 @@ if __name__ == "__main__": ...@@ -321,7 +303,6 @@ if __name__ == "__main__":
test = TestModel( test = TestModel(
model_path, model_path,
infini_dtype=infini_dtype,
infini_device=infini_device, infini_device=infini_device,
tp=tp, tp=tp,
) )
......
...@@ -58,12 +58,6 @@ def get_args(): ...@@ -58,12 +58,6 @@ def get_args():
default="cpp", default="cpp",
help="python or cpp model", help="python or cpp model",
) )
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
help="float32, float16, bfloat16",
)
parser.add_argument( parser.add_argument(
"--batch-size", "--batch-size",
type=int, type=int,
...@@ -90,7 +84,6 @@ def test( ...@@ -90,7 +84,6 @@ def test(
prompts: str | list[str], prompts: str | list[str],
model_path, model_path,
max_new_tokens=100, max_new_tokens=100,
infini_dtype=infinicore.bfloat16,
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
backend="python", backend="python",
tp=1, tp=1,
...@@ -102,7 +95,6 @@ def test( ...@@ -102,7 +95,6 @@ def test(
model = infinilm.AutoLlamaModel.from_pretrained( model = infinilm.AutoLlamaModel.from_pretrained(
model_path, model_path,
device=infini_device, device=infini_device,
dtype=infini_dtype,
backend=backend, backend=backend,
distributed_config=DistConfig(tp), distributed_config=DistConfig(tp),
) )
...@@ -110,7 +102,7 @@ def test( ...@@ -110,7 +102,7 @@ def test(
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 加载权重 # 加载权重
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
load_model_state_dict_by_file(model, model_path, dtype=infini_dtype) load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建 tokenizer # 创建 tokenizer
...@@ -203,21 +195,12 @@ if __name__ == "__main__": ...@@ -203,21 +195,12 @@ if __name__ == "__main__":
tp = args.tp tp = args.tp
infini_device = infinicore.device(device_str, 0) infini_device = infinicore.device(device_str, 0)
if args.dtype == "float32":
infini_dtype = infinicore.float32
elif args.dtype == "bfloat16":
infini_dtype = infinicore.bfloat16
elif args.dtype == "float16":
infini_dtype = infinicore.float16
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
test( test(
prompts, prompts,
model_path, model_path,
max_new_tokens, max_new_tokens,
infini_device=infini_device, infini_device=infini_device,
infini_dtype=infini_dtype,
backend=backend, backend=backend,
tp=tp, tp=tp,
) )
...@@ -57,12 +57,6 @@ def get_args(): ...@@ -57,12 +57,6 @@ def get_args():
default="python", default="python",
help="python or cpp model", help="python or cpp model",
) )
parser.add_argument(
"--dtype",
type=str,
default="float32",
help="float32, float16, bfloat16",
)
parser.add_argument( parser.add_argument(
"--batch_size", "--batch_size",
type=int, type=int,
...@@ -83,7 +77,6 @@ def test( ...@@ -83,7 +77,6 @@ def test(
prompts: str | list[str], prompts: str | list[str],
model_path, model_path,
max_new_tokens=100, max_new_tokens=100,
infini_dtype=infinicore.bfloat16,
infini_device=infinicore.device("cpu", 0), infini_device=infinicore.device("cpu", 0),
backend="python", backend="python",
): ):
...@@ -94,7 +87,6 @@ def test( ...@@ -94,7 +87,6 @@ def test(
model = infinilm.AutoLlamaModel.from_pretrained( model = infinilm.AutoLlamaModel.from_pretrained(
model_path, model_path,
device=infini_device, device=infini_device,
dtype=infini_dtype,
backend=backend, backend=backend,
) )
...@@ -104,7 +96,7 @@ def test( ...@@ -104,7 +96,7 @@ def test(
model_param_infini = get_model_state_dict( model_param_infini = get_model_state_dict(
model_path, model_path,
device=infini_device, device=infini_device,
dtype=infini_dtype, dtype=model.config.dtype,
) )
model.load_state_dict(model_param_infini, strict=True) model.load_state_dict(model_param_infini, strict=True)
...@@ -201,20 +193,11 @@ if __name__ == "__main__": ...@@ -201,20 +193,11 @@ if __name__ == "__main__":
backend = args.backend backend = args.backend
infini_device = infinicore.device(device_str, 0) infini_device = infinicore.device(device_str, 0)
if args.dtype == "float32":
infini_dtype = infinicore.float32
elif args.dtype == "bfloat16":
infini_dtype = infinicore.bfloat16
elif args.dtype == "float16":
infini_dtype = infinicore.float16
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
test( test(
prompts, prompts,
model_path, model_path,
max_new_tokens, max_new_tokens,
infini_device=infini_device, infini_device=infini_device,
infini_dtype=infini_dtype,
backend=backend, backend=backend,
) )
...@@ -30,7 +30,6 @@ class AutoLlamaModel: ...@@ -30,7 +30,6 @@ class AutoLlamaModel:
instance = modeling_llama.LlamaForCausalLM.from_pretrained( instance = modeling_llama.LlamaForCausalLM.from_pretrained(
model_path, model_path,
device=device, device=device,
dtype=dtype,
**kwargs, **kwargs,
) )
...@@ -45,7 +44,6 @@ class AutoLlamaModel: ...@@ -45,7 +44,6 @@ class AutoLlamaModel:
instance = cpp.LlamaForCausalLM.from_pretrained( instance = cpp.LlamaForCausalLM.from_pretrained(
model_path, model_path,
device=device, device=device,
dtype=dtype,
**kwargs, **kwargs,
) )
else: else:
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
"""LLaMA model configuration""" """LLaMA model configuration"""
import infinicore
from infinilm.lib import _infinilm from infinilm.lib import _infinilm
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
...@@ -179,6 +181,7 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig): ...@@ -179,6 +181,7 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
attention_dropout=0.0, attention_dropout=0.0,
mlp_bias=False, mlp_bias=False,
head_dim=None, head_dim=None,
torch_dtype=None,
**kwargs, **kwargs,
): ):
_infinilm.LlamaConfig.__init__(self) _infinilm.LlamaConfig.__init__(self)
...@@ -225,6 +228,12 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig): ...@@ -225,6 +228,12 @@ class LlamaConfig(PretrainedConfig, _infinilm.LlamaConfig):
self.rope_scaling["rope_type"] = self.rope_scaling["type"] self.rope_scaling["rope_type"] = self.rope_scaling["type"]
# rope_config_validation(self) # rope_config_validation(self)
if torch_dtype in {"float32", "bfloat16", "float16"}:
self.dtype = getattr(infinicore, torch_dtype)
self._dtype = self.dtype._underlying
else:
raise ValueError(f"Unsupported dtype: {torch_dtype}")
PretrainedConfig.__init__( PretrainedConfig.__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -98,15 +98,16 @@ class LlamaMLP(infinicore.nn.Module): ...@@ -98,15 +98,16 @@ class LlamaMLP(infinicore.nn.Module):
hidden_size = config.hidden_size hidden_size = config.hidden_size
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
mlp_bias = config.mlp_bias mlp_bias = config.mlp_bias
dtype = config.dtype
self.gate_proj = infinicore.nn.Linear( self.gate_proj = infinicore.nn.Linear(
hidden_size, intermediate_size, bias=mlp_bias, **kwargs hidden_size, intermediate_size, bias=mlp_bias, dtype=dtype, **kwargs
) )
self.up_proj = infinicore.nn.Linear( self.up_proj = infinicore.nn.Linear(
hidden_size, intermediate_size, bias=mlp_bias, **kwargs hidden_size, intermediate_size, bias=mlp_bias, dtype=dtype, **kwargs
) )
self.down_proj = infinicore.nn.Linear( self.down_proj = infinicore.nn.Linear(
intermediate_size, hidden_size, bias=mlp_bias, **kwargs intermediate_size, hidden_size, bias=mlp_bias, dtype=dtype, **kwargs
) )
self.act_fn = infinicore.nn.functional.silu self.act_fn = infinicore.nn.functional.silu
...@@ -133,10 +134,13 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -133,10 +134,13 @@ class LlamaAttention(infinicore.nn.Module):
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
dtype = config.dtype
self.q_proj = infinicore.nn.Linear( self.q_proj = infinicore.nn.Linear(
self.hidden_size, self.hidden_size,
self.num_attention_heads * self.head_dim, self.num_attention_heads * self.head_dim,
bias=attention_bias, bias=attention_bias,
dtype=dtype,
**kwargs, **kwargs,
) )
...@@ -144,6 +148,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -144,6 +148,7 @@ class LlamaAttention(infinicore.nn.Module):
self.hidden_size, self.hidden_size,
self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim,
bias=attention_bias, bias=attention_bias,
dtype=dtype,
**kwargs, **kwargs,
) )
...@@ -151,6 +156,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -151,6 +156,7 @@ class LlamaAttention(infinicore.nn.Module):
self.hidden_size, self.hidden_size,
self.num_key_value_heads * self.head_dim, self.num_key_value_heads * self.head_dim,
bias=attention_bias, bias=attention_bias,
dtype=dtype,
**kwargs, **kwargs,
) )
...@@ -158,6 +164,7 @@ class LlamaAttention(infinicore.nn.Module): ...@@ -158,6 +164,7 @@ class LlamaAttention(infinicore.nn.Module):
self.num_attention_heads * self.head_dim, self.num_attention_heads * self.head_dim,
self.hidden_size, self.hidden_size,
bias=False, bias=False,
dtype=dtype,
**kwargs, **kwargs,
) )
...@@ -258,13 +265,16 @@ class LlamaDecoderLayer(infinicore.nn.Module): ...@@ -258,13 +265,16 @@ class LlamaDecoderLayer(infinicore.nn.Module):
super().__init__() super().__init__()
hidden_size = config.hidden_size hidden_size = config.hidden_size
rms_norm_eps = config.rms_norm_eps rms_norm_eps = config.rms_norm_eps
dtype = config.dtype
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx, **kwargs) self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx, **kwargs)
self.mlp = LlamaMLP(config=config, **kwargs) self.mlp = LlamaMLP(config=config, **kwargs)
self.input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps, **kwargs) self.input_layernorm = LlamaRMSNorm(
hidden_size, eps=rms_norm_eps, dtype=dtype, **kwargs
)
self.post_attention_layernorm = LlamaRMSNorm( self.post_attention_layernorm = LlamaRMSNorm(
hidden_size, eps=rms_norm_eps, **kwargs hidden_size, eps=rms_norm_eps, dtype=dtype, **kwargs
) )
def forward( def forward(
...@@ -317,7 +327,7 @@ class LlamaModel(infinicore.nn.Module): ...@@ -317,7 +327,7 @@ class LlamaModel(infinicore.nn.Module):
) )
self.embed_tokens = infinicore.nn.Embedding( self.embed_tokens = infinicore.nn.Embedding(
config.vocab_size, config.hidden_size, **kwargs config.vocab_size, config.hidden_size, dtype=config.dtype, **kwargs
) )
self.layers = infinicore.nn.ModuleList( self.layers = infinicore.nn.ModuleList(
...@@ -326,12 +336,15 @@ class LlamaModel(infinicore.nn.Module): ...@@ -326,12 +336,15 @@ class LlamaModel(infinicore.nn.Module):
for layer_idx in range(config.num_hidden_layers) for layer_idx in range(config.num_hidden_layers)
] ]
) )
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps, **kwargs) self.norm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps, dtype=config.dtype, **kwargs
)
self.rope_instance = infinicore.nn.RoPE( self.rope_instance = infinicore.nn.RoPE(
max_position_embeddings=config.max_position_embeddings, max_position_embeddings=config.max_position_embeddings,
rope_theta=config.rope_theta, rope_theta=config.rope_theta,
head_dim=head_dim, head_dim=head_dim,
dtype=config.dtype,
**kwargs, **kwargs,
) )
...@@ -394,6 +407,7 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin): ...@@ -394,6 +407,7 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin):
config.hidden_size, config.hidden_size,
config.vocab_size, config.vocab_size,
bias=False, bias=False,
dtype=config.dtype,
**kwargs, **kwargs,
) )
self.device = kwargs.get("device", infinicore.device("cpu")) self.device = kwargs.get("device", infinicore.device("cpu"))
...@@ -420,7 +434,6 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin): ...@@ -420,7 +434,6 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin):
cls, cls,
model_path: Optional[Union[str, os.PathLike]], model_path: Optional[Union[str, os.PathLike]],
device: infinicore.device, device: infinicore.device,
dtype=infinicore.dtype,
): ):
def load_config_json(dir_path_: str): def load_config_json(dir_path_: str):
with open(os.path.join(dir_path_, "config.json"), "r") as f: with open(os.path.join(dir_path_, "config.json"), "r") as f:
...@@ -430,4 +443,4 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin): ...@@ -430,4 +443,4 @@ class LlamaForCausalLM(infinicore.nn.Module, GenerationMixin):
config_dict = load_config_json(os.path.join(model_path)) config_dict = load_config_json(os.path.join(model_path))
config = LlamaConfig(**config_dict) config = LlamaConfig(**config_dict)
return LlamaForCausalLM(config, device=device, dtype=dtype) return LlamaForCausalLM(config, device=device)
...@@ -77,7 +77,6 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -77,7 +77,6 @@ class InfiniLMBenchmark(BaseBenchmark):
# When CUDA_VISIBLE_DEVICES=5 is set, CUDA only sees device 5 as device 0 # When CUDA_VISIBLE_DEVICES=5 is set, CUDA only sees device 5 as device 0
# So device index 0 will automatically map to the first visible device # So device index 0 will automatically map to the first visible device
self.device = infinicore.device(device_name, 0) self.device = infinicore.device(device_name, 0)
self.dtype = infinicore.bfloat16
# Load config and tokenizer # Load config and tokenizer
with open(os.path.join(model_dir_path, "config.json"), "r") as f: with open(os.path.join(model_dir_path, "config.json"), "r") as f:
...@@ -117,7 +116,6 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -117,7 +116,6 @@ class InfiniLMBenchmark(BaseBenchmark):
self.model = AutoLlamaModel.from_pretrained( self.model = AutoLlamaModel.from_pretrained(
model_dir_path, model_dir_path,
device=self.device, device=self.device,
dtype=self.dtype,
backend=backend, backend=backend,
distributed_config=DistConfig(ndev), distributed_config=DistConfig(ndev),
) )
...@@ -130,7 +128,7 @@ class InfiniLMBenchmark(BaseBenchmark): ...@@ -130,7 +128,7 @@ class InfiniLMBenchmark(BaseBenchmark):
load_model_state_dict_by_file( load_model_state_dict_by_file(
self.model, self.model,
model_dir_path, model_dir_path,
dtype=self.dtype, dtype=self.model.config.dtype,
) )
print("Model loaded successfully") print("Model loaded successfully")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment