Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eceea10a
"G" did not exist on "3dbd6a8f4d7f3f27df3e3433fc2ab9fe1e7a873d"
Commit
eceea10a
authored
Aug 03, 2022
by
Anthony Chang
Browse files
clean up
parent
4ee34028
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
95 additions
and
268 deletions
+95
-268
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+6
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+0
-16
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+1
-15
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+6
-23
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+78
-145
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-34
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+0
-11
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+1
-0
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+0
-6
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
eceea10a
...
@@ -216,21 +216,19 @@ int main(int argc, char* argv[])
...
@@ -216,21 +216,19 @@ int main(int argc, char* argv[])
{
{
case
0
:
break
;
case
0
:
break
;
case
1
:
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
-
0.
5
,
0.
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
-
0.
5
,
0.
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B0DataType
>
{
-
5
,
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
break
;
break
;
case
2
:
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{
1
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
0.0
,
1.0
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
break
;
break
;
default:
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
@@ -308,15 +306,6 @@ int main(int argc, char* argv[])
...
@@ -308,15 +306,6 @@ int main(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std
::
cout
<<
"b0_k_n(0, 0) = "
<<
(
float
)
b0_k_n
(
0
,
0
)
<<
", b0_k_n(1, 0) = "
<<
(
float
)
b0_k_n
(
1
,
0
)
<<
", b0_k_n(0, 1) = "
<<
(
float
)
b0_k_n
(
0
,
1
)
<<
", b0_k_n(1, 1) = "
<<
(
float
)
b0_k_n
(
1
,
1
)
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
}
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
eceea10a
...
@@ -158,22 +158,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -158,22 +158,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
{
#if 0
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
eceea10a
...
@@ -81,21 +81,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
...
@@ -81,21 +81,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
eceea10a
...
@@ -162,7 +162,7 @@ template <typename ALayout,
...
@@ -162,7 +162,7 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_l
oop
_s
cheduler
()
>
LoopScheduler
LoopSched
=
L
oop
S
cheduler
::
Default
>
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
BaseOperator
// TODO ANT: inherit from DeviceGemmGemm subtype
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
BaseOperator
// TODO ANT: inherit from DeviceGemmGemm subtype
{
{
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
...
@@ -553,7 +553,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -553,7 +553,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
fals
e
,
tru
e
,
ABlockLdsExtraM
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -561,7 +561,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -561,7 +561,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
fals
e
,
tru
e
,
BBlockLdsExtraN
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferThreadClusterArrangeOrder
,
...
@@ -655,24 +655,6 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -655,24 +655,6 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
// TODO ANT: block id to ctilemap should infer acc0tile map
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
...
@@ -685,7 +667,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -685,7 +667,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
//
TODO ANT: K for gemm1
//
Gemm0_K
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
@@ -728,7 +710,8 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
...
@@ -728,7 +710,8 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg
.
compute_base_ptr_of_batch_
);
arg
.
compute_base_ptr_of_batch_
);
};
};
// TODO ANT: handle tail loops for gemm0 & gemm1
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
eceea10a
...
@@ -50,7 +50,7 @@ template <typename FloatAB,
...
@@ -50,7 +50,7 @@ template <typename FloatAB,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -58,7 +58,7 @@ template <typename FloatAB,
...
@@ -58,7 +58,7 @@ template <typename FloatAB,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
...
@@ -75,6 +75,9 @@ template <typename FloatAB,
...
@@ -75,6 +75,9 @@ template <typename FloatAB,
LoopScheduler
LoopSched
>
LoopScheduler
LoopSched
>
struct
GridwiseBatchedGemmGemm_Xdl_CShuffle
struct
GridwiseBatchedGemmGemm_Xdl_CShuffle
{
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
@@ -91,8 +94,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -91,8 +94,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
// Gemm1
// Gemm1
static
constexpr
auto
AccK1
=
Number
<
4
>
{};
// TODO ANT: get from mfma_type.mfma_group_size
static
constexpr
auto
AccK0
=
Number
<
NPerBlock
/
AccK1
.
value
>
{};
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
...
@@ -148,7 +149,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -148,7 +149,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// Sequence<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>{}.foo(); // <2, 1, 32>
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
BBlockDesc_BK0_N_BK1
{});
}
}
...
@@ -169,18 +169,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -169,18 +169,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
}
// template <typename BlockwiseGemm>
// __host__ __device__ static constexpr auto
// GetAccBlockDescriptor_AK0PerBlock_MPerBlock_AK1(const BlockwiseGemm& blockwise_gemm)
// {
// constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// return make_naive_tensor_descriptor(
// make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1BK1),
// make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
// }
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
...
@@ -266,26 +254,21 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -266,26 +254,21 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
// check gemm0 gridwise gemm pipeline
{
return
false
;
}
// check gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
{
return
false
;
return
false
;
}
}
const
auto
num_gemm1_k_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// check gemm1 gridwise gemm pipeline
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inn
er
_
lo
op
))
if
(
!
(
NPerBlock
%
Gemm1KP
er
B
lo
ck
==
0
))
{
{
return
false
;
return
false
;
}
}
const
auto
num_gemm1_k_
out
er_loop
=
N
/
N
PerBlock
;
const
auto
num_gemm1_k_
inn
er_loop
=
N
PerBlock
/
Gemm1K
PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_
out
er_loop
))
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_
inn
er_loop
))
{
{
return
false
;
return
false
;
}
}
...
@@ -301,7 +284,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -301,7 +284,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
true
;
return
true
;
}
}
// TODO ANT: also consider gemm1 loop
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
{
const
index_t
num_loop
=
K
/
KPerBlock
;
const
index_t
num_loop
=
K
/
KPerBlock
;
...
@@ -395,11 +377,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -395,11 +377,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// for n in N0: // gemm1 summation loop
// for k in K0: // gemm0 summation loop
// acc0 += A[m][k] * B0[k][n] // acc0[m][n]
// acc1 += acc0 * B1[n][o] // acc1[m][o]
//
//
// set up Gemm0
// set up Gemm0
//
//
...
@@ -425,8 +402,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -425,8 +402,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
ABlockTransferDstScalarPerVector_AK1
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
1
,
1
,
true
,
//
TODO ANT: check if false
true
,
//
SrcResetCoord
true
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
...
@@ -456,8 +433,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -456,8 +433,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
1
,
1
,
true
,
//
TODO ANT: check if false
true
,
//
SrcResetCoord
true
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
...
@@ -466,12 +443,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -466,12 +443,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// TODO ANT: to refactor: blockwise gemm output layout
// TODO ANT: to refactor: blockwise gemm output layout
// TODO ANT: interwave scheduling
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
...
@@ -509,8 +491,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -509,8 +491,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
uler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
@@ -520,7 +503,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -520,7 +503,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// set up Gemm1
// set up Gemm1
//
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to
A
data type
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to
XDL input
data type
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
...
@@ -533,47 +516,49 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -533,47 +516,49 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
a1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
n4
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// m0_m1_m2 -> m
// n4 -> k1
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
constexpr
auto
acc_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
// NOTE: had to use merge_v3 or it will spit out weird errors
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 thread descriptor for iterating Acc thread descriptor
// A1 matrix in AccVGPR
// n2 num_groups_per_blk, n3 num_input_blks, n4 group_size // FIXME ANT: use block desc N3 instead of hardcoding
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
A1ThreadSlice
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
2
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
AccN3
=
constexpr
index_t
A1K0
=
A1ThreadSlice
[
I0
];
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
index_t
A1K1
=
A1ThreadSlice
[
I2
];
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
AccN3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice
,
A1ThreadSlice
_K0_M_K1
,
make_tuple
(
A1ThreadSlice
[
I1
]
*
A1ThreadSlice
[
I2
]
,
A1ThreadSlice
[
I2
]
,
I1
));
make_tuple
(
A1ThreadSlice
M
*
A1ThreadSlice
K1
,
A1ThreadSlice
K1
,
I1
));
// make_tuple(Number<A1K0>{}, Number<m0 * m1 * m2>{}, Number<n4>{}).foo(); // <8, 1, 4>
// B1 matrix in LDS memory, dst of blockwise copy
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
// A1 matrix blockwise copy
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
FloatGemmAcc
,
// FIXME: this cannot copy from static_buffer to static_buffer because v3r1 uses integer offset
FloatAB
,
// which is useless against static_buffer because it requires integral constant
decltype
(
acc_thread_desc_k0_m_k1
),
auto
a1_blockwise_copy
=
decltype
(
a1_thread_desc_k0_m_k1
),
ThreadwiseTensorSliceTransfer_v1r3_Static
<
FloatGemmAcc
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSliceK1
>
,
FloatAB
,
Sequence
<
1
,
0
,
2
>
,
decltype
(
acc_thread_desc_k0_m_k1
),
2
,
decltype
(
a1_thread_desc_k0_m_k1
),
n4
>
{};
Sequence
<
A1K0
,
m0
*
m1
*
m2
,
A1K1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{};
// B1 matrix blockwise copy
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
auto
b1_blockwise_copy
=
...
@@ -596,8 +581,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -596,8 +581,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
B1BlockTransferDstScalarPerVector_BK1
,
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
1
,
1
,
true
,
// TODO ANT: check if false
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
NumGemmKPrefetchStage
>
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
...
@@ -637,19 +622,19 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -637,19 +622,19 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
false
,
false
,
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
make_tuple
(
0
,
0
,
0
,
0
)
make_tuple
(
0
,
0
,
0
,
0
)};
// TransposeC
};
// TransposeC
auto
c_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize C
// Initialize C
c_thread_buf
.
Clear
();
c_thread_buf
.
Clear
();
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
// j loop
do
do
{
{
// gemm0
// gemm0
...
@@ -668,88 +653,40 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -668,88 +653,40 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
blockwise_gemm
,
blockwise_gemm
,
acc_thread_buf
,
acc_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
#if 0
if(hipThreadIdx_x == 0)
printf("gemm1_k_block_outer_index %d, num_gemm1_k_block_outer_loop %d\n",
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
#endif
#if 0
if (hipBlockIdx_x == 0 && hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, acc[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, I.value, acc_thread_buf[I]);
});
}
#endif
// gemm1
// gemm1
{
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
// preload data into LDS
// FIXME ANT: do not need a1 copy here?
// a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// acc_thread_buf,
// a1_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// a1_thread_buf
// );
#if 0
if (hipThreadIdx_x % 32 < 4) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, 0, I.value, (float)a1_thread_buf[I]);
});
}
#endif
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
// TODO ANT: how to access static buffer while using tensor coordinate?
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_, index_t(b1_block_desc_bk0_n_bk1.GetElementSpaceSize()));
}
#endif
// main body
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1K0
>
{},
I0
,
I0
),
make_tuple
(
Number
<
i
*
A1
ThreadSlice
K0
>
{},
I0
,
I0
),
acc_thread_buf
,
acc_thread_buf
,
a1_thread_desc_k0_m_k1
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
a1_thread_buf
);
);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, (float)a1_thread_buf[I]);
});
}
#endif
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
block_sync_lds
();
block_sync_lds
();
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
#if 0
if (hipThreadIdx_x % 32 < 8) {
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, c_thread_buf[I]);
});
}
#endif
block_sync_lds
();
block_sync_lds
();
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
b1_block_slice_copy_step
);
...
@@ -758,30 +695,26 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -758,30 +695,26 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
}
}
// tail
// tail
{
{
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
a1_blockwise_copy
.
Run
(
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1K0
>
{},
I0
,
I0
),
acc_thread_desc_k0_m_k1
,
acc_thread_buf
,
make_tuple
(
a1_thread_desc_k0_m_k1
,
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
),
acc_thread_buf
,
a1_thread_buf
);
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
block_sync_lds
();
block_sync_lds
();
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
}
}
}
// end gemm1
}
// end gemm1
#if 0
if (hipThreadIdx_x % 32 < 8) {
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
a_block_reset_copy_step
);
// rewind K
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, num_gemm1_k_block_inner_loop - 1, I.value, c_thread_buf[I]);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
});
b_block_reset_copy_step
);
// rewind K and step N
}
#endif
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_block_reset_copy_step
);
// rewind K and step N
// don't need to rewind b1
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle C and write out
// shuffle C and write out
{
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
eceea10a
...
@@ -1145,10 +1145,6 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1145,10 +1145,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
...
@@ -1164,33 +1160,7 @@ struct ThreadwiseTensorSliceTransfer_v4
...
@@ -1164,33 +1160,7 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
src_buf
[
Number
<
src_offset
>
{}];
});
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
...
@@ -1236,16 +1206,14 @@ template <typename SrcData,
...
@@ -1236,16 +1206,14 @@ template <typename SrcData,
typename
DimAccessOrder
,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_
v1r3_
Static
struct
ThreadwiseTensorSliceTransfer_
StaticTo
Static
{
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_
v1r3_
Static
()
__device__
constexpr
ThreadwiseTensorSliceTransfer_
StaticTo
Static
()
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
"wrong! Desc need to known at compile-time"
);
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
eceea10a
...
@@ -30,17 +30,6 @@ enum struct MfmaInstr
...
@@ -30,17 +30,6 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
mfma_f64_16x16x4f64
};
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template
<
MfmaInstr
instr
>
template
<
MfmaInstr
instr
>
struct
mfma_type
;
struct
mfma_type
;
...
...
include/ck/utility/static_buffer.hpp
View file @
eceea10a
...
@@ -72,6 +72,7 @@ struct StaticBufferTupleOfVector
...
@@ -72,6 +72,7 @@ struct StaticBufferTupleOfVector
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
// Get S
// Get S
// i is offset of S
// i is offset of S
template
<
index_t
I
>
template
<
index_t
I
>
...
...
include/ck/utility/tuple_helper.hpp
View file @
eceea10a
...
@@ -78,10 +78,4 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -78,10 +78,4 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
Tuple
<
Number
<
Is
>
...
>
to_tuple
(
Sequence
<
Is
...
>
)
{
return
Tuple
<
Number
<
Is
>
...
>
(
Number
<
Is
>
{}...);
}
}
// namespace ck
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
eceea10a
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
{
max_err
=
err
>
max_err
?
err
:
max_err
;
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
err_count
++
;
if
(
err_count
<
128
)
if
(
err_count
<
5
)
{
{
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment