Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
78314c0b
Commit
78314c0b
authored
Feb 12, 2025
by
mtgu0705
Browse files
init b preshuffle dequant in VGPR.
parent
518551b1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
37 deletions
+141
-37
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
.../block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
+35
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
...n/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
+33
-10
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+73
-20
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp
View file @
78314c0b
...
@@ -222,6 +222,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -222,6 +222,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
typename
BBlockTransfer
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockBuffer
,
typename
BThreadTransfer
,
typename
BBlockTransferStep
,
typename
BBlockTransferStep
,
typename
CThreadBuffer
>
typename
CThreadBuffer
>
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
__device__
void
Run
(
const
AGridDesc
&
a_grid_desc
,
...
@@ -235,6 +236,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -235,6 +236,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
const
BGridBuffer
&
b_grid_buf
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BBlockTransferStep
&
b_block_copy_step
,
BThreadTransfer
&
b_thread_dequant_copy
,
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
const
index_t
num_loop
)
const
{
{
...
@@ -242,12 +244,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -242,12 +244,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
a_thread_desc_
.
GetElementSpaceSize
());
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_dequant_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
StaticallyIndexedArray
<
decltype
(
b_thread_buf
),
Number
<
2
>
{}
>
b_thread_bufs
;
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
constexpr
auto
b_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
);
StaticallyIndexedArray
<
decltype
(
b_thread_dequant_buf
),
Number
<
2
>
{}
>
b_thread_dequant_bufs
;
// Global prefetch A1 B1
// Global prefetch A1 B1
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
,
I0
);
b_blockwise_copy
.
Run
(
b_grid_desc
,
b_blockwise_copy
.
Run
(
b_grid_desc
,
...
@@ -279,6 +286,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -279,6 +286,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
);
a_thread_buf
);
});
});
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I0
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I0
));
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
...
@@ -316,9 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -316,9 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
mfma_reg_buf
]
b_thread_
dequant_
bufs
[
mfma_reg_buf
]
[
Number
<
b_thread_desc_
.
CalculateOffset
(
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
using
mfma_input_type
=
using
mfma_input_type
=
...
@@ -348,6 +362,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -348,6 +362,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
);
a_thread_buf
);
});
});
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
mfma_reg_buf
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
mfma_reg_buf
));
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -382,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -382,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -411,6 +432,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -411,6 +432,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
);
a_thread_buf
);
});
});
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy
.
Run
(
b_block_desc_n0_n1_k0_k1
,
b_block_origin_idx
,
b_thread_bufs
(
I1
),
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_dequant_bufs
(
I1
));
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -425,7 +453,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -425,7 +453,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I1
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
@@ -458,7 +486,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
...
@@ -458,7 +486,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
I0
,
k0
,
I0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
b_thread_bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
b_thread_
dequant_
bufs
[
I0
][
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
View file @
78314c0b
...
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
//
const BElementwiseOperation b_element_op{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
// divide block work by [M, N]
...
@@ -1205,8 +1205,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1205,8 +1205,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
BDataType
,
// BDataType,
BDataType
,
ADataType
,
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_grid_desc_bpreshuffled
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
...
@@ -1220,18 +1219,24 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1220,18 +1219,24 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0
,
0
,
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
// B: VGRP->VGPR dequantization
auto
b_thread_dequant_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
BDataType
,
ComputeTypeA
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
BK1Number
>
(
b_element_op
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
// Cast after lds
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
// auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
// sizeof(ADataType) /
// APackedSize),
// b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
KRepeat
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
0
,
KRepeat
,
0
);
...
@@ -1255,6 +1260,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1255,6 +1260,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
...
@@ -1514,7 +1522,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1514,7 +1522,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
AElementwiseOperation
a_element_op
{};
//
const BElementwiseOperation b_element_op{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
// divide block work by [M, N]
...
@@ -1604,6 +1612,18 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1604,6 +1612,18 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0
,
0
,
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
KPack
*
(
get_thread_local_1d_id
()
%
warpSize
)));
// B: VGRP->VGPR dequantization
auto
b_thread_dequant_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
BDataType
,
ComputeTypeA
,
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
Number
<
NXdlPerWave
>
{},
I1
,
Number
<
KRepeat
>
{},
Number
<
BK1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
BK1Number
>
(
b_element_op
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
...
@@ -1636,6 +1656,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
...
@@ -1636,6 +1656,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf
,
b_grid_buf
,
b_block_bufs
,
b_block_bufs
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy
,
c_thread_buf
,
c_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
78314c0b
...
@@ -287,6 +287,7 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -287,6 +287,7 @@ struct ThreadwiseTensorSliceTransfer_v2
// loop over tensor and copy
// loop over tensor and copy
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
#if 0
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
static_for<0, num_access, 1>{}([&](auto idx_1d) {
...
@@ -352,12 +353,13 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -352,12 +353,13 @@ struct ThreadwiseTensorSliceTransfer_v2
});
});
}
}
else
else
#endif
{
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
src_vector
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>::
type
src_vector
;
using
src_vector_t
=
using
src_vector_t
=
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
SrcData
,
SrcScalarPerVector
/
PackedSize
>::
type
::
type
;
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
constexpr
auto
src_data_idx
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
const
bool
is_src_valid
=
const
bool
is_src_valid
=
...
@@ -365,24 +367,24 @@ struct ThreadwiseTensorSliceTransfer_v2
...
@@ -365,24 +367,24 @@ struct ThreadwiseTensorSliceTransfer_v2
// copy data from src_buf into src_vector
// copy data from src_buf into src_vector
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
(),
is_src_valid
);
src_buf
.
template
Get
<
src_vector_t
>(
src_coord_
.
GetOffset
()
/
PackedSize
,
is_src_valid
);
// copy data from src_vector into dst_buf
// copy data from src_vector into dst_buf
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
/
PackedSize
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
dst_desc
.
CalculateOffset
(
to_multi_index
(
dst_slice_origin_idx
)
+
src_data_idx
+
i
*
src_scalar_step_in_vector
);
i
*
src_scalar_step_in_vector
);
if
constexpr
(
InvalidElementAsNaN
)
if
constexpr
(
InvalidElementAsNaN
)
{
{
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_buf
(
Number
<
dst_offset
/
PackedSize
>
{})
=
is_src_valid
is_src_valid
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
?
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
])
:
NumericLimits
<
DstData
>::
QuietNaN
();
:
NumericLimits
<
DstData
>::
QuietNaN
();
}
}
else
else
{
{
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_buf
(
Number
<
dst_offset
/
PackedSize
>
{})
=
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
type_convert
<
DstData
>
(
src_vector
.
template
AsType
<
SrcData
>()[
i
]);
}
}
});
});
...
@@ -1544,6 +1546,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1544,6 +1546,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
static
constexpr
index_t
PackedSize
=
[]()
{
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
return
2
;
else
return
1
;
}();
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic
(
const
ElementwiseOperation
&
element_op
)
const
ElementwiseOperation
&
element_op
)
:
element_op_
{
element_op
}
:
element_op_
{
element_op
}
...
@@ -1598,26 +1607,70 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
...
@@ -1598,26 +1607,70 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
constexpr
auto
num_access
=
SpaceFillingCurve
::
GetNumOfAccess
();
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
if
constexpr
(
is_same
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>::
value
)
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
typename
vector_type_maker
<
SrcData
,
DstScalarPerVector
/
PackedSize
>::
type
src_tmp_vector
;
// copy data from src_buf into dst_vector
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
// copy data from src_buf into dst_vector
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
static_for
<
0
,
DstScalarPerVector
/
PackedSize
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
DstData
v
;
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
/
PackedSize
>
{}];
});
// apply element-wise operation
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// DstData)
vector_type_maker_t
<
DstData
,
DstScalarPerVector
>
dst_tmp_vector
;
// apply type convert
constexpr
index_t
pack_size
=
8
;
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
static_assert
(
DstScalarPerVector
%
pack_size
==
0
,
""
);
using
src_v_t
=
typename
vector_type_maker_t
<
SrcData
,
pack_size
/
PackedSize
>::
type
;
using
dst_v_t
=
typename
vector_type_maker_t
<
DstData
,
pack_size
>::
type
;
static_for
<
0
,
DstScalarPerVector
/
pack_size
,
1
>
{}([
&
](
auto
i
)
{
ck
::
tensor_operation
::
element_wise
::
PassThroughPack8
{}(
dst_tmp_vector
.
template
AsType
<
dst_v_t
>()(
i
),
src_tmp_vector
.
template
AsType
<
src_v_t
>()[
i
]);
});
// copy data from dst_tmp_vector into dst_buf
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_buf
(
Number
<
dst_offset
>
{})
=
dst_tmp_vector
.
template
AsType
<
DstData
>()[
i
];
});
});
});
});
}
else
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
idx_1d
)
{
constexpr
auto
idx_md
=
SpaceFillingCurve
::
GetIndex
(
idx_1d
);
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
DstData
v
;
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
v
;
});
});
}
}
}
ElementwiseOperation
element_op_
;
ElementwiseOperation
element_op_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment