Commit 8e25feb0 authored by PanZezhong's avatar PanZezhong
Browse files

issue/821 添加squeeze算子,完善unsqueeze算子测试

parent 215d1932
...@@ -168,6 +168,19 @@ public: ...@@ -168,6 +168,19 @@ public:
/// View APIs /// View APIs
/// ///
/**
* Returns a new tensor with a dimension of size one removed at the specified position.
* Throws runtime_error if the dimension to be removed is not of size 1.
*
* @param dim The dimension index to remove
* @return A new tensor with the removed dimension
*
* Example:
* // For a 3D tensor with shape [1, 3, 4], squeeze at dim 0 results in shape [3, 4]
* tensor->squeeze(0);
*/
Tensor squeeze(size_t dim) const;
/** /**
* Returns a new tensor with a dimension of size one inserted at the specified position. * Returns a new tensor with a dimension of size one inserted at the specified position.
* The returned tensor shares the same underlying storage with the original tensor. * The returned tensor shares the same underlying storage with the original tensor.
......
...@@ -45,6 +45,8 @@ from infinicore.ops.matmul import matmul ...@@ -45,6 +45,8 @@ from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow from infinicore.ops.narrow import narrow
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.ops.squeeze import squeeze
from infinicore.ops.unsqueeze import unsqueeze
from infinicore.tensor import ( from infinicore.tensor import (
Tensor, Tensor,
empty, empty,
...@@ -104,6 +106,8 @@ __all__ = [ ...@@ -104,6 +106,8 @@ __all__ = [
"matmul", "matmul",
"mul", "mul",
"narrow", "narrow",
"squeeze",
"unsqueeze",
"rearrange", "rearrange",
"empty", "empty",
"empty_like", "empty_like",
......
from infinicore.tensor import Tensor
def squeeze(input: Tensor, dim: int) -> Tensor:
return Tensor(input._underlying.squeeze(dim))
from infinicore.tensor import Tensor
def unsqueeze(input: Tensor, dim: int) -> Tensor:
return Tensor(input._underlying.unsqueeze(dim))
...@@ -92,6 +92,12 @@ class Tensor: ...@@ -92,6 +92,12 @@ class Tensor:
def view(self, shape): def view(self, shape):
return Tensor(self._underlying.view(shape)) return Tensor(self._underlying.view(shape))
def squeeze(self, dim):
return infinicore.squeeze(self, dim)
def unsqueeze(self, dim):
return infinicore.unsqueeze(self, dim)
def debug(self, filename=None): def debug(self, filename=None):
"""Print tensor data or save to file for debugging """Print tensor data or save to file for debugging
......
...@@ -16,25 +16,27 @@ inline void bind(py::module &m) { ...@@ -16,25 +16,27 @@ inline void bind(py::module &m) {
.def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); }) .def_property_readonly("ndim", [](const Tensor &tensor) { return tensor->ndim(); })
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); }) .def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); }) .def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); }) .def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); })
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); }) .def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
.def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); }) .def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); })
.def("numel", [](const Tensor &tensor) { return tensor->numel(); }) .def("numel", [](const Tensor &tensor) { return tensor->numel(); })
.def("is_contiguous", [](const Tensor &tensor) { return tensor->is_contiguous(); }) .def("is_contiguous", [](const Tensor &tensor) { return tensor->is_contiguous(); })
.def("is_pinned", [](const Tensor &tensor) { return tensor->is_pinned(); }) .def("is_pinned", [](const Tensor &tensor) { return tensor->is_pinned(); })
.def("info", [](const Tensor &tensor) { return tensor->info(); }) .def("info", [](const Tensor &tensor) { return tensor->info(); })
.def("debug", [](const Tensor &tensor) { return tensor->debug(); }) .def("debug", [](const Tensor &tensor) { return tensor->debug(); })
.def("debug", [](const Tensor &tensor, const std::string &filename) { return tensor->debug(filename); }) .def("debug", [](const Tensor &tensor, const std::string &filename) { return tensor->debug(filename); })
.def("copy_", [](Tensor &tensor, const Tensor &other) { tensor->copy_from(other); }) .def("copy_", [](Tensor &tensor, const Tensor &other) { tensor->copy_from(other); })
.def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); }) .def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); })
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
.def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); }) .def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); })
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
.def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); }) .def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); })
.def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); }) .def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); })
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); }); .def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); })
.def("unsqueeze", [](const Tensor &tensor, std::size_t dim) { return tensor->unsqueeze(dim); })
.def("squeeze", [](const Tensor &tensor, std::size_t dim) { return tensor->squeeze(dim); });
m.def("empty", &Tensor::empty, m.def("empty", &Tensor::empty,
py::arg("shape"), py::arg("shape"),
......
...@@ -6,6 +6,23 @@ ...@@ -6,6 +6,23 @@
#include <stdexcept> #include <stdexcept>
namespace infinicore { namespace infinicore {
Tensor TensorImpl::squeeze(size_t dim) const {
// Create new shape with dimension of size one removed at dim
if (meta_.shape[dim] != 1) {
spdlog::error("Dimension {} is not of size 1 for squeeze operation on {}.", dim, this->info());
throw std::runtime_error("Invalid squeeze operation on tensor.");
}
Shape new_shape = meta_.shape;
new_shape.erase(new_shape.begin() + dim);
Strides new_strides = meta_.strides;
new_strides.erase(new_strides.begin() + dim);
auto tensor_impl = std::make_shared<TensorImpl>(new_shape, new_strides, meta_.dtype);
tensor_impl->data_ = data_;
return Tensor(tensor_impl);
}
Tensor TensorImpl::unsqueeze(size_t dim) const { Tensor TensorImpl::unsqueeze(size_t dim) const {
// Create new shape with dimension of size one inserted at dim // Create new shape with dimension of size one inserted at dim
Shape new_shape = meta_.shape; Shape new_shape = meta_.shape;
......
#ifndef INFINIUTILS_H #ifndef INFINIUTILS_H
#define INFINIUTILS_H #define INFINIUTILS_H
#include "infinicore.h"
#include "utils/custom_types.h" #include "utils/custom_types.h"
#include "utils/rearrange.h" #include "utils/rearrange.h"
......
...@@ -3,8 +3,19 @@ ...@@ -3,8 +3,19 @@
#include <iostream> #include <iostream>
#include <tuple> #include <tuple>
#include "../utils.h"
#include "infini_status_string.h" #include "infini_status_string.h"
#define CHECK_OR_DO(CONDITION, ACTION) \
do { \
if (!(CONDITION)) { \
std::cerr << "Check Failed: `(" << #CONDITION << ")` is False" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
{ ACTION; } \
} \
} while (0)
#define CHECK_OR_RETURN(CONDITION, ERROR) \ #define CHECK_OR_RETURN(CONDITION, ERROR) \
do { \ do { \
if (!(CONDITION)) { \ if (!(CONDITION)) { \
...@@ -33,17 +44,19 @@ ...@@ -33,17 +44,19 @@
std::cerr << "Error: " << infini_status_string(api_result_) << std::endl; \ std::cerr << "Error: " << infini_status_string(api_result_) << std::endl; \
return api_result_) return api_result_)
#define CHECK_DTYPE(DT, ...) \ #define CHECK_DTYPE(DT, ...) \
do { \ do { \
auto found_supported_dtype = false; \ auto dtype_is_supported = false; \
for (auto dt : {__VA_ARGS__}) { \ for (auto dt : {__VA_ARGS__}) { \
if (dt == DT) { \ if (dt == DT) { \
found_supported_dtype = true; \ dtype_is_supported = true; \
break; \ break; \
} \ } \
} \ } \
CHECK_API_OR(found_supported_dtype, true, \ CHECK_OR_DO(dtype_is_supported, \
return INFINI_STATUS_BAD_TENSOR_DTYPE); \ { std::cerr << "Unsupported dtype: " << \
infiniDtypeToString(DT) << ". "; \
return INFINI_STATUS_BAD_TENSOR_DTYPE; }); \
} while (0) } while (0)
#define CHECK_DTYPE_ANY_INT(DT) \ #define CHECK_DTYPE_ANY_INT(DT) \
......
...@@ -363,7 +363,7 @@ def rearrange_tensor(tensor, new_strides): ...@@ -363,7 +363,7 @@ def rearrange_tensor(tensor, new_strides):
left = 0 left = 0
right = 0 right = 0
for i in range(len(shape)): for i in range(len(shape)):
if new_strides[i] > 0: if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1 new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1) right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future else: # TODO: Support negative strides in the future
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, strides, dim)
_TEST_CASES_DATA = [
# Basic cases
((1, 1, 1), None, 1),
((1, 1, 1), None, 0),
((1, 2, 4), None, 0),
((2, 1, 4), (4, 0, 1), 1),
((1, 4, 1, 32), (32, 32, 32, 1), 2),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.float32: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
strides = data[1]
dim = data[2]
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, strides, dtype)
test_cases.append(
TestCase(
inputs=[a_spec, dim],
kwargs={},
output_spec=None,
comparison_target=None, # Compare output
tolerance=tolerance,
description=f"squeeze",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""squeeze operator test with simplified implementation"""
def __init__(self):
super().__init__("squeeze")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch squeeze implementation"""
return torch.squeeze(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore squeeze implementation"""
return infinicore.squeeze(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, strides, dim)
_TEST_CASES_DATA = [
# Basic cases
((1, 1, 1), None, 1),
((1, 1, 1), None, 0),
((1, 2, 4), None, 0),
((2, 1, 4), (4, 0, 1), 1),
((1, 4, 1, 32), (32, 32, 32, 1), 2),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.float32: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
strides = data[1]
dim = data[2]
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, strides, dtype)
test_cases.append(
TestCase(
inputs=[a_spec, dim],
kwargs={},
output_spec=None,
comparison_target=None, # Compare output
tolerance=tolerance,
description=f"unsqueeze",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""unsqueeze operator test with simplified implementation"""
def __init__(self):
super().__init__("unsqueeze")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch unsqueeze implementation"""
return torch.unsqueeze(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore unsqueeze implementation"""
return infinicore.unsqueeze(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
...@@ -296,7 +296,7 @@ def rearrange_tensor(tensor, new_strides): ...@@ -296,7 +296,7 @@ def rearrange_tensor(tensor, new_strides):
left = 0 left = 0
right = 0 right = 0
for i in range(len(shape)): for i in range(len(shape)):
if new_strides[i] > 0: if new_strides[i] >= 0:
new_size[i] = (shape[i] - 1) * new_strides[i] + 1 new_size[i] = (shape[i] - 1) * new_strides[i] + 1
right += new_strides[i] * (shape[i] - 1) right += new_strides[i] * (shape[i] - 1)
else: # TODO: Support negative strides in the future else: # TODO: Support negative strides in the future
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment