Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
de3e6b95
Unverified
Commit
de3e6b95
authored
Jan 16, 2026
by
Haojie Wang
Committed by
GitHub
Jan 16, 2026
Browse files
Merge pull request #187 from InfiniTensor/issue/186
issue/186 support longrope
parents
c1a3ab29
fc454c77
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
137 additions
and
65 deletions
+137
-65
csrc/models/llama/llama_config.hpp
csrc/models/llama/llama_config.hpp
+25
-21
csrc/models/llama/llama_model.cpp
csrc/models/llama/llama_model.cpp
+1
-1
csrc/pybind11/models/llama.hpp
csrc/pybind11/models/llama.hpp
+86
-2
test/bench/test_benchmark.py
test/bench/test_benchmark.py
+25
-41
No files found.
csrc/models/llama/llama_config.hpp
View file @
de3e6b95
...
@@ -7,6 +7,8 @@
...
@@ -7,6 +7,8 @@
#include "../infinilm_model.hpp"
#include "../infinilm_model.hpp"
#include <infinicore/nn/rope.hpp>
namespace
infinilm
::
models
::
llama
{
namespace
infinilm
::
models
::
llama
{
/**
/**
...
@@ -20,41 +22,43 @@ struct LlamaConfig : public InfinilmModel::Config {
...
@@ -20,41 +22,43 @@ struct LlamaConfig : public InfinilmModel::Config {
infinicore
::
DataType
dtype
=
infinicore
::
DataType
::
F32
;
infinicore
::
DataType
dtype
=
infinicore
::
DataType
::
F32
;
// Vocabulary and embedding
// Vocabulary and embedding
size_t
vocab_size
=
32000
;
// Vocabulary size
size_t
vocab_size
=
32000
;
// Vocabulary size
size_t
hidden_size
=
4096
;
// Hidden dimension size
size_t
hidden_size
=
4096
;
// Hidden dimension size
size_t
intermediate_size
=
11008
;
// MLP intermediate dimension
size_t
intermediate_size
=
11008
;
// MLP intermediate dimension
// Architecture
// Architecture
size_t
num_hidden_layers
=
32
;
// Number of decoder layers
size_t
num_hidden_layers
=
32
;
// Number of decoder layers
size_t
num_attention_heads
=
32
;
// Number of attention heads
size_t
num_attention_heads
=
32
;
// Number of attention heads
size_t
num_key_value_heads
=
32
;
// Number of key-value heads (for GQA)
size_t
num_key_value_heads
=
32
;
// Number of key-value heads (for GQA)
size_t
head_dim
=
128
;
// Attention head dimension (hidden_size / num_attention_heads)
size_t
head_dim
=
128
;
// Attention head dimension (hidden_size / num_attention_heads)
// Position embeddings
// Position embeddings
size_t
max_position_embeddings
=
2048
;
// Maximum sequence length
size_t
max_position_embeddings
=
2048
;
// Maximum sequence length
double
rope_theta
=
10000.0
;
// RoPE base frequency
double
rope_theta
=
10000.0
;
// RoPE base frequency
std
::
shared_ptr
<
infinicore
::
nn
::
RoPE
::
ScalingConfig
>
rope_scaling
=
nullptr
;
// RoPE scaling type
// Normalization
// Normalization
double
rms_norm_eps
=
1e-6
;
// RMSNorm epsilon
double
rms_norm_eps
=
1e-6
;
// RMSNorm epsilon
// Activation
// Activation
std
::
string
hidden_act
=
"silu"
;
// Activation function (typically "silu")
std
::
string
hidden_act
=
"silu"
;
// Activation function (typically "silu")
std
::
string
model_type
=
"llama"
;
// Model type identifier (matches HF configs)
std
::
string
model_type
=
"llama"
;
// Model type identifier (matches HF configs)
// Optional features
// Optional features
bool
use_cache
=
true
;
// Whether to use KV cache
bool
use_cache
=
true
;
// Whether to use KV cache
bool
attention_bias
=
true
;
// Whether to use bias in Q/K/V projections (default true for 9G7B compatibility)
bool
attention_bias
=
true
;
// Whether to use bias in Q/K/V projections (default true for 9G7B compatibility)
bool
attention_output_bias
=
false
;
// Whether to use bias in output projection (o_proj)
bool
attention_output_bias
=
false
;
// Whether to use bias in output projection (o_proj)
bool
mlp_bias
=
false
;
// Whether to use bias in MLP projections
bool
mlp_bias
=
false
;
// Whether to use bias in MLP projections
bool
tie_word_embeddings
=
false
;
// Whether to tie input/output embeddings
bool
tie_word_embeddings
=
false
;
// Whether to tie input/output embeddings
// Training/initialization parameters
// Training/initialization parameters
double
attention_dropout
=
0.0
;
// Dropout ratio for attention probabilities
double
attention_dropout
=
0.0
;
// Dropout ratio for attention probabilities
double
initializer_range
=
0.02
;
// Standard deviation for weight initialization
double
initializer_range
=
0.02
;
// Standard deviation for weight initialization
size_t
pretraining_tp
=
1
;
// Tensor parallelism rank used during pretraining
size_t
pretraining_tp
=
1
;
// Tensor parallelism rank used during pretraining
// Model metadata
// Model metadata
std
::
string
name_or_path
=
""
;
// Model name or path identifier
std
::
string
name_or_path
=
""
;
// Model name or path identifier
// Token IDs
// Token IDs
int64_t
pad_token_id
=
-
1
;
// Padding token ID (optional)
int64_t
pad_token_id
=
-
1
;
// Padding token ID (optional)
...
...
csrc/models/llama/llama_model.cpp
View file @
de3e6b95
...
@@ -34,7 +34,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
...
@@ -34,7 +34,7 @@ LlamaModel::LlamaModel(const LlamaConfig &config,
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
// Use GPT-J-style inverse frequencies (default) and GPT_NEOX rotation pairing
INFINICORE_NN_MODULE_INIT
(
rotary_emb
,
config
.
head_dim
,
config
.
max_position_embeddings
,
INFINICORE_NN_MODULE_INIT
(
rotary_emb
,
config
.
head_dim
,
config
.
max_position_embeddings
,
config
.
rope_theta
,
infinicore
::
nn
::
RoPE
::
Algo
::
GPT_NEOX
,
config
.
rope_theta
,
infinicore
::
nn
::
RoPE
::
Algo
::
GPT_NEOX
,
dtype
,
device
);
dtype
,
device
,
config
.
rope_scaling
);
for
(
auto
&
layer
:
layers_
)
{
for
(
auto
&
layer
:
layers_
)
{
if
(
layer
)
{
if
(
layer
)
{
...
...
csrc/pybind11/models/llama.hpp
View file @
de3e6b95
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
#include "../../models/llama/llama_attention.hpp"
#include "../../models/llama/llama_attention.hpp"
#include "infinicore/device.hpp"
#include "infinicore/device.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/nn/rope.hpp"
#include "infinicore/tensor.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/numpy.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
...
@@ -69,7 +70,8 @@ inline void bind_llama(py::module &m) {
...
@@ -69,7 +70,8 @@ inline void bind_llama(py::module &m) {
.
def_readwrite
(
"pretraining_tp"
,
&
LlamaConfig
::
pretraining_tp
)
.
def_readwrite
(
"pretraining_tp"
,
&
LlamaConfig
::
pretraining_tp
)
.
def_readwrite
(
"name_or_path"
,
&
LlamaConfig
::
name_or_path
)
.
def_readwrite
(
"name_or_path"
,
&
LlamaConfig
::
name_or_path
)
.
def_readwrite
(
"pad_token_id"
,
&
LlamaConfig
::
pad_token_id
)
.
def_readwrite
(
"pad_token_id"
,
&
LlamaConfig
::
pad_token_id
)
.
def_property
(
"bos_token_id"
,
[](
const
LlamaConfig
&
self
)
{
.
def_property
(
"bos_token_id"
,
[](
const
LlamaConfig
&
self
)
{
// Always return as list to match Python config format
// Always return as list to match Python config format
return
py
::
cast
(
self
.
bos_token_id
);
},
[](
LlamaConfig
&
self
,
py
::
object
value
)
{
return
py
::
cast
(
self
.
bos_token_id
);
},
[](
LlamaConfig
&
self
,
py
::
object
value
)
{
// Accept both single int and list
// Accept both single int and list
...
@@ -80,7 +82,8 @@ inline void bind_llama(py::module &m) {
...
@@ -80,7 +82,8 @@ inline void bind_llama(py::module &m) {
}
else
{
}
else
{
throw
py
::
type_error
(
"bos_token_id must be int or list of ints"
);
throw
py
::
type_error
(
"bos_token_id must be int or list of ints"
);
}
})
}
})
.
def_property
(
"eos_token_id"
,
[](
const
LlamaConfig
&
self
)
{
.
def_property
(
"eos_token_id"
,
[](
const
LlamaConfig
&
self
)
{
// Always return as list to match Python config format
// Always return as list to match Python config format
return
py
::
cast
(
self
.
eos_token_id
);
},
[](
LlamaConfig
&
self
,
py
::
object
value
)
{
return
py
::
cast
(
self
.
eos_token_id
);
},
[](
LlamaConfig
&
self
,
py
::
object
value
)
{
// Accept both single int and list
// Accept both single int and list
...
@@ -91,6 +94,86 @@ inline void bind_llama(py::module &m) {
...
@@ -91,6 +94,86 @@ inline void bind_llama(py::module &m) {
}
else
{
}
else
{
throw
py
::
type_error
(
"eos_token_id must be int or list of ints"
);
throw
py
::
type_error
(
"eos_token_id must be int or list of ints"
);
}
})
}
})
.
def_property
(
"rope_scaling"
,
// ---------- getter ----------
[](
const
LlamaConfig
&
self
)
->
py
::
object
{
if
(
!
self
.
rope_scaling
)
{
return
py
::
none
();
}
using
ScalingConfig
=
infinicore
::
nn
::
RoPE
::
ScalingConfig
;
using
LongRopeConfig
=
infinicore
::
nn
::
RoPE
::
LongRopeConfig
;
py
::
dict
d
;
if
(
auto
*
lr
=
dynamic_cast
<
const
LongRopeConfig
*>
(
self
.
rope_scaling
.
get
()))
{
d
[
"type"
]
=
"longrope"
;
d
[
"rope_type"
]
=
"longrope"
;
d
[
"factor"
]
=
lr
->
factor
();
d
[
"original_max_position_embeddings"
]
=
lr
->
original_max_position_embeddings
();
d
[
"short_factor"
]
=
lr
->
short_factor
();
d
[
"long_factor"
]
=
lr
->
long_factor
();
}
else
{
throw
std
::
runtime_error
(
"Unknown RoPE scaling type"
);
}
return
std
::
move
(
d
);
},
// ---------- setter ----------
[](
LlamaConfig
&
self
,
py
::
object
value
)
{
if
(
value
.
is_none
())
{
self
.
rope_scaling
.
reset
();
return
;
}
if
(
!
py
::
isinstance
<
py
::
dict
>
(
value
))
{
throw
py
::
type_error
(
"rope_scaling must be a dict or None"
);
}
py
::
dict
d
=
value
.
cast
<
py
::
dict
>
();
auto
get_str
=
[
&
](
const
char
*
k
)
{
if
(
!
d
.
contains
(
k
))
{
throw
py
::
key_error
(
k
);
}
return
py
::
cast
<
std
::
string
>
(
d
[
k
]);
};
std
::
string
type
=
d
.
contains
(
"rope_type"
)
?
py
::
cast
<
std
::
string
>
(
d
[
"rope_type"
])
:
get_str
(
"type"
);
if
(
type
==
"longrope"
)
{
using
LongRopeConfig
=
infinicore
::
nn
::
RoPE
::
LongRopeConfig
;
if
(
!
d
.
contains
(
"short_factor"
)
||
!
d
.
contains
(
"long_factor"
)
||
!
d
.
contains
(
"original_max_position_embeddings"
))
{
throw
py
::
value_error
(
"longrope requires short_factor, long_factor, "
"original_max_position_embeddings"
);
}
std
::
vector
<
float
>
short_factor
=
py
::
cast
<
std
::
vector
<
float
>>
(
d
[
"short_factor"
]);
std
::
vector
<
float
>
long_factor
=
py
::
cast
<
std
::
vector
<
float
>>
(
d
[
"long_factor"
]);
size_t
original_max_position_embeddings
=
py
::
cast
<
size_t
>
(
d
[
"original_max_position_embeddings"
]);
float
factor
=
1.0
f
;
if
(
d
.
contains
(
"factor"
))
{
factor
=
py
::
cast
<
float
>
(
d
[
"factor"
]);
}
self
.
rope_scaling
=
std
::
make_shared
<
LongRopeConfig
>
(
std
::
move
(
short_factor
),
std
::
move
(
long_factor
),
original_max_position_embeddings
,
factor
);
}
else
{
throw
py
::
value_error
(
"Unsupported rope_scaling type: "
+
type
);
}
})
.
def
(
"validate"
,
&
LlamaConfig
::
validate
)
.
def
(
"validate"
,
&
LlamaConfig
::
validate
)
.
def
(
"kv_dim"
,
&
LlamaConfig
::
kv_dim
)
.
def
(
"kv_dim"
,
&
LlamaConfig
::
kv_dim
)
// Add __dir__ to make attributes discoverable via dir() in Python
// Add __dir__ to make attributes discoverable via dir() in Python
...
@@ -108,6 +191,7 @@ inline void bind_llama(py::module &m) {
...
@@ -108,6 +191,7 @@ inline void bind_llama(py::module &m) {
dir_list
.
append
(
"hidden_act"
);
dir_list
.
append
(
"hidden_act"
);
dir_list
.
append
(
"model_type"
);
dir_list
.
append
(
"model_type"
);
dir_list
.
append
(
"rope_theta"
);
dir_list
.
append
(
"rope_theta"
);
dir_list
.
append
(
"rope_scaling"
);
dir_list
.
append
(
"attention_bias"
);
dir_list
.
append
(
"attention_bias"
);
dir_list
.
append
(
"attention_output_bias"
);
dir_list
.
append
(
"attention_output_bias"
);
dir_list
.
append
(
"mlp_bias"
);
dir_list
.
append
(
"mlp_bias"
);
...
...
test/bench/test_benchmark.py
View file @
de3e6b95
...
@@ -368,7 +368,7 @@ def render_ceval(_tokenizer, conversation):
...
@@ -368,7 +368,7 @@ def render_ceval(_tokenizer, conversation):
def
render_mmlu
(
_tokenizer
,
question
,
choices
):
def
render_mmlu
(
_tokenizer
,
question
,
choices
):
"""Render MMLU question and choices to input content"""
"""Render MMLU question and choices to input content"""
choices_text
=
"
\n
"
.
join
(
choices_text
=
"
\n
"
.
join
(
[
f
"
{
chr
(
65
+
i
)
}
.
{
choice
}
"
for
i
,
choice
in
enumerate
(
choices
)]
[
f
"
{
chr
(
65
+
i
)
}
.
{
choice
}
"
for
i
,
choice
in
enumerate
(
choices
)]
)
)
instruction
=
(
instruction
=
(
"You are a multiple-choice question solver. "
"You are a multiple-choice question solver. "
...
@@ -924,7 +924,9 @@ def test():
...
@@ -924,7 +924,9 @@ def test():
splits_to_load
=
(
splits_to_load
=
(
[
"test"
]
[
"test"
]
if
split
==
"test"
if
split
==
"test"
else
[
"validation"
]
if
split
==
"val"
else
[
"validation"
,
"test"
]
else
[
"validation"
]
if
split
==
"val"
else
[
"validation"
,
"test"
]
)
)
# Load each subject individually from hardcoded list, excluding "all"
# Load each subject individually from hardcoded list, excluding "all"
for
subject_name
in
mmlu_subjects
:
for
subject_name
in
mmlu_subjects
:
...
@@ -946,7 +948,9 @@ def test():
...
@@ -946,7 +948,9 @@ def test():
splits_to_load
=
(
splits_to_load
=
(
[
"test"
]
[
"test"
]
if
split
==
"test"
if
split
==
"test"
else
[
"validation"
]
if
split
==
"val"
else
[
"validation"
,
"test"
]
else
[
"validation"
]
if
split
==
"val"
else
[
"validation"
,
"test"
]
)
)
records
=
[]
records
=
[]
for
sp
in
splits_to_load
:
for
sp
in
splits_to_load
:
...
@@ -980,14 +984,13 @@ def test():
...
@@ -980,14 +984,13 @@ def test():
all_results
=
[]
all_results
=
[]
for
subj
in
subject_list
:
for
subj
in
subject_list
:
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"Evaluating subject:
{
subj
}
"
)
print
(
f
"Evaluating subject:
{
subj
}
"
)
print
(
f
"
{
'='
*
60
}
\n
"
)
print
(
f
"
{
'='
*
60
}
\n
"
)
try
:
try
:
samples
,
actual_subj_name
=
load_subject_samples
(
subj
)
samples
,
actual_subj_name
=
load_subject_samples
(
subj
)
print
(
f
"Loaded
{
len
(
samples
)
}
samples for subject:
{
actual_subj_name
}
"
)
print
(
f
"Loaded
{
len
(
samples
)
}
samples for subject:
{
actual_subj_name
}
"
)
# Limit number of samples if specified
# Limit number of samples if specified
if
num_samples
is
not
None
and
num_samples
>
0
:
if
num_samples
is
not
None
and
num_samples
>
0
:
original_count
=
len
(
samples
)
original_count
=
len
(
samples
)
...
@@ -996,37 +999,9 @@ def test():
...
@@ -996,37 +999,9 @@ def test():
f
"Limited to
{
len
(
samples
)
}
samples for validation (from
{
original_count
}
total)"
f
"Limited to
{
len
(
samples
)
}
samples for validation (from
{
original_count
}
total)"
)
)
# Test with first sample if available
if
len
(
samples
)
==
0
:
if
len
(
samples
)
>
0
:
print
(
f
"No samples found for subject:
{
actual_subj_name
}
"
)
sample
=
samples
[
0
]
continue
if
benchmark
==
"ceval"
:
input_content
=
f
"'question':
{
sample
[
'question'
]
}
,'A':
{
sample
[
'A'
]
}
, 'B':
{
sample
[
'B'
]
}
, 'C':
{
sample
[
'C'
]
}
,'D':
{
sample
[
'D'
]
}
。"
test_conversation
=
[
{
"role"
:
"system"
,
"content"
:
"请从question的A,B,C,D四个选项中选择正确的选项。例如,标准答案:A。"
,
},
{
"role"
:
"user"
,
"content"
:
input_content
},
]
test_output
=
model
.
generate
(
test_conversation
,
max_steps
=
max_new_tokens
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
,
)
elif
benchmark
==
"mmlu"
:
question
=
sample
[
"question"
]
choices
=
sample
[
"choices"
]
test_output
=
model
.
generate
(
question
,
choices
,
max_steps
=
max_new_tokens
,
topp_
=
1.0
,
topk_
=
1
,
temperature_
=
1.0
,
)
print
(
f
"
\n
Test output:
{
test_output
}
\n
"
)
# Evaluate samples for this subject
# Evaluate samples for this subject
result
=
evaluate_samples
(
result
=
evaluate_samples
(
...
@@ -1044,13 +1019,22 @@ def test():
...
@@ -1044,13 +1019,22 @@ def test():
model
.
destroy_model_instance
()
model
.
destroy_model_instance
()
# Calculate overall results
# Calculate overall results
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
"OVERALL RESULTS"
)
print
(
f
"
{
'='
*
60
}
"
)
if
len
(
all_results
)
==
0
:
print
(
"No tests were run."
)
return
elif
len
(
all_results
)
>
1
:
for
r
in
all_results
:
print
(
f
"Subject '
{
r
[
'subject'
]
}
':
{
r
[
'correct'
]
}
/
{
r
[
'total'
]
}
=
{
r
[
'accuracy'
]:.
2
%
}
"
)
overall_correct
=
sum
(
r
[
"correct"
]
for
r
in
all_results
)
overall_correct
=
sum
(
r
[
"correct"
]
for
r
in
all_results
)
overall_total
=
sum
(
r
[
"total"
]
for
r
in
all_results
)
overall_total
=
sum
(
r
[
"total"
]
for
r
in
all_results
)
overall_accuracy
=
overall_correct
/
overall_total
if
overall_total
>
0
else
0.0
overall_accuracy
=
overall_correct
/
overall_total
if
overall_total
>
0
else
0.0
print
(
f
"
\n
{
'='
*
60
}
"
)
print
(
f
"
{
'='
*
60
}
"
)
print
(
"OVERALL RESULTS"
)
print
(
f
"
{
'='
*
60
}
"
)
if
benchmark
==
"ceval"
:
if
benchmark
==
"ceval"
:
print
(
print
(
f
"Overall 成绩:
{
overall_correct
}
/
{
overall_total
}
=
{
overall_accuracy
:.
2
%
}
"
f
"Overall 成绩:
{
overall_correct
}
/
{
overall_total
}
=
{
overall_accuracy
:.
2
%
}
"
...
@@ -1062,7 +1046,7 @@ def test():
...
@@ -1062,7 +1046,7 @@ def test():
print
(
f
"Total Latency:
{
TOTAL_TIME
}
seconds"
)
print
(
f
"Total Latency:
{
TOTAL_TIME
}
seconds"
)
print
(
f
"Total Tokens Processed:
{
TOTAL_TOKENS
}
tokens"
)
print
(
f
"Total Tokens Processed:
{
TOTAL_TOKENS
}
tokens"
)
print
(
f
"Overall Throughput:
{
TOTAL_TOKENS
/
TOTAL_TIME
:.
2
f
}
tokens/s"
)
print
(
f
"Overall Throughput:
{
TOTAL_TOKENS
/
TOTAL_TIME
:.
2
f
}
tokens/s"
)
# Write CSV if output path is specified
# Write CSV if output path is specified
if
output_csv
:
if
output_csv
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment