Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
7a48b0de
Unverified
Commit
7a48b0de
authored
Dec 09, 2025
by
thatPepe
Committed by
GitHub
Dec 09, 2025
Browse files
Merge pull request #728 from InfiniTensor/issue/722
issue/722 - adjusted cuda rearrange for shape (8, 4, 20, 64)
parents
be4d3d89
0eb27e6e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
141 deletions
+63
-141
src/infiniop/ops/rearrange/nvidia/rearrange_nvidia.cu
src/infiniop/ops/rearrange/nvidia/rearrange_nvidia.cu
+63
-141
No files found.
src/infiniop/ops/rearrange/nvidia/rearrange_nvidia.cu
View file @
7a48b0de
...
...
@@ -83,9 +83,7 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
// 获取更适合GPU处理的单元大小,这里使用2的幂次方
auto
meta_result
=
original_meta
.
distributeUnit
({
32
,
16
,
8
,
4
,
2
,
1
});
CHECK_RESULT
(
meta_result
);
const
utils
::
RearrangeMeta
&
meta
=
meta_result
.
take
();
// 获取维度信息
...
...
@@ -125,153 +123,79 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
prev_idx_stride
=
idx_strides
[
i
];
}
// 计算src_strides的降序排序索引,类似于Rust版本中的src_strides_desc_idx
std
::
vector
<
size_t
>
src_strides_desc_idx
(
ndim
);
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
src_strides_desc_idx
[
i
]
=
i
;
}
std
::
sort
(
src_strides_desc_idx
.
begin
(),
src_strides_desc_idx
.
end
(),
[
&
dims
](
size_t
a
,
size_t
b
)
{
return
std
::
abs
(
dims
[
a
].
src_stride
)
>
std
::
abs
(
dims
[
b
].
src_stride
);
});
// 根据最大线程数选择block和grid维度
const
size_t
block_size
=
max_threads
;
std
::
vector
<
bool
>
block_dim_choose
(
ndim
,
false
);
std
::
vector
<
SplitDim
>
split_dims
;
// 初始化计数器
size_t
block_elements
=
1
;
size_t
block_src_elements
=
1
;
size_t
block_dst_elements
=
1
;
size_t
src_choose_idx
=
ndim
;
size_t
dst_choose_idx
=
ndim
;
// 用于存储分割维度信息
std
::
vector
<
SplitDim
>
split_dims
;
std
::
vector
<
size_t
>
dim_order
(
ndim
);
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
dim_order
[
i
]
=
i
;
}
// 按src_stride升序排序,贪心选择
std
::
sort
(
dim_order
.
begin
(),
dim_order
.
end
(),
[
&
dims
](
size_t
a
,
size_t
b
)
{
return
std
::
abs
(
dims
[
a
].
src_stride
)
<
std
::
abs
(
dims
[
b
].
src_stride
);
});
// 维度选择循环
while
(
src_choose_idx
>
0
&&
dst_choose_idx
>
0
)
{
// 获取当前需要处理的维度索引
size_t
src_idx
=
src_strides_desc_idx
[
src_choose_idx
-
1
];
size_t
dst_idx
=
dst_choose_idx
-
1
;
if
(
src_idx
==
dst_idx
)
{
// 源和目标维度相同,可以一起处理
size_t
idx
=
src_idx
;
size_t
len
=
shape
[
idx
];
// 检查是否可以将此维度完全添加到block中
if
(
block_elements
*
len
<=
block_size
)
{
// 选择此维度
block_dim_choose
[
idx
]
=
true
;
block_elements
*=
len
;
block_src_elements
*=
len
;
block_dst_elements
*=
len
;
src_choose_idx
--
;
dst_choose_idx
--
;
}
else
{
// 需要分割此维度
size_t
num_per_block
=
block_size
/
block_elements
;
// 确保num_per_block > 0且len >= num_per_block
if
(
num_per_block
>
0
&&
len
>=
num_per_block
&&
num_per_block
>
1
)
{
size_t
num_per_grid
=
(
len
+
num_per_block
-
1
)
/
num_per_block
;
// 向上取整
SplitDim
split_dim
=
{
idx
,
// choose_idx
num_per_block
,
// num_per_block
num_per_grid
,
// num_per_grid
0
,
// array_struct_idx_block (待更新)
0
,
// array_struct_idx_grid (待更新)
len
// 原始维度长度
};
split_dims
.
push_back
(
split_dim
);
}
break
;
for
(
size_t
i
=
0
;
i
<
ndim
;
++
i
)
{
size_t
dim_idx
=
dim_order
[
i
];
size_t
dim_len
=
shape
[
dim_idx
];
if
(
block_elements
*
dim_len
<=
(
size_t
)
max_threads
)
{
block_dim_choose
[
dim_idx
]
=
true
;
block_elements
*=
dim_len
;
}
else
if
(
block_elements
>
1
&&
dim_len
>
1
)
{
// 需要分割此维度
size_t
num_per_block
=
std
::
min
(
dim_len
,
(
size_t
)
max_threads
/
block_elements
);
if
(
num_per_block
>
0
)
{
size_t
num_per_grid
=
(
dim_len
+
num_per_block
-
1
)
/
num_per_block
;
SplitDim
split_dim
=
{
dim_idx
,
// choose_idx
num_per_block
,
// num_per_block
num_per_grid
,
// num_per_grid
0
,
// array_struct_idx_block (待更新)
0
,
// array_struct_idx_grid (待更新)
dim_len
// original dimension length
};
split_dims
.
push_back
(
split_dim
);
block_elements
*=
num_per_block
;
}
}
else
{
// 源和目标维度不同,需要分别处理
// 计算块比例
double
src_div_dst
=
static_cast
<
double
>
(
block_src_elements
)
/
block_dst_elements
;
double
src_num_per_block
=
std
::
sqrt
(
block_size
/
(
double
)
block_elements
/
src_div_dst
);
double
dst_num_per_block
=
src_num_per_block
*
src_div_dst
;
size_t
src_current_dim_len
=
shape
[
src_idx
];
size_t
dst_current_dim_len
=
shape
[
dst_idx
];
if
(
static_cast
<
double
>
(
src_current_dim_len
)
<
src_num_per_block
)
{
// 源维度可以完全添加到block
block_dim_choose
[
src_idx
]
=
true
;
block_elements
*=
src_current_dim_len
;
block_src_elements
*=
src_current_dim_len
;
src_choose_idx
--
;
}
else
if
(
static_cast
<
double
>
(
dst_current_dim_len
)
<
dst_num_per_block
)
{
// 目标维度可以完全添加到block
block_dim_choose
[
dst_idx
]
=
true
;
block_elements
*=
dst_current_dim_len
;
block_dst_elements
*=
dst_current_dim_len
;
dst_choose_idx
--
;
}
else
{
// 需要分割源和目标维度
size_t
src_num_per_block_int
=
static_cast
<
size_t
>
(
std
::
floor
(
src_num_per_block
));
size_t
dst_num_per_block_int
=
static_cast
<
size_t
>
(
std
::
floor
(
dst_num_per_block
));
// 计算网格尺寸
size_t
src_num_per_grid
=
(
src_current_dim_len
+
src_num_per_block_int
-
1
)
/
src_num_per_block_int
;
// 向上取整
size_t
dst_num_per_grid
=
(
dst_current_dim_len
+
dst_num_per_block_int
-
1
)
/
dst_num_per_block_int
;
// 向上取整
// 处理源维度
if
(
src_num_per_block_int
>
1
)
{
if
(
src_num_per_grid
==
1
)
{
// 可以完全放入块
block_dim_choose
[
src_idx
]
=
true
;
block_elements
*=
src_current_dim_len
;
block_src_elements
*=
src_current_dim_len
;
src_choose_idx
--
;
}
else
{
// 需要分割
SplitDim
split_dim
=
{
src_idx
,
// choose_idx
src_num_per_block_int
,
// num_per_block
src_num_per_grid
,
// num_per_grid
0
,
// array_struct_idx_block (待更新)
0
,
// array_struct_idx_grid (待更新)
src_current_dim_len
// 原始维度长度
};
split_dims
.
push_back
(
split_dim
);
}
}
break
;
}
}
// 处理目标维度
if
(
dst_num_per_block_int
>
1
)
{
if
(
dst_num_per_grid
==
1
)
{
// 可以完全放入块
block_dim_choose
[
dst_idx
]
=
true
;
block_elements
*=
dst_current_dim_len
;
block_dst_elements
*=
dst_current_dim_len
;
dst_choose_idx
--
;
}
else
{
// 需要分割
SplitDim
split_dim
=
{
dst_idx
,
// choose_idx
dst_num_per_block_int
,
// num_per_block
dst_num_per_grid
,
// num_per_grid
0
,
// array_struct_idx_block (待更新)
0
,
// array_struct_idx_grid (待更新)
dst_current_dim_len
// 原始维度长度
};
split_dims
.
push_back
(
split_dim
);
}
}
if
(
block_elements
==
1
&&
ndim
>
0
)
{
size_t
dim_idx
=
dim_order
[
0
];
size_t
dim_len
=
shape
[
dim_idx
];
break
;
}
if
(
dim_len
<=
(
size_t
)
max_threads
)
{
block_dim_choose
[
dim_idx
]
=
true
;
block_elements
=
dim_len
;
}
else
{
// 需要分割
size_t
num_per_block
=
std
::
min
(
dim_len
,
(
size_t
)
max_threads
);
size_t
num_per_grid
=
(
dim_len
+
num_per_block
-
1
)
/
num_per_block
;
SplitDim
split_dim
=
{
dim_idx
,
num_per_block
,
num_per_grid
,
0
,
0
,
dim_len
};
split_dims
.
push_back
(
split_dim
);
block_elements
=
num_per_block
;
}
}
// 准备block维度相关参数
size_t
block_dim
=
0
;
size_t
block_len_total
=
1
;
size_t
block_len_total
=
block_elements
;
std
::
vector
<
ARRAY_TYPE_SIZE
>
block_len
;
std
::
vector
<
ARRAY_TYPE_STRIDE
>
src_block_stride
;
...
...
@@ -288,7 +212,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
src_block_stride
.
push_back
(
dims
[
i
].
src_stride
);
dst_block_stride
.
push_back
(
dims
[
i
].
dst_stride
);
block_dim
+=
1
;
block_len_total
*=
shape
[
i
];
}
// 处理分割维度的block部分
...
...
@@ -299,7 +222,6 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
dst_block_stride
.
push_back
(
dims
[
i
].
dst_stride
);
split_dims
[
j
].
array_struct_idx_block
=
static_cast
<
int
>
(
block_dim
);
block_dim
+=
1
;
block_len_total
*=
split_dims
[
j
].
num_per_block
;
}
}
}
...
...
@@ -317,6 +239,7 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
src_grid_stride
.
push_back
(
dims
[
i
].
src_stride
*
split_dims
[
j
].
num_per_block
);
dst_grid_stride
.
push_back
(
dims
[
i
].
dst_stride
*
split_dims
[
j
].
num_per_block
);
split_dims
[
j
].
array_struct_idx_grid
=
static_cast
<
int
>
(
grid_len
.
size
()
-
1
);
break
;
}
}
...
...
@@ -350,6 +273,10 @@ utils::Result<RearrangeParams> prepareRearrangeParams(const utils::RearrangeMeta
constraint
.
grid_div_block
=
split_dims
[
i
].
num_per_block
;
constraint
.
total_len
=
split_dims
[
i
].
dim_len
;
constraints
.
push_back
(
constraint
);
if
(
constraints
.
size
()
>=
2
)
{
break
;
}
}
// 设置参数
...
...
@@ -437,11 +364,6 @@ infiniStatus_t Descriptor::calculate(
// 如果没有维度,直接进行内存拷贝
if
(
_meta
.
ndim
()
==
0
)
{
auto
err
=
cudaMemcpyAsync
(
y
,
x
,
_meta
.
unit
(),
cudaMemcpyDeviceToDevice
,
cuda_stream
);
if
(
err
!=
cudaSuccess
)
{
return
INFINI_STATUS_INTERNAL_ERROR
;
}
CHECK_OR_RETURN
(
cudaMemcpyAsync
(
y
,
x
,
_meta
.
unit
(),
cudaMemcpyDeviceToDevice
,
cuda_stream
)
==
cudaSuccess
,
INFINI_STATUS_INTERNAL_ERROR
);
return
INFINI_STATUS_SUCCESS
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment