Commit 06dcc067 authored by PanZezhong's avatar PanZezhong
Browse files

issue/920 RoPE supports longrope

parent 180674dc
...@@ -17,6 +17,47 @@ public: ...@@ -17,6 +17,47 @@ public:
GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos) GPT_NEOX = 1, // GPT-NeoX style RoPE algorithm (First half dimensions for sin, second half for cos)
}; };
enum class ScalingType {
DEFAULT = 0, // Default RoPE
LONGROPE = 1 // Long-RoPE
};
class ScalingConfig {
public:
virtual ~ScalingConfig() = default;
ScalingType type() const { return type_; }
protected:
ScalingType type_ = ScalingType::DEFAULT;
ScalingConfig(ScalingType type) : type_(type) {}
};
// longrope scaling
class LongRopeConfig : public ScalingConfig {
protected:
std::vector<float> short_factor_;
std::vector<float> long_factor_;
size_t original_max_position_embeddings_;
float factor_;
public:
LongRopeConfig(
std::vector<float> short_factor,
std::vector<float> long_factor,
size_t original_max_position_embeddings,
float factor = 1.0f)
: ScalingConfig(ScalingType::LONGROPE),
short_factor_(short_factor),
long_factor_(long_factor),
original_max_position_embeddings_(original_max_position_embeddings),
factor_(factor == 1.0f ? 1.0f : std::sqrt(1 + std::log(factor) / std::log(original_max_position_embeddings))) {}
~LongRopeConfig() override = default;
size_t original_max_position_embeddings() const { return original_max_position_embeddings_; }
const std::vector<float> &short_factor() const { return short_factor_; }
const std::vector<float> &long_factor() const { return long_factor_; }
float factor() const { return factor_; }
};
/** /**
* @brief Construct a RoPE layer * @brief Construct a RoPE layer
* *
...@@ -26,13 +67,15 @@ public: ...@@ -26,13 +67,15 @@ public:
* @param algo RoPE algorithm type (default: Algo::GPT_J) * @param algo RoPE algorithm type (default: Algo::GPT_J)
* @param dtype Data type for sin/cos cache (default: DataType::F32) * @param dtype Data type for sin/cos cache (default: DataType::F32)
* @param device Device to create the cache on * @param device Device to create the cache on
* @param scaling RoPE scaling type (default: nullptr)
*/ */
RoPE(size_t head_dim, RoPE(size_t head_dim,
size_t max_seq_len, size_t max_seq_len,
double theta = 10000.0, double theta = 10000.0,
Algo algo = Algo::GPT_J, Algo algo = Algo::GPT_J,
const DataType &dtype = DataType::F32, const DataType &dtype = DataType::F32,
const Device &device = Device()); const Device &device = Device(),
std::shared_ptr<ScalingConfig> scaling = nullptr);
/** /**
* @brief Forward pass: apply RoPE to a tensor * @brief Forward pass: apply RoPE to a tensor
...@@ -93,6 +136,7 @@ private: ...@@ -93,6 +136,7 @@ private:
double theta_; // Base frequency for rotary embeddings double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables DataType dtype_; // Data type for cache tables
std::shared_ptr<ScalingConfig> scaling_; // RoPE scaling type
}; };
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim, ...@@ -16,12 +16,14 @@ RoPE::RoPE(size_t head_dim,
double theta, double theta,
Algo algo, Algo algo,
const DataType &dtype, const DataType &dtype,
const Device &device) const Device &device,
std::shared_ptr<ScalingConfig> scaling)
: head_dim_(head_dim), : head_dim_(head_dim),
max_seq_len_(max_seq_len), max_seq_len_(max_seq_len),
theta_(theta), theta_(theta),
algo_(algo), algo_(algo),
dtype_(dtype) { dtype_(dtype),
scaling_(scaling) {
if (head_dim % 2 != 0) { if (head_dim % 2 != 0) {
throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim)); throw std::invalid_argument("head_dim must be even for RoPE, got " + std::to_string(head_dim));
} }
...@@ -54,14 +56,30 @@ void RoPE::initialize_cache() { ...@@ -54,14 +56,30 @@ void RoPE::initialize_cache() {
for (size_t j = 0; j < cache_dim; j++) { for (size_t j = 0; j < cache_dim; j++) {
// GPT-J style inverse frequency: theta^(-2j/head_dim) // GPT-J style inverse frequency: theta^(-2j/head_dim)
// Compute directly in float to avoid double->float casting // Compute directly in float to avoid double->float casting
float inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)); float inv_freq;
float table_factor = 1.0f;
if (scaling_ == nullptr) {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
} else if (scaling_->type() == ScalingType::LONGROPE) {
std::shared_ptr<LongRopeConfig> lr = std::dynamic_pointer_cast<LongRopeConfig>(scaling_);
table_factor = lr->factor();
float _ext;
if (pos < lr->original_max_position_embeddings()) {
_ext = lr->short_factor()[j];
} else {
_ext = lr->long_factor()[j];
}
inv_freq = 1.0f / (_ext * std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_)));
} else {
inv_freq = 1.0f / std::pow(static_cast<float>(theta_), 2.0f * static_cast<float>(j) / static_cast<float>(head_dim_));
}
// Compute angle: position * inverse_frequency // Compute angle: position * inverse_frequency
float angle = static_cast<float>(pos) * inv_freq; float angle = static_cast<float>(pos) * inv_freq;
// Compute sin and cos directly on float // Compute sin and cos directly on float
sin_data[pos * cache_dim + j] = std::sin(angle); sin_data[pos * cache_dim + j] = std::sin(angle) * table_factor;
cos_data[pos * cache_dim + j] = std::cos(angle); cos_data[pos * cache_dim + j] = std::cos(angle) * table_factor;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment