Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bc5d7b6a
Commit
bc5d7b6a
authored
Aug 02, 2023
by
Jing Zhang
Browse files
remove zero
parent
5ca6b1f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
264 additions
and
98 deletions
+264
-98
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
+24
-16
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+44
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+196
-75
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp
View file @
bc5d7b6a
...
...
@@ -222,11 +222,17 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
auto
argument
=
gemm
.
MakeArgument
(
p_As
,
p_Bs
,
p_Ds
,
p_Cs
,
gemm_descs
,
a_element_op
,
b_element_op
,
c_element_op
);
DeviceMem
gemm_desc_workspace
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
std
::
size_t
grouped_gemm_kernel_args_buf_size
=
grouped_gemm_kernel_args_
.
size
()
*
sizeof
(
GroupedGemmKernelArgument
);
hip_check_error
(
hipMemcpy
(
gemm_desc_workspace
.
GetDeviceBuffer
(),
DeviceMem
gemm_arg_dev_mem
(
grouped_gemm_kernel_args_buf_size
);
DeviceMem
gemm_workspace_dev
(
gemm
.
GetWorkSpaceSize
(
&
argument
));
gemm
.
SetWorkSpacePointer
(
&
argument
,
gemm_workspace_dev
.
GetDeviceBuffer
());
hip_check_error
(
hipMemcpy
(
gemm_arg_dev_mem
.
GetDeviceBuffer
(),
grouped_gemm_kernel_args_
.
data
(),
g
emm
.
GetWorkSpaceSize
(
&
argument
)
,
g
rouped_gemm_kernel_args_buf_size
,
hipMemcpyHostToDevice
));
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
...
...
@@ -236,11 +242,21 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
"not support this GEMM problem"
);
}
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_
desc_workspace
.
GetDeviceBuffer
());
gemm
.
SetDeviceKernelArgs
(
argument
,
gemm_
arg_dev_mem
.
GetDeviceBuffer
());
gemm
.
SetKBatch
(
argument
,
config
.
k_batch
);
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
...
...
@@ -273,16 +289,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
}
}
if
(
config
.
time_kernel
)
{
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
ave_time
;
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
gemm
.
GetTypeString
()
<<
std
::
endl
;
}
return
pass
;
}
...
...
@@ -293,8 +299,10 @@ int main(int argc, char* argv[])
problem_size
.
group_count
=
16
;
problem_size
.
Ms
=
{
167
,
183
,
177
,
181
,
153
,
139
,
156
,
173
,
163
,
150
,
204
,
184
,
168
,
156
,
168
,
148
};
// problem_size.Ms = {
// 167, 183, 177, 181, 153, 139, 156, 173, 163, 150, 180, 184, 168, 156, 168, 148};
problem_size
.
Ms
=
{
0
,
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
0
};
for
(
int
i
=
0
;
i
<
problem_size
.
group_count
;
i
++
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
bc5d7b6a
#pragma once
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
...
...
@@ -41,6 +40,8 @@ __global__ void
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_fixed_nk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
uint32_t
*
barrier_count
,
const
index_t
barrier_size_grp
,
const
index_t
group_count
,
const
index_t
grid_size_grp
,
const
index_t
KBatch
,
...
...
@@ -95,13 +96,27 @@ __global__ void
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
index_t
id_off
=
0
;
index_t
id_off
=
0
;
index_t
id_local
=
get_block_1d_id
()
-
BlockStart
;
const
index_t
mn_blocks
=
local_grid_size
/
KBatch
;
__shared__
index_t
k_id_start
,
k_id_finished
;
ignore
=
barrier_count
;
ignore
=
k_id_start
;
ignore
=
k_id_finished
;
while
(
(
get_block_1d_id
()
-
BlockStart
+
id_off
)
<
local_grid_size
)
while
(
id_local
<
local_grid_size
)
{
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
id_off
);
auto
barrier_count_start
=
barrier_count
+
group_id
*
barrier_size_grp
*
2
+
id_local
%
mn_blocks
;
auto
barrier_count_finished
=
barrier_count
+
group_id
*
barrier_size_grp
*
2
+
barrier_size_grp
+
id_local
%
mn_blocks
;
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
EGlobalMemoryDataOperation
,
GemmSpec
,
...
...
@@ -113,6 +128,8 @@ __global__ void
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
barrier_count_start
,
barrier_count_finished
,
a_element_op
,
b_element_op
,
c_element_op
,
...
...
@@ -127,6 +144,7 @@ __global__ void
block_2_etile_map
);
id_off
+=
grid_size_grp
;
id_local
+=
grid_size_grp
;
}
#else
ignore
=
gemm_descs_const
;
...
...
@@ -430,8 +448,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch_
};
grid_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
grid_size_
=
grid_size_grp_
*
group_count_
;
grid_size_
=
grid_size_grp_
*
group_count_
;
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
...
...
@@ -568,6 +585,14 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
group_id
++
;
}
const
auto
e_grid_desc_sum_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
sum_of_m
,
gemm_desc_kernel_arg_
[
0
].
N_
,
gemm_desc_kernel_arg_
[
0
].
StrideE_
);
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_sum_m_n
,
1
};
barrier_size_grp_
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_sum_m_n
);
}
// private:
...
...
@@ -585,6 +610,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
grid_size_
;
index_t
grid_size_grp_
;
index_t
barrier_size_grp_
;
index_t
sum_of_m
;
index_t
k_batch_
;
...
...
@@ -642,6 +668,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
grouped_gemm_kernel_args_dev
),
reinterpret_cast
<
uint32_t
*>
(
arg
.
p_workspace_
),
arg
.
barrier_size_grp_
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp_
,
arg
.
k_batch_
,
...
...
@@ -808,8 +836,17 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GroupedGemmKernelArgument
<
NumDTensor
>
);
auto
arg
=
*
dynamic_cast
<
const
Argument
*>
(
p_arg
);
return
arg
.
group_count_
*
(
arg
.
barrier_size_grp_
*
2
)
*
sizeof
(
uint32_t
);
}
void
SetWorkSpacePointer
(
BaseArgument
*
p_arg
,
void
*
p_workspace
)
const
override
{
auto
p_arg_
=
dynamic_cast
<
Argument
*>
(
p_arg
);
p_arg_
->
p_workspace_
=
p_workspace
;
hip_check_error
(
hipMemset
(
p_workspace
,
0
,
GetWorkSpaceSize
(
p_arg
)));
}
static
void
SetKBatch
(
Argument
&
arg
,
index_t
k_batch
)
{
arg
.
UpdateKBatch
(
k_batch
);
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
bc5d7b6a
...
...
@@ -475,6 +475,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
uint32_t
*
barrier_count_start
,
uint32_t
*
barrier_count_finished
,
const
index_t
KBatch
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation_
&
cde_element_op
,
...
...
@@ -492,17 +495,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_kbatch_bk0_n_bk1
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor_
>
{});
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
// divide block work by [M, N]
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -661,8 +653,38 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
c_thread_buf
,
num_k_block_main_loop
);
// ignore = barrier_count_start;
// ignore = barrier_count_finished;
// ignore = KBatch;
__shared__
index_t
k_id_start_shared
;
if
(
threadIdx
.
x
==
0
)
{
const
auto
k_id_start_t
=
atomicAdd
(
barrier_count_start
,
1
);
k_id_start_shared
=
k_id_start_t
;
if
(
k_id_start_t
>
0
)
{
while
(
__atomic_load_n
(
barrier_count_finished
,
__ATOMIC_RELAXED
)
==
0
)
{}
}
}
__syncthreads
();
// shuffle C and write out
{
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_ds_grid
[
i
],
ds_grid_desc_mblock_mperblock_nblock_nperblock
[
i
].
GetElementSpaceSize
());
},
Number
<
NumDTensor_
>
{});
auto
e_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_grid
,
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
...
...
@@ -799,36 +821,6 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
},
Number
<
NumDTensor_
>
{}));
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType_
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation_
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make Sequence
// support arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor_
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
)),
cde_element_op
};
// space filling curve for threadwise C in VGPR before shuffle
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
...
...
@@ -855,45 +847,166 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
if
(
k_id_start_shared
==
0
)
{
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType_
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation_
,
Sequence
<
static_cast
<
index_t
>
(
InMemoryDataOperationEnum
::
Set
)
>
,
// FIXME: make
// Sequence
// support
// arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor_
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
)),
cde_element_op
};
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor_
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
else
{
// blockwise copy C/D/E between LDS and global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7
<
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType_
{})),
Tuple
<
EDataType
>
,
decltype
(
c_ds_desc_refs
),
decltype
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
)),
CDEElementwiseOperation_
,
Sequence
<
static_cast
<
index_t
>
(
EGlobalMemoryDataOperation
)
>
,
// FIXME: make
// Sequence support
// arbitray type
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CDEShuffleBlockTransferScalarPerVector_NPerBlock
,
sequence_merge_t
<
Sequence
<
true
>
,
uniform_sequence_gen_t
<
NumDTensor_
,
false
>>
,
// ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence
<
false
>>
// ThreadTransferDstResetCoordinateAfterRunFlags
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I1
],
0
,
block_work_idx
[
I2
],
0
)),
cde_element_op
};
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor_
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
}
});
}
// make sure it's safe to read from LDS
block_sync_lds
();
__syncthreads
();
// each block copy its data from LDS to global
cde_block_copy_lds_and_global
.
Run
(
c_ds_desc_refs
,
c_ds_buf_refs
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_buf
));
if
(
threadIdx
.
x
==
0
)
{
index_t
k_id_finished_t
=
atomicAdd
(
barrier_count_finished
,
1
);
if
constexpr
(
access_id
<
num_access
-
1
)
if
(
k_id_finished_t
==
KBatch
-
1
)
{
constexpr
auto
cde_lds_and_global_step
=
sfc_cde_block
.
GetForwardStep
(
access_id
);
// move on Ds
static_for
<
0
,
NumDTensor_
,
1
>
{}([
&
](
auto
i
)
{
cde_block_copy_lds_and_global
.
MoveSrcSliceWindow
(
c_ds_desc_refs
,
i
+
I1
,
cde_lds_and_global_step
);
});
// move on E
cde_block_copy_lds_and_global
.
MoveDstSliceWindow
(
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
cde_lds_and_global_step
);
*
barrier_count_start
=
0
;
*
barrier_count_finished
=
0
;
}
}
);
}
}
}
...
...
@@ -910,6 +1023,8 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
DsGridPointer
p_ds_grid
,
void
*
__restrict__
p_e_grid_
,
void
*
__restrict__
p_shared
,
uint32_t
*
barrier_count_start
,
uint32_t
*
barrier_count_finished
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
...
...
@@ -977,6 +1092,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid
,
p_e_grid
,
p_shared
,
barrier_count_start
,
barrier_count_finished
,
KBatch
,
a_element_op
,
b_element_op
,
cde_element_op
,
...
...
@@ -994,6 +1112,9 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid
,
p_e_grid
,
p_shared
,
barrier_count_start
,
barrier_count_finished
,
KBatch
,
a_element_op
,
b_element_op
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment