Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
06dcc067
Commit
06dcc067
authored
Jan 14, 2026
by
PanZezhong
Browse files
issue/920 RoPE supports longrope
parent
180674dc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
11 deletions
+73
-11
include/infinicore/nn/rope.hpp
include/infinicore/nn/rope.hpp
+50
-6
src/infinicore/nn/rope.cc
src/infinicore/nn/rope.cc
+23
-5
No files found.
include/infinicore/nn/rope.hpp
View file @
06dcc067
...
@@ -17,6 +17,47 @@ public:
...
@@ -17,6 +17,47 @@ public:
GPT_NEOX
=
1
,
// GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
GPT_NEOX
=
1
,
// GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
};
};
enum
class
ScalingType
{
DEFAULT
=
0
,
// Default RoPE
LONGROPE
=
1
// Long-RoPE
};
class
ScalingConfig
{
public:
virtual
~
ScalingConfig
()
=
default
;
ScalingType
type
()
const
{
return
type_
;
}
protected:
ScalingType
type_
=
ScalingType
::
DEFAULT
;
ScalingConfig
(
ScalingType
type
)
:
type_
(
type
)
{}
};
// longrope scaling
class
LongRopeConfig
:
public
ScalingConfig
{
protected:
std
::
vector
<
float
>
short_factor_
;
std
::
vector
<
float
>
long_factor_
;
size_t
original_max_position_embeddings_
;
float
factor_
;
public:
LongRopeConfig
(
std
::
vector
<
float
>
short_factor
,
std
::
vector
<
float
>
long_factor
,
size_t
original_max_position_embeddings
,
float
factor
=
1.0
f
)
:
ScalingConfig
(
ScalingType
::
LONGROPE
),
short_factor_
(
short_factor
),
long_factor_
(
long_factor
),
original_max_position_embeddings_
(
original_max_position_embeddings
),
factor_
(
factor
==
1.0
f
?
1.0
f
:
std
::
sqrt
(
1
+
std
::
log
(
factor
)
/
std
::
log
(
original_max_position_embeddings
)))
{}
~
LongRopeConfig
()
override
=
default
;
size_t
original_max_position_embeddings
()
const
{
return
original_max_position_embeddings_
;
}
const
std
::
vector
<
float
>
&
short_factor
()
const
{
return
short_factor_
;
}
const
std
::
vector
<
float
>
&
long_factor
()
const
{
return
long_factor_
;
}
float
factor
()
const
{
return
factor_
;
}
};
/**
/**
* @brief Construct a RoPE layer
* @brief Construct a RoPE layer
*
*
...
@@ -26,13 +67,15 @@ public:
...
@@ -26,13 +67,15 @@ public:
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on
* @param device Device to create the cache on
* @param scaling RoPE scaling type (default: nullptr)
*/
*/
RoPE
(
size_t
head_dim
,
RoPE
(
size_t
head_dim
,
size_t
max_seq_len
,
size_t
max_seq_len
,
double
theta
=
10000.0
,
double
theta
=
10000.0
,
Algo
algo
=
Algo
::
GPT_J
,
Algo
algo
=
Algo
::
GPT_J
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
DataType
&
dtype
=
DataType
::
F32
,
const
Device
&
device
=
Device
());
const
Device
&
device
=
Device
(),
std
::
shared_ptr
<
ScalingConfig
>
scaling
=
nullptr
);
/**
/**
* @brief Forward pass: apply RoPE to a tensor
* @brief Forward pass: apply RoPE to a tensor
...
@@ -88,11 +131,12 @@ protected:
...
@@ -88,11 +131,12 @@ protected:
private:
private:
void
initialize_cache
();
void
initialize_cache
();
size_t
head_dim_
;
// Dimension of each attention head
size_t
head_dim_
;
// Dimension of each attention head
size_t
max_seq_len_
;
// Maximum sequence length
size_t
max_seq_len_
;
// Maximum sequence length
double
theta_
;
// Base frequency for rotary embeddings
double
theta_
;
// Base frequency for rotary embeddings
Algo
algo_
;
// RoPE algorithm type
Algo
algo_
;
// RoPE algorithm type
DataType
dtype_
;
// Data type for cache tables
DataType
dtype_
;
// Data type for cache tables
std
::
shared_ptr
<
ScalingConfig
>
scaling_
;
// RoPE scaling type
};
};
}
// namespace infinicore::nn
}
// namespace infinicore::nn
src/infinicore/nn/rope.cc
View file @
06dcc067
...
@@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim,
...
@@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim,
double
theta
,
double
theta
,
Algo
algo
,
Algo
algo
,
const
DataType
&
dtype
,
const
DataType
&
dtype
,
const
Device
&
device
)
const
Device
&
device
,
std
::
shared_ptr
<
ScalingConfig
>
scaling
)
:
head_dim_
(
head_dim
),
:
head_dim_
(
head_dim
),
max_seq_len_
(
max_seq_len
),
max_seq_len_
(
max_seq_len
),
theta_
(
theta
),
theta_
(
theta
),
algo_
(
algo
),
algo_
(
algo
),
dtype_
(
dtype
)
{
dtype_
(
dtype
),
scaling_
(
scaling
)
{
if
(
head_dim
%
2
!=
0
)
{
if
(
head_dim
%
2
!=
0
)
{
throw
std
::
invalid_argument
(
"head_dim must be even for RoPE, got "
+
std
::
to_string
(
head_dim
));
throw
std
::
invalid_argument
(
"head_dim must be even for RoPE, got "
+
std
::
to_string
(
head_dim
));
}
}
...
@@ -54,14 +56,30 @@ void RoPE::initialize_cache() {
...
@@ -54,14 +56,30 @@ void RoPE::initialize_cache() {
for
(
size_t
j
=
0
;
j
<
cache_dim
;
j
++
)
{
for
(
size_t
j
=
0
;
j
<
cache_dim
;
j
++
)
{
// GPT-J style inverse frequency: theta^(-2j/head_dim)
// GPT-J style inverse frequency: theta^(-2j/head_dim)
// Compute directly in float to avoid double->float casting
// Compute directly in float to avoid double->float casting
float
inv_freq
=
1.0
f
/
std
::
pow
(
static_cast
<
float
>
(
theta_
),
2.0
f
*
static_cast
<
float
>
(
j
)
/
static_cast
<
float
>
(
head_dim_
));
float
inv_freq
;
float
table_factor
=
1.0
f
;
if
(
scaling_
==
nullptr
)
{
inv_freq
=
1.0
f
/
std
::
pow
(
static_cast
<
float
>
(
theta_
),
2.0
f
*
static_cast
<
float
>
(
j
)
/
static_cast
<
float
>
(
head_dim_
));
}
else
if
(
scaling_
->
type
()
==
ScalingType
::
LONGROPE
)
{
std
::
shared_ptr
<
LongRopeConfig
>
lr
=
std
::
dynamic_pointer_cast
<
LongRopeConfig
>
(
scaling_
);
table_factor
=
lr
->
factor
();
float
_ext
;
if
(
pos
<
lr
->
original_max_position_embeddings
())
{
_ext
=
lr
->
short_factor
()[
j
];
}
else
{
_ext
=
lr
->
long_factor
()[
j
];
}
inv_freq
=
1.0
f
/
(
_ext
*
std
::
pow
(
static_cast
<
float
>
(
theta_
),
2.0
f
*
static_cast
<
float
>
(
j
)
/
static_cast
<
float
>
(
head_dim_
)));
}
else
{
inv_freq
=
1.0
f
/
std
::
pow
(
static_cast
<
float
>
(
theta_
),
2.0
f
*
static_cast
<
float
>
(
j
)
/
static_cast
<
float
>
(
head_dim_
));
}
// Compute angle: position * inverse_frequency
// Compute angle: position * inverse_frequency
float
angle
=
static_cast
<
float
>
(
pos
)
*
inv_freq
;
float
angle
=
static_cast
<
float
>
(
pos
)
*
inv_freq
;
// Compute sin and cos directly on float
// Compute sin and cos directly on float
sin_data
[
pos
*
cache_dim
+
j
]
=
std
::
sin
(
angle
);
sin_data
[
pos
*
cache_dim
+
j
]
=
std
::
sin
(
angle
)
*
table_factor
;
cos_data
[
pos
*
cache_dim
+
j
]
=
std
::
cos
(
angle
);
cos_data
[
pos
*
cache_dim
+
j
]
=
std
::
cos
(
angle
)
*
table_factor
;
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment