Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
809a0c97
Commit
809a0c97
authored
Feb 12, 2025
by
mtgu0705
Browse files
fp8xint4 bpreshuffle function pass
parent
8c0e03ba
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
146 additions
and
77 deletions
+146
-77
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
+33
-43
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
...n/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+111
-32
No files found.
example/01_gemm/gemm_xdl_fp8_pk_i4_bpreshuffle_v3.cpp
View file @
809a0c97
...
@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
...
@@ -58,7 +58,7 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteA
=
false
;
static
constexpr
bool
PermuteB
=
tru
e
;
static
constexpr
bool
PermuteB
=
fals
e
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
static
constexpr
ck
::
index_t
KPerBlock
=
128
;
// clang-format off
// clang-format off
...
@@ -131,7 +131,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -131,7 +131,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_preshuffled
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_preshuffled
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n_permute
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
...
@@ -161,51 +160,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -161,51 +160,42 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n_preshuffled:"
<<
b_k_n_preshuffled
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
// std::cout << "a_m_K size: " << sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "BDataType size: " << sizeof(BDataType) << std::endl;
// std::cout << "b_k_n size: " << sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()
// << std::endl;
// std::cout << "c_m_n size: " << sizeof(CDataType) * c_m_n_host_result.mDesc.GetElementSpaceSize()
// << std::endl;
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_p
ermute
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n_p
reshuffled
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmV2Instance
{};
auto
gemm
=
DeviceGemmV2Instance
{};
int
NperXdl
=
gemm
.
GetPreShuffleParameters
();
// weight pre-shuffle
preShuffleBuffer
(
b_k_n
.
mData
.
data
(),
b_k_n_preshuffled
.
mData
.
data
(),
N
,
K
,
NperXdl
);
int
KPack
=
32
;
// int4 -> 32, fp8 -> 16, fp16 -> 8
int
NLane
=
gemm
.
GetPreShuffleParameters
();
// weight permute
int
KLane
=
64
/
NLane
;
if
constexpr
(
PermuteB
)
{
int
K1
=
KPerBlock
;
int
K0
=
K
/
KPerBlock
;
// int K0, N, K1
int
K0
=
K
/
(
KLane
*
KPack
);
for
(
int
j
=
0
;
j
<
K0
;
j
++
)
// K -> K0 KLane KPack
{
// N -> N0 NLane
for
(
int
i
=
0
;
i
<
N
;
i
++
)
// N, K -> N0 K0 KLane NLane KPack
{
int
tempk
;
for
(
int
jj
=
0
;
jj
<
K1
;
jj
++
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
b_k_n_permute
(
j
*
N
*
K1
+
i
*
K1
+
jj
)
=
b_k_n_preshuffled
(
i
*
K
+
(
j
*
K1
+
jj
));
}
}
}
}
else
{
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
int
n0
=
n
/
NLane
;
{
int
n1
=
n
%
NLane
;
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n_preshuffled
(
i
*
K
+
j
);
}
int
k0
=
k
/
(
KLane
*
KPack
);
tempk
=
k
%
(
KLane
*
KPack
);
int
k1
=
tempk
/
KPack
;
int
k2
=
tempk
%
KPack
;
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
K0
+
k0
*
KPack
*
NLane
*
KLane
+
k1
*
KPack
*
NLane
+
n1
*
KPack
+
k2
;
b_k_n_preshuffled
(
outputIndex
)
=
b_k_n
(
n
*
K
+
k
);
}
}
}
}
...
@@ -218,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -218,7 +208,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
{
int
i4x2
=
b_k_n_p
ermute
(
j
+
k
*
2
,
i
).
data
;
int
i4x2
=
b_k_n_p
reshuffled
(
j
+
k
*
2
,
i
).
data
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
0
]
=
(
i4x2
>>
4
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
input
[
k
*
2
+
1
]
=
(
i4x2
>>
0
)
&
0xf
;
}
}
...
@@ -229,7 +219,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -229,7 +219,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int
lo
=
input
[
0
];
int
lo
=
input
[
0
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_p
ermute
(
j
+
0
,
i
)
=
i4x2
;
b_k_n_p
reshuffled
(
j
+
0
,
i
)
=
i4x2
;
}
}
{
{
...
@@ -237,7 +227,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -237,7 +227,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int
lo
=
input
[
4
];
int
lo
=
input
[
4
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_p
ermute
(
j
+
2
,
i
)
=
i4x2
;
b_k_n_p
reshuffled
(
j
+
2
,
i
)
=
i4x2
;
}
}
{
{
...
@@ -245,7 +235,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -245,7 +235,7 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int
lo
=
input
[
1
];
int
lo
=
input
[
1
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_p
ermute
(
j
+
4
,
i
)
=
i4x2
;
b_k_n_p
reshuffled
(
j
+
4
,
i
)
=
i4x2
;
}
}
{
{
...
@@ -253,13 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -253,13 +243,13 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
int
lo
=
input
[
5
];
int
lo
=
input
[
5
];
int
i4x2
=
(
hi
<<
4
)
|
lo
;
int
i4x2
=
(
hi
<<
4
)
|
lo
;
b_k_n_p
ermute
(
j
+
6
,
i
)
=
i4x2
;
b_k_n_p
reshuffled
(
j
+
6
,
i
)
=
i4x2
;
}
}
}
}
}
}
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_p
ermute
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n_p
reshuffled
.
mData
.
data
());
DeviceMem
workspace
;
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
809a0c97
...
@@ -1205,7 +1205,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1205,7 +1205,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
BDataType
,
// BDataType,
ADataType
,
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
...
@@ -1221,7 +1222,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1221,7 +1222,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
// Cast after lds
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
809a0c97
...
@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -224,6 +224,13 @@ struct ThreadwiseTensorSliceTransfer_v2
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
using
SrcCoordStep
=
decltype
(
make_tensor_coordinate_step
(
SrcDesc
{},
Index
{}));
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__device__
constexpr
ThreadwiseTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
__device__
constexpr
ThreadwiseTensorSliceTransfer_v2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
const
Index
&
src_slice_origin_idx
)
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
:
src_coord_
(
make_tensor_coordinate
(
src_desc
,
src_slice_origin_idx
))
...
@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -232,6 +239,11 @@ struct ThreadwiseTensorSliceTransfer_v2
"wrong! SrcDesc need to known at compile-time"
);
"wrong! SrcDesc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong! Not divisible"
);
"wrong! Not divisible"
);
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
SrcScalarPerVector
%
PackedSize
==
0
,
"pk data N cannot be 1"
);
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin_idx
)
...
@@ -275,48 +287,115 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -275,48 +287,115 @@ struct ThreadwiseTensorSliceTransfer_v2
// loop over tensor and copy
// loop over tensor and copy
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>::
type
src_tmp_vector
;
using
src_vector_t
=
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
const
bool
is_src_valid
=
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
// copy data from src_buf into src_vector
const
bool
is_src_valid
=
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
);
// copy data from src_vector into dst_buf
// copy data from src_buf into src_tmp_vector
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
constexpr
index_t
dst_offset
=
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
is_src_valid
);
i
*
src_scalar_step_in_vector
);
if
constexpr
(
InvalidElementAsNaN
)
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
{
// DstData)
dst_buf
(
Number
<
dst_offset
>
{})
=
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
is_src_valid
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
constexpr
index_t
pack_size
=
8
;
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
static_assert
(
SrcScalarPerVector
%
pack_size
==
0
,
""
);
else
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
SrcScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
if
constexpr
(
InvalidElementAsNaN
)
{
dst_buf
(
Number
<
dst_offset
>
{})
=
is_src_valid
?
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
]
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
// type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
{
dst_buf
(
Number
<
dst_offset
>
{})
=
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
}
}
});
});
}
else
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
using
src_vector_t
=
{
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
constexpr
auto
forward_step
=
SpaceFillingCurve
::
Get
ForwardStep
(
idx_1d
);
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
Get
Index
(
idx_1d
);
move_tensor_coordinate
(
const
bool
is_src_valid
=
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
}
});
// copy data from src_buf into src_vector
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
);
// copy data from src_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
if
constexpr
(
InvalidElementAsNaN
)
{
dst_buf
(
Number
<
dst_offset
>
{})
=
is_src_valid
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
}
});
if
constexpr
(
idx_1d
.
value
!=
num_access
-
1
)
{
constexpr
auto
forward_step
=
SpaceFillingCurve
::
GetForwardStep
(
idx_1d
);
move_tensor_coordinate
(
src_desc
,
src_coord_
,
make_tensor_coordinate_step
(
src_desc
,
forward_step
));
}
});
}
// move src coordinate back to slice origin (or not)
// move src coordinate back to slice origin (or not)
if
constexpr
(
SrcResetCoordinateAfterRun
)
if
constexpr
(
SrcResetCoordinateAfterRun
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment