Commit e238ace2 authored by PanZezhong's avatar PanZezhong Committed by thatPepe
Browse files

issue/709 nn::RoPE forward支持inplace,添加unsqueeze

parent b6f8f8c3
#pragma once
#include "module.hpp"
#include "../context/context.hpp"
#include "../tensor.hpp"
#include "module.hpp"
#include <memory>
namespace infinicore::nn {
......@@ -39,6 +39,7 @@ public:
*
* @param x Input tensor of shape (..., head_dim) where ... is any number of dimensions
* @param pos Position IDs tensor of shape (*,) typically [seq_len] or [batch, seq_len]
* @param in_place If true, modify input tensor in place (default: false)
* @return Rotated tensor with same shape as input
*
* Applies rotary position embeddings to the input tensor.
......@@ -49,7 +50,7 @@ public:
* - [batch, seq_len, num_heads, head_dim]
* - [seq_len, head_dim]
*/
Tensor forward(const Tensor &x, const Tensor &pos) const;
Tensor forward(const Tensor &x, const Tensor &pos, bool in_place = false) const;
// Module information
size_t head_dim() const { return head_dim_; }
......@@ -69,11 +70,11 @@ protected:
private:
void initialize_cache();
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
size_t head_dim_; // Dimension of each attention head
size_t max_seq_len_; // Maximum sequence length
double theta_; // Base frequency for rotary embeddings
Algo algo_; // RoPE algorithm type
DataType dtype_; // Data type for cache tables
};
} // namespace infinicore::nn
......@@ -166,6 +166,21 @@ public:
/// View APIs
///
/**
* Returns a new tensor with a dimension of size one inserted at the specified position.
* The returned tensor shares the same underlying storage with the original tensor.
*
* @param dim The dimension index at which to insert the new dimension
* @return A new tensor with the added dimension
*
* Example:
* // For a 2D tensor with shape [3, 4], unsqueeze at dim 0 results in shape [1, 3, 4]
* // unsqueeze at dim 1 results in shape [3, 1, 4]
* // unsqueeze at dim 2 results in shape [3, 4, 1]
* tensor->unsqueeze(0);
*/
Tensor unsqueeze(size_t dim) const;
/**
* Returns a new tensor that is a narrowed version of the current tensor.
* The returned tensor shares the same underlying storage with the original tensor.
......
......@@ -112,9 +112,13 @@ void RoPE::initialize_cache() {
}
}
Tensor RoPE::forward(const Tensor &x, const Tensor &pos) const {
// Delegate to InfiniCore op (backed by InfiniRT/InfiniOP)
// Validation is handled by the op layer
Tensor RoPE::forward(const Tensor &x, const Tensor &pos, bool in_place) const {
if (in_place) {
Tensor y = Tensor(x);
op::rope_(y, x, pos, sin_cache_, cos_cache_, algo_);
return y;
}
return op::rope(x, pos, sin_cache_, cos_cache_, algo_);
}
......
......@@ -5,6 +5,21 @@
#include <spdlog/spdlog.h>
namespace infinicore {
Tensor TensorImpl::unsqueeze(size_t dim) const {
// Create new shape with dimension of size one inserted at dim
Shape new_shape = meta_.shape;
new_shape.insert(new_shape.begin() + dim, 1);
// Create new strides with stride of zero for the new dimension
Strides new_strides = meta_.strides;
new_strides.insert(new_strides.begin() + dim, 0);
auto tensor_impl = std::make_shared<TensorImpl>(new_shape, new_strides, meta_.dtype);
tensor_impl->data_ = data_;
return Tensor(tensor_impl);
}
Tensor TensorImpl::narrow(const std::vector<TensorSliceParams> &slices) const {
// Create new shape and calculate offset
Shape new_shape = meta_.shape;
......@@ -95,12 +110,16 @@ Tensor TensorImpl::view(const Shape &new_shape) const {
for (size_t i = 0; i < new_shape.size(); ++i) {
// Find which merged dimension contains this new dimension
while (new_shape[i] > remaining_size) {
assert(++merged_idx < merged_shape.size());
if (++merged_idx >= merged_shape.size()) {
throw std::runtime_error("Incompatible shape for view operation.");
}
current_stride = merged_strides[merged_idx];
remaining_size = merged_shape[merged_idx];
}
assert(remaining_size % new_shape[i] == 0);
if (remaining_size % new_shape[i] != 0) {
throw std::runtime_error("Incompatible shape for view operation.");
};
new_strides[i] = current_stride * (remaining_size / new_shape[i]);
remaining_size /= new_shape[i];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment