Commit e238ace2 authored by PanZezhong's avatar PanZezhong Committed by thatPepe
Browse files

issue/709 nn::RoPE forward支持inplace,添加unsqueeze

parent b6f8f8c3
#pragma once #pragma once
#include "module.hpp"
#include "../context/context.hpp" #include "../context/context.hpp"
#include "../tensor.hpp" #include "../tensor.hpp"
#include "module.hpp"
#include <memory> #include <memory>
namespace infinicore::nn { namespace infinicore::nn {
...@@ -39,6 +39,7 @@ public: ...@@ -39,6 +39,7 @@ public:
* *
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions * @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len] * @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @param in_place If true, modify input tensor in place (default: false)
* @return Rotated tensor with same shape as input * @return Rotated tensor with same shape as input
* *
* Applies rotary position embeddings to the input tensor. * Applies rotary position embeddings to the input tensor.
...@@ -49,7 +50,7 @@ public: ...@@ -49,7 +50,7 @@ public:
* - [batch, seq_len, num_heads, head_dim] * - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim] * - [seq_len, head_dim]
*/ */
Tensor forward(const Tensor &x, const Tensor &pos) const; Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const;
// Module information // Module information
size_t head_dim() const { return head_dim_; } size_t head_dim() const { return head_dim_; }
...@@ -69,11 +70,11 @@ protected: ...@@ -69,11 +70,11 @@ protected:
private: private:
void initialize_cache(); void initialize_cache();
size_t head_dim_; // Dimension of each attention head size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables DataType dtype_; // Data type for cache tables
}; };
} // namespace infinicore::nn } // namespace infinicore::nn
...@@ -166,6 +166,21 @@ public: ...@@ -166,6 +166,21 @@ public:
/// View APIs /// View APIs
/// ///
/**
* Returns a new tensor with a dimension of size one inserted at the specified position.
* The returned tensor shares the same underlying storage with the original tensor.
*
* @param dim The dimension index at which to insert the new dimension
* @return A new tensor with the added dimension
*
* Example:
* // For a 2D tensor with shape [3, 4], unsqueeze at dim 0 results in shape [1, 3, 4]
* // unsqueeze at dim 1 results in shape [3, 1, 4]
* // unsqueeze at dim 2 results in shape [3, 4, 1]
* tensor->unsqueeze(0);
*/
Tensor unsqueeze(size_t dim) const;
/** /**
* Returns a new tensor that is a narrowed version of the current tensor. * Returns a new tensor that is a narrowed version of the current tensor.
* The returned tensor shares the same underlying storage with the original tensor. * The returned tensor shares the same underlying storage with the original tensor.
......
...@@ -112,9 +112,13 @@ void RoPE::initialize_cache() { ...@@ -112,9 +112,13 @@ void RoPE::initialize_cache() {
} }
} }
Tensor RoPE::forward(const Tensor &x, const Tensor &pos) const { Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const {
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP) if (in_place) {
// Validation is handled by the op layer Tensor y = Tensor(x);
op::rope_(y, x, pos, sin_cache_, cos_cache_, algo_);
return y;
}
return op::rope(x, pos, sin_cache_, cos_cache_, algo_); return op::rope(x, pos, sin_cache_, cos_cache_, algo_);
} }
......
...@@ -5,6 +5,21 @@ ...@@ -5,6 +5,21 @@
#include <spdlog/spdlog.h> #include <spdlog/spdlog.h>
namespace infinicore { namespace infinicore {
Tensor TensorImpl::unsqueeze(size_t dim) const {
// Create new shape with dimension of size one inserted at dim
Shape new_shape = meta_.shape;
new_shape.insert(new_shape.begin() + dim, 1);
// Create new strides with stride of zero for the new dimension
Strides new_strides = meta_.strides;
new_strides.insert(new_strides.begin() + dim, 0);
auto tensor_impl = std::make_shared<TensorImpl>(new_shape, new_strides, meta_.dtype);
tensor_impl->data_ = data_;
return Tensor(tensor_impl);
}
Tensor TensorImpl::narrow(const std::vector<TensorSliceParams> &slices) const { Tensor TensorImpl::narrow(const std::vector<TensorSliceParams> &slices) const {
// Create new shape and calculate offset // Create new shape and calculate offset
Shape new_shape = meta_.shape; Shape new_shape = meta_.shape;
...@@ -95,12 +110,16 @@ Tensor TensorImpl::view(const Shape &new_shape) const { ...@@ -95,12 +110,16 @@ Tensor TensorImpl::view(const Shape &new_shape) const {
for (size_t i = 0; i < new_shape.size(); ++i) { for (size_t i = 0; i < new_shape.size(); ++i) {
// Find which merged dimension contains this new dimension // Find which merged dimension contains this new dimension
while (new_shape[i] > remaining_size) { while (new_shape[i] > remaining_size) {
assert(++merged_idx < merged_shape.size()); if (++merged_idx >= merged_shape.size()) {
throw std::runtime_error("Incompatible shape for view operation.");
}
current_stride = merged_strides[merged_idx]; current_stride = merged_strides[merged_idx];
remaining_size = merged_shape[merged_idx]; remaining_size = merged_shape[merged_idx];
} }
assert(remaining_size % new_shape[i] == 0); if (remaining_size % new_shape[i] != 0) {
throw std::runtime_error("Incompatible shape for view operation.");
};
new_strides[i] = current_stride * (remaining_size / new_shape[i]); new_strides[i] = current_stride * (remaining_size / new_shape[i]);
remaining_size /= new_shape[i]; remaining_size /= new_shape[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment