Unverified Commit 9b0b89c5 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #592 from InfiniTensor/issue/591

issue/591 添加infinicore.narrow
parents 5028ea42 16854aed
...@@ -16,7 +16,7 @@ Device getDevice(); ...@@ -16,7 +16,7 @@ Device getDevice();
size_t getDeviceCount(Device::Type type); size_t getDeviceCount(Device::Type type);
infinirtStream_t getStream(); infinirtStream_t getStream();
infiniopHandle_t getInfiniopHandle(); infiniopHandle_t getInfiniopHandle(Device device);
void syncStream(); void syncStream();
void syncDevice(); void syncDevice();
......
...@@ -31,6 +31,7 @@ from infinicore.ops.add import add ...@@ -31,6 +31,7 @@ from infinicore.ops.add import add
from infinicore.ops.attention import attention from infinicore.ops.attention import attention
from infinicore.ops.matmul import matmul from infinicore.ops.matmul import matmul
from infinicore.ops.mul import mul from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.rearrange import rearrange from infinicore.ops.rearrange import rearrange
from infinicore.tensor import ( from infinicore.tensor import (
Tensor, Tensor,
...@@ -79,6 +80,7 @@ __all__ = [ ...@@ -79,6 +80,7 @@ __all__ = [
"attention", "attention",
"matmul", "matmul",
"mul", "mul",
"narrow",
"rearrange", "rearrange",
"empty", "empty",
"empty_like", "empty_like",
......
from infinicore.tensor import Tensor
def narrow(input: Tensor, dim: int, start: int, length: int) -> Tensor:
return Tensor(input._underlying.narrow(dim, start, length))
...@@ -56,8 +56,8 @@ class Tensor: ...@@ -56,8 +56,8 @@ class Tensor:
def is_contiguous(self): def is_contiguous(self):
return self._underlying.is_contiguous() return self._underlying.is_contiguous()
def is_is_pinned(self): def is_pinned(self):
return self._underlying.is_is_pinned() return self._underlying.is_pinned()
def copy_(self, src): def copy_(self, src):
self._underlying.copy_(src._underlying) self._underlying.copy_(src._underlying)
...@@ -67,12 +67,12 @@ class Tensor: ...@@ -67,12 +67,12 @@ class Tensor:
self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs) self._underlying.to(*tuple(arg._underlying for arg in args), **kwargs)
) )
def as_strided(self, size, stride):
return Tensor(self._underlying.as_strided(size, stride))
def contiguous(self): def contiguous(self):
return Tensor(self._underlying.contiguous()) return Tensor(self._underlying.contiguous())
def as_strided(self, size, stride):
return Tensor(self._underlying.as_strided(size, stride))
def permute(self, dims): def permute(self, dims):
return Tensor(self._underlying.permute(dims)) return Tensor(self._underlying.permute(dims))
......
...@@ -99,7 +99,13 @@ infinirtStream_t getStream() { ...@@ -99,7 +99,13 @@ infinirtStream_t getStream() {
return ContextImpl::singleton().getCurrentRuntime()->stream(); return ContextImpl::singleton().getCurrentRuntime()->stream();
} }
infiniopHandle_t getInfiniopHandle() { infiniopHandle_t getInfiniopHandle(Device device) {
if (device.getType() == Device::Type::CPU) {
return ContextImpl::singleton().getCpuRuntime()->infiniopHandle();
}
if (device != getDevice()) {
throw std::runtime_error("Requested device doesn't match current runtime.");
}
return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle(); return ContextImpl::singleton().getCurrentRuntime()->infiniopHandle();
} }
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAddDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor ...@@ -28,7 +28,7 @@ void calculate(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateAttentionDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(out->device()), &desc,
out->desc(), q->desc(), k->desc(), v->desc(), out->desc(), q->desc(), k->desc(), v->desc(),
k_cache->desc(), v_cache->desc(), pos)); k_cache->desc(), v_cache->desc(), pos));
cache.put(seed, desc); cache.put(seed, desc);
......
...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { ...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateCausalSoftmaxDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b, float alpha, float beta) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateGemmDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateMulDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) { ...@@ -27,7 +27,7 @@ void calculate(Tensor y, Tensor x) {
infiniopRearrangeDescriptor_t desc = nullptr; infiniopRearrangeDescriptor_t desc = nullptr;
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(), &desc, y->desc(), x->desc())); INFINICORE_CHECK_ERROR(infiniopCreateRearrangeDescriptor(context::getInfiniopHandle(y->device()), &desc, y->desc(), x->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
desc = *desc_opt; desc = *desc_opt;
......
...@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) { ...@@ -28,7 +28,7 @@ void calculate(Tensor y, Tensor x, Tensor weight, float epsilon) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRMSNormDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(y->device()), &desc,
y->desc(), x->desc(), weight->desc(), epsilon)); y->desc(), x->desc(), weight->desc(), epsilon));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s ...@@ -42,7 +42,7 @@ void calculate(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &s
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateRoPEDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(x_out->device()), &desc,
x_out->desc(), x->desc(), x_out->desc(), x->desc(),
pos->desc(), sin_cache->desc(), cos_cache->desc(), pos->desc(), sin_cache->desc(), cos_cache->desc(),
infiniop_algo)); infiniop_algo));
......
...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) { ...@@ -28,7 +28,7 @@ void calculate(Tensor output, Tensor input) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSiluDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(output->device()), &desc,
output->desc(), input->desc())); output->desc(), input->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { ...@@ -28,7 +28,7 @@ void calculate(Tensor c, Tensor a, Tensor b) {
if (!desc_opt) { if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor( INFINICORE_CHECK_ERROR(infiniopCreateSwiGLUDescriptor(
context::getInfiniopHandle(), &desc, context::getInfiniopHandle(c->device()), &desc,
c->desc(), a->desc(), b->desc())); c->desc(), a->desc(), b->desc()));
cache.put(seed, desc); cache.put(seed, desc);
} else { } else {
......
...@@ -32,7 +32,7 @@ inline void bind(py::module &m) { ...@@ -32,7 +32,7 @@ inline void bind(py::module &m) {
.def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); }) .def("to", [](const Tensor &tensor, const Device &device) { return tensor->to(device); })
.def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); }) .def("as_strided", [](const Tensor &tensor, const Shape &shape, const Strides &strides) { return tensor->as_strided(shape, strides); })
.def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); }) .def("contiguous", [](const Tensor &tensor) { return tensor->contiguous(); })
.def("narrow", [](const Tensor &tensor, std::size_t dim, std::size_t start, std::size_t length) { return tensor->narrow({{dim, start, length}}); })
.def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); }) .def("permute", [](const Tensor &tensor, const Shape &dims) { return tensor->permute(dims); })
.def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); }); .def("view", [](const Tensor &tensor, const Shape &shape) { return tensor->view(shape); });
......
import sys
import os
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import infinicore
from framework.base import BaseOperatorTest, TensorSpec, TestCase
from framework.runner import GenericTestRunner
from framework.utils import is_broadcast
# ==============================================================================
# Operator-specific configuration
# ==============================================================================
# Test cases format: (shape, dim, start, length)
_TEST_CASES_DATA = [
# Basic cases
((2, 4), 0, 0, 1),
((2, 4), 1, 1, 1),
((5, 3, 2), 1, 0, 3),
((5, 3, 2), 0, 1, 3),
((4, 4, 1024, 32), 2, 1023, 1),
]
# Tolerance configuration
_TOLERANCE_MAP = {
infinicore.float16: {"atol": 0, "rtol": 0},
infinicore.float32: {"atol": 0, "rtol": 0},
infinicore.bfloat16: {"atol": 0, "rtol": 0},
}
# Data types to test
_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32]
def parse_test_cases():
"""
Parse test case data and return list of TestCase objects for all operation types.
Each test case contains all necessary information for execution and validation.
"""
test_cases = []
for data in _TEST_CASES_DATA:
shape = data[0]
dim = data[1]
start = data[2]
length = data[3]
# Generate test cases for all data types
for dtype in _TENSOR_DTYPES:
tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 0})
# Create typed tensor specs
a_spec = TensorSpec.from_tensor(shape, None, dtype)
test_cases.append(
TestCase(
inputs=[a_spec, dim, start, length],
kwargs={},
output_spec=None,
comparison_target=None, # Compare output
tolerance=tolerance,
description=f"Narrow",
)
)
return test_cases
class OpTest(BaseOperatorTest):
"""Narrow operator test with simplified implementation"""
def __init__(self):
super().__init__("Narrow")
def get_test_cases(self):
return parse_test_cases()
def torch_operator(self, *args, **kwargs):
"""PyTorch narrow implementation"""
return torch.narrow(*args, **kwargs)
def infinicore_operator(self, *args, **kwargs):
"""InfiniCore narrow implementation"""
return infinicore.narrow(*args, **kwargs)
def main():
"""Main entry point"""
runner = GenericTestRunner(OpTest)
runner.run_and_exit()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment