Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f58137f2
Commit
f58137f2
authored
May 29, 2023
by
carlushuang
Browse files
fix several bug
parent
8eaed8b3
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
123 additions
and
50 deletions
+123
-50
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
+6
-3
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+24
-4
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+6
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
...ensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
+43
-34
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-2
library/include/ck/library/utility/device_memory.hpp
library/include/ck/library/utility/device_memory.hpp
+2
-1
library/src/utility/device_memory.cpp
library/src/utility/device_memory.cpp
+33
-4
profiler/include/profiler/profile_gemm_streamk_impl.hpp
profiler/include/profiler/profile_gemm_streamk_impl.hpp
+7
-0
No files found.
example/35_splitK_gemm/splitK_gemm_xdl_fp16.cpp
View file @
f58137f2
...
@@ -35,14 +35,15 @@ using AccDataType = F32;
...
@@ -35,14 +35,15 @@ using AccDataType = F32;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
BLayout
=
Col
;
using
BLayout
=
Row
;
using
CLayout
=
Row
;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
// static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
static
constexpr
auto
GemmMNPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
;
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlSplitKCShuffle
// clang-format off
// clang-format off
...
@@ -50,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
...
@@ -50,7 +51,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
256
,
128
,
4
,
8
,
32
,
32
,
4
,
2
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
64
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
3
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, true, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 3, 8, 8, true, 1, 1, S<1, 32, 1, 8>, 8>;
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmMNPadding
,
128
,
32
,
64
,
4
,
8
,
32
,
32
,
1
,
1
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
2
,
1
,
3
>
,
S
<
0
,
2
,
1
,
3
>
,
3
,
8
,
8
,
true
,
S
<
1
,
4
,
32
,
1
>
,
S
<
0
,
1
,
3
,
2
>
,
S
<
0
,
1
,
3
,
2
>
,
2
,
2
,
8
,
true
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
>
;
// clang-format on
// clang-format on
#include "run_splitK_gemm_example.inc"
#include "run_splitK_gemm_example.inc"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
View file @
f58137f2
...
@@ -111,14 +111,34 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
...
@@ -111,14 +111,34 @@ struct ThreadGroupTensorSliceTransfer_v6r1r2
}
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_slice_origin
_idx
)
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_
block_
slice_origin
)
{
{
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_slice_origin_idx
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_slice_origin
_idx
)
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_
block_
slice_origin
)
{
{
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_slice_origin_idx
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
}
private:
private:
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
f58137f2
...
@@ -833,7 +833,8 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -833,7 +833,8 @@ struct BlockToCTileMap_GemmStreamK
printf
(
"cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
printf
(
"cu:%d, occupancy:%d, grids:%d, num_tiles:%d, dp_tiles:%d, sk_num_big_blocks:%d, "
"sk_num_blocks:%d, "
"sk_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"sk_total_iters:%d, dp_start_block_idx:%d, dp_iters_per_block:%d, dp_num_blocks:%d, "
"k_iters_per_tile:%d, k_iters_per_big_block:%d
\n
"
,
"k_iters_per_tile:%d, k_iters_per_big_block:%d, reduction_start_block_idx:%u, "
"sk_tiles:%u, workspace(acc float):%u
\n
"
,
num_cu
,
num_cu
,
occupancy
,
occupancy
,
get_grid_dims
().
x
,
get_grid_dims
().
x
,
...
@@ -846,7 +847,10 @@ struct BlockToCTileMap_GemmStreamK
...
@@ -846,7 +847,10 @@ struct BlockToCTileMap_GemmStreamK
dp_iters_per_block
,
dp_iters_per_block
,
dp_num_blocks
,
dp_num_blocks
,
k_iters_per_tile
.
get
(),
k_iters_per_tile
.
get
(),
k_iters_per_big_block
);
k_iters_per_big_block
,
reduction_start_block_idx
,
get_sk_tiles
(),
get_workspace_size
(
sizeof
(
float
)));
}
}
__host__
__device__
uint32_t
get_sk_total_iters
()
const
__host__
__device__
uint32_t
get_sk_total_iters
()
const
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_streamk.hpp
View file @
f58137f2
...
@@ -23,9 +23,8 @@ namespace ck {
...
@@ -23,9 +23,8 @@ namespace ck {
template
<
typename
GridwiseGemm
>
template
<
typename
GridwiseGemm
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
1
)
#endif
#endif
// kernel_gemm_xdlops_streamk(typename GridwiseGemm::Argument karg)
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
kernel_gemm_xdlops_streamk
(
const
typename
GridwiseGemm
::
FloatAB
*
p_a_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
const
typename
GridwiseGemm
::
FloatAB
*
p_b_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
typename
GridwiseGemm
::
FloatC
*
p_c_grid
,
...
@@ -43,7 +42,6 @@ __global__ void
...
@@ -43,7 +42,6 @@ __global__ void
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
// GridwiseGemm::Run(karg, static_cast<void*>(p_shared));
GridwiseGemm
::
Run
(
p_a_grid
,
GridwiseGemm
::
Run
(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
...
@@ -549,6 +547,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -549,6 +547,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
{
{
// descriptors
// descriptors
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
cluster_length_reduce
=
GetClusterLengthReduction
();
constexpr
auto
reduce_desc
=
make_cluster_descriptor
(
cluster_length_reduce
);
const
auto
reduce_thread_cluster_idx
=
reduce_desc
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
reduce_thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
reduce_thread_cluster_idx
[
I1
];
constexpr
auto
MReduceIters
=
constexpr
auto
MReduceIters
=
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
math
::
integer_divide_ceil
(
Number
<
MPerBlock
>
{},
cluster_length_reduce
.
At
(
I0
));
...
@@ -560,13 +563,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -560,13 +563,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
// make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
// make_tuple(Number<CBlockTransferScalarPerVector_NWaveNPerXDL>{}));
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
acc_thread_buf_load_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{}
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
make_tuple
(
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
constexpr
auto
acc_thread_buf_store_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
1
>
{},
make_tuple
(
I1
,
I1
,
I1
,
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
Number
<
1
>
{},
Number
<
1
>
{},
Number
<
CBlockTransferScalarPerVector_NWaveNPerXDL
>
{}));
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
constexpr
auto
c_partial_acc_block_m_n
=
GetPartialAccBlockDescriptor
();
...
@@ -627,7 +627,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -627,7 +627,10 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// SrcScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// SrcScalarPerVector,
1
,
// SrcScalarStrideInVector,
1
,
// SrcScalarStrideInVector,
false
// SrcResetCoordinateAfterRun,
false
// SrcResetCoordinateAfterRun,
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
0
,
0
)};
>
{
c_partial_acc_block_m_n
,
make_multi_index
(
thread_m_cluster_id
,
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
)};
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
auto
acc_store
=
ThreadwiseTensorSliceTransfer_v1r3
<
FloatAcc
,
// SrcData,
FloatAcc
,
// SrcData,
...
@@ -635,18 +638,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -635,18 +638,19 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
acc_thread_buf_store_desc
),
// SrcDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
// DstDesc,
CElementwiseOperation
,
// ElementwiseOperation,
CElementwiseOperation
,
// ElementwiseOperation,
Sequence
<
0
,
0
,
0
,
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
Sequence
<
1
,
1
,
1
,
CBlockTransferScalarPerVector_NWaveNPerXDL
>
,
// SliceLengths,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
Sequence
<
0
,
1
,
2
,
3
>
,
// DimAccessOrder,
2
,
// DstVectorDim,
3
,
// DstVectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// DstScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// DstScalarPerVector,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
InMemoryDataOperationEnum
::
Set
,
// InMemoryDataOperationEnum DstInMemOp,
3
,
// DstScalarStrideInVector,
1
,
// DstScalarStrideInVector,
false
// DstResetCoordinateAfterRun,
false
// DstResetCoordinateAfterRun,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
>
{
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
make_multi_index
(
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I0
]),
0
,
thread_m_cluster_id
,
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
__builtin_amdgcn_readfirstlane
(
spatial_idx
[
I1
]),
0
),
thread_n_cluster_id
*
CBlockTransferScalarPerVector_NWaveNPerXDL
),
CElementwiseOperation
{}};
CElementwiseOperation
{}};
// block synchronization
// block synchronization
...
@@ -659,8 +663,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -659,8 +663,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
static_for
<
0
,
NReduceIters
,
1
>
{}([
&
](
auto
i_n_reduce
)
{
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
for
(
auto
i
=
tile_acc_offset_start
;
i
<
tile_acc_offset_end
;
i
++
)
{
{
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_partial_acc_buf
=
static_cast
<
FloatAcc
*>
(
p_workspace
)
+
i
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
amd_buffer_coherence_bits
::
glc
>
(
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
i
*
c_partial_acc_block_m_n
.
GetElementSpaceSize
(),
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
c_partial_acc_block_m_n
.
GetElementSpaceSize
());
acc_load
.
Run
(
c_partial_acc_block_m_n
,
acc_load
.
Run
(
c_partial_acc_block_m_n
,
...
@@ -850,12 +857,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -850,12 +857,14 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
();
GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle
();
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static
_cast
<
FloatCShuffle
*>
(
p_shared_block
),
reinterpret
_cast
<
FloatCShuffle
*>
(
p_shared_block
),
c_block_desc_mblock_mpershuffle_nblock_npershuffle
.
GetElementSpaceSize
());
c_block_desc_mblock_mpershuffle_nblock_npershuffle
.
GetElementSpaceSize
());
auto
c_partial_acc_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
auto
c_partial_acc_buf
=
static_cast
<
FloatAcc
*>
(
p_workspace
)
+
block_acc_offset
,
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
,
amd_buffer_coherence_bits
::
glc
>
(
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
reinterpret_cast
<
FloatAcc
*>
(
p_workspace
)
+
block_acc_offset
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
...
@@ -984,7 +993,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
...
@@ -984,7 +993,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_streamk
3
,
// index_t VectorDim,
3
,
// index_t VectorDim,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
fals
e
>
// bool ThreadTransferDstResetCoordinateAfterRun
tru
e
>
// bool ThreadTransferDstResetCoordinateAfterRun
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
{
c_block_desc_mblock_mpershuffle_nblock_npershuffle
,
make_multi_index
(
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
),
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
,
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
f58137f2
...
@@ -685,12 +685,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
...
@@ -685,12 +685,12 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
);
static_cast
<
index_t
>
(
coherence
)
);
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
llvm_amdgcn_raw_buffer_store_fp32x4
(
tmp
.
AsType
<
float4_t
>
()[
Number
<
1
>
{}],
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
dst_wave_addr_offset
+
4
*
sizeof
(
float
),
0
);
static_cast
<
index_t
>
(
coherence
)
);
}
}
}
}
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
else
if
constexpr
(
is_same
<
T
,
half_t
>::
value
)
...
...
library/include/ck/library/utility/device_memory.hpp
View file @
f58137f2
...
@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
...
@@ -20,8 +20,9 @@ __global__ void set_buffer_value(T* p, T x, uint64_t buffer_element_size)
*/
*/
struct
DeviceMem
struct
DeviceMem
{
{
DeviceMem
()
=
delete
;
DeviceMem
()
:
mpDeviceBuf
(
nullptr
),
mMemSize
(
0
)
{}
DeviceMem
(
std
::
size_t
mem_size
);
DeviceMem
(
std
::
size_t
mem_size
);
void
Realloc
(
std
::
size_t
mem_size
);
void
*
GetDeviceBuffer
()
const
;
void
*
GetDeviceBuffer
()
const
;
std
::
size_t
GetBufferSize
()
const
;
std
::
size_t
GetBufferSize
()
const
;
void
ToDevice
(
const
void
*
p
)
const
;
void
ToDevice
(
const
void
*
p
)
const
;
...
...
library/src/utility/device_memory.cpp
View file @
f58137f2
...
@@ -10,20 +10,49 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
...
@@ -10,20 +10,49 @@ DeviceMem::DeviceMem(std::size_t mem_size) : mMemSize(mem_size)
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
}
void
DeviceMem
::
Realloc
(
std
::
size_t
mem_size
)
{
if
(
mpDeviceBuf
)
{
hip_check_error
(
hipFree
(
mpDeviceBuf
));
}
mMemSize
=
mem_size
;
hip_check_error
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
void
*
DeviceMem
::
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
void
*
DeviceMem
::
GetDeviceBuffer
()
const
{
return
mpDeviceBuf
;
}
std
::
size_t
DeviceMem
::
GetBufferSize
()
const
{
return
mMemSize
;
}
std
::
size_t
DeviceMem
::
GetBufferSize
()
const
{
return
mMemSize
;
}
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
const
void
DeviceMem
::
ToDevice
(
const
void
*
p
)
const
{
{
hip_check_error
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
if
(
mpDeviceBuf
)
{
hip_check_error
(
hipMemcpy
(
mpDeviceBuf
,
const_cast
<
void
*>
(
p
),
mMemSize
,
hipMemcpyHostToDevice
));
}
}
}
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
void
DeviceMem
::
FromDevice
(
void
*
p
)
const
{
{
if
(
mpDeviceBuf
)
{
hip_check_error
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
hip_check_error
(
hipMemcpy
(
p
,
mpDeviceBuf
,
mMemSize
,
hipMemcpyDeviceToHost
));
}
}
}
void
DeviceMem
::
SetZero
()
const
{
hip_check_error
(
hipMemset
(
mpDeviceBuf
,
0
,
mMemSize
));
}
void
DeviceMem
::
SetZero
()
const
{
if
(
mpDeviceBuf
)
{
hip_check_error
(
hipMemset
(
mpDeviceBuf
,
0
,
mMemSize
));
}
}
DeviceMem
::~
DeviceMem
()
{
hip_check_error
(
hipFree
(
mpDeviceBuf
));
}
DeviceMem
::~
DeviceMem
()
{
if
(
mpDeviceBuf
)
{
hip_check_error
(
hipFree
(
mpDeviceBuf
));
}
}
profiler/include/profiler/profile_gemm_streamk_impl.hpp
View file @
f58137f2
...
@@ -155,6 +155,13 @@ bool profile_gemm_streamk_impl(int do_verification,
...
@@ -155,6 +155,13 @@ bool profile_gemm_streamk_impl(int do_verification,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
NumSKBlocks
);
NumSKBlocks
);
DeviceMem
workspace
;
std
::
size_t
workspace_size
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
);
if
(
workspace_size
!=
0
)
{
workspace
.
Realloc
(
workspace_size
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
,
workspace
.
GetDeviceBuffer
());
}
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment