Commit 770fa304 authored by dongcl's avatar dongcl
Browse files

修改mtp

parent 8096abd4
...@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True): ...@@ -30,33 +30,3 @@ def is_flux_min_version(version, check_equality=True):
if check_equality: if check_equality:
return get_flux_version() >= PkgVersion(version) return get_flux_version() >= PkgVersion(version)
return get_flux_version() > PkgVersion(version) return get_flux_version() > PkgVersion(version)
def tensor_slide(
tensor: Optional[torch.Tensor],
num_slice: int,
dims: Union[int, List[int]] = -1,
step: int = 1,
return_first=False,
) -> List[Union[torch.Tensor, None]]:
"""通用滑动窗口函数,支持任意维度"""
if tensor is None:
# return `List[None]` to avoid NoneType Error
return [None] * (num_slice + 1)
if num_slice == 0:
return [tensor]
window_size = tensor.shape[-1] - num_slice
dims = [dims] if isinstance(dims, int) else sorted(dims, reverse=True)
# 连续多维度滑动
slices = []
for i in range(0, tensor.size(dims[-1]) - window_size + 1, step):
slice_obj = [slice(None)] * tensor.dim()
for dim in dims:
slice_obj[dim] = slice(i, i + window_size)
slices.append(tensor[tuple(slice_obj)])
if return_first:
return slices
return slices
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment