Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eceea10a
Commit
eceea10a
authored
Aug 03, 2022
by
Anthony Chang
Browse files
clean up
parent
4ee34028
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
95 additions
and
268 deletions
+95
-268
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
+6
-17
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+0
-16
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+1
-15
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
...tion/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
+6
-23
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+78
-145
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+2
-34
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+0
-11
include/ck/utility/static_buffer.hpp
include/ck/utility/static_buffer.hpp
+1
-0
include/ck/utility/tuple_helper.hpp
include/ck/utility/tuple_helper.hpp
+0
-6
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
No files found.
example/31_batched_gemm_gemm/batched_gemm_gemm_xdl_fp16.cpp
View file @
eceea10a
...
...
@@ -216,21 +216,19 @@ int main(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
-
0.
5
,
0.
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
-
0.
5
,
0.
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
2
<
ADataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B0DataType
>
{
-
5
,
5
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
break
;
case
2
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
1
<
ADataType
>
{
1
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
1
<
B0DataType
>
{
1
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
2
<
B1DataType
>
{
-
5
,
5
});
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_
3
<
ADataType
>
{
0.0
,
1.0
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B0DataType
>
{
0.0
,
1.0
});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_
3
<
B1DataType
>
{
-
0.
5
,
0.
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_1
<
ADataType
>
{
1
});
// b0_k_n.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_Sequential
<
1
>
{});
b1_n_o
.
GenerateTensorValue
(
GeneratorTensor_Diagonal
<
B1DataType
>
{});
// b1_n_o.GenerateTensorValue(GeneratorTensor_Sequential<0>{});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
...
...
@@ -308,15 +306,6 @@ int main(int argc, char* argv[])
ref_gemm1_invoker
.
Run
(
ref_gemm1_argument
);
// LogRangeAsType<float>(std::cout << "a_m_k: ", a_m_k.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b0_k_n : ", b0_k_n.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "b1_n_o : ", b1_n_o.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "c_m_o_device_result : ", c_m_o_device_result.mData, ",") << std::endl;
std
::
cout
<<
"b0_k_n(0, 0) = "
<<
(
float
)
b0_k_n
(
0
,
0
)
<<
", b0_k_n(1, 0) = "
<<
(
float
)
b0_k_n
(
1
,
0
)
<<
", b0_k_n(0, 1) = "
<<
(
float
)
b0_k_n
(
0
,
1
)
<<
", b0_k_n(1, 1) = "
<<
(
float
)
b0_k_n
(
1
,
1
)
<<
std
::
endl
;
return
ck
::
utils
::
check_err
(
c_m_o_device_result
.
mData
,
c_m_o_host_result
.
mData
)
?
0
:
1
;
}
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
eceea10a
...
...
@@ -158,22 +158,6 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
Tuple4
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
#if 0
if(!TransposeC && hipThreadIdx_x % 32 < 8)
{
printf("bid %zd tid %zd, a_mma = %d, %d, %d, %d, b_mma = %d, %d, %d, %d\n",
hipBlockIdx_x,
hipThreadIdx_x,
a_origin[Number<0>{}],
a_origin[Number<1>{}],
a_origin[Number<2>{}],
a_origin[Number<3>{}],
b_origin[Number<0>{}],
b_origin[Number<1>{}],
b_origin[Number<2>{}],
b_origin[Number<3>{}]);
}
#endif
static_assert
(
AMmaTileDesc
::
IsKnownAtCompileTime
()
&&
BMmaTileDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
eceea10a
...
...
@@ -81,21 +81,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
#if 0
if (std::is_same<Sequence<16,64,2>, BlockSliceLengths>::value)
{
auto s = src_block_slice_origin + thread_data_idx_begin;
auto d = dst_block_slice_origin + thread_data_idx_begin;
printf("bid %zd tid %zd, src origin %d %d %d, dst origin %d %d %d\n",
hipBlockIdx_x, hipThreadIdx_x,
s[Number<0>{}],
s[Number<1>{}],
s[Number<2>{}],
d[Number<0>{}],
d[Number<1>{}],
d[Number<2>{}]);
}
#endif
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
eceea10a
...
...
@@ -162,7 +162,7 @@ template <typename ALayout,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_l
oop
_s
cheduler
()
>
LoopScheduler
LoopSched
=
L
oop
S
cheduler
::
Default
>
struct
DeviceGemmGemm_Xdl_CShuffle
:
public
BaseOperator
// TODO ANT: inherit from DeviceGemmGemm subtype
{
using
DeviceOp
=
DeviceGemmGemm_Xdl_CShuffle
;
...
...
@@ -553,7 +553,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
fals
e
,
tru
e
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -561,7 +561,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
fals
e
,
tru
e
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
...
...
@@ -655,24 +655,6 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
<< arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
}
#endif
// TODO ANT: block id to ctilemap should infer acc0tile map
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b1_grid_desc_bk0_n_bk1_
,
...
...
@@ -685,7 +667,7 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
batch_count_
;
//
TODO ANT: K for gemm1
//
Gemm0_K
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
...
@@ -728,7 +710,8 @@ struct DeviceGemmGemm_Xdl_CShuffle : public BaseOperator // TODO ANT: inherit fr
arg
.
compute_base_ptr_of_batch_
);
};
// TODO ANT: handle tail loops for gemm0 & gemm1
// Gemm1_K is split into Gemm1_K0/K1 where K1 is known at compile time, so we only need
// to concern Gemm0's loop
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
ave_time
=
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
eceea10a
...
...
@@ -50,7 +50,7 @@ template <typename FloatAB,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
...
...
@@ -58,7 +58,7 @@ template <typename FloatAB,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
// ignored
index_t
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
...
...
@@ -75,6 +75,9 @@ template <typename FloatAB,
LoopScheduler
LoopSched
>
struct
GridwiseBatchedGemmGemm_Xdl_CShuffle
{
static_assert
(
LoopSched
==
LoopScheduler
::
Default
,
"Non-default loop scheduler is currently not supported"
);
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -91,8 +94,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
// Gemm1
static
constexpr
auto
AccK1
=
Number
<
4
>
{};
// TODO ANT: get from mfma_type.mfma_group_size
static
constexpr
auto
AccK0
=
Number
<
NPerBlock
/
AccK1
.
value
>
{};
static
constexpr
auto
B1K0
=
Number
<
Gemm1KPerBlock
/
B1K1Value
>
{};
static
constexpr
auto
B1K1
=
Number
<
B1K1Value
>
{};
...
...
@@ -148,7 +149,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
Gemm1NWaves
=
Gemm1NPerBlock
/
(
Gemm1NXdlPerWave
*
NPerXdl
);
// Sequence<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>{}.foo(); // <2, 1, 32>
return
MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K
<
Gemm1NXdlPerWave
,
Gemm1NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
...
...
@@ -169,18 +169,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
}
// template <typename BlockwiseGemm>
// __host__ __device__ static constexpr auto
// GetAccBlockDescriptor_AK0PerBlock_MPerBlock_AK1(const BlockwiseGemm& blockwise_gemm)
// {
// constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
// blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// return make_naive_tensor_descriptor(
// make_tuple(B1K0, Number<Gemm1NPerBlock>{}, B1BK1),
// make_tuple(Number<Gemm1NPerBlock + B1BlockLdsExtraN>{} * B1K1, B1K1, I1));
// }
__host__
__device__
static
constexpr
auto
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B1 matrix in LDS memory, dst of blockwise copy
...
...
@@ -266,26 +254,21 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
}
if
(
!
(
NPerBlock
%
Gemm1KPerBlock
==
0
))
{
return
false
;
}
// check gridwise gemm pipeline
// check gemm0 gridwise gemm pipeline
const
auto
num_gemm0_k_loop
=
K
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm0_k_loop
))
{
return
false
;
}
const
auto
num_gemm1_k_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_inn
er
_
lo
op
))
// check gemm1 gridwise gemm pipeline
if
(
!
(
NPerBlock
%
Gemm1KP
er
B
lo
ck
==
0
))
{
return
false
;
}
const
auto
num_gemm1_k_
out
er_loop
=
N
/
N
PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_
out
er_loop
))
const
auto
num_gemm1_k_
inn
er_loop
=
N
PerBlock
/
Gemm1K
PerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_gemm1_k_
inn
er_loop
))
{
return
false
;
}
...
...
@@ -301,7 +284,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
true
;
}
// TODO ANT: also consider gemm1 loop
__host__
__device__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
...
...
@@ -395,11 +377,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// for n in N0: // gemm1 summation loop
// for k in K0: // gemm0 summation loop
// acc0 += A[m][k] * B0[k][n] // acc0[m][n]
// acc1 += acc0 * B1[n][o] // acc1[m][o]
//
// set up Gemm0
//
...
...
@@ -425,8 +402,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
true
,
//
TODO ANT: check if false
true
,
true
,
//
SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
...
...
@@ -456,8 +433,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
//
TODO ANT: check if false
true
,
true
,
//
SrcResetCoord
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
// will loop over GemmN dimension
...
...
@@ -466,12 +443,17 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// Fused Gemm+Gemm pipeline
// for n in N0:
// for k in K0:
// acc[m][n] += A[m][k] * B0[k][n]
// acc1[m][o] += acc[m][n] * B1[n][o]
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
// TODO ANT: to refactor: blockwise gemm output layout
// TODO ANT: interwave scheduling
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
...
...
@@ -509,8 +491,9 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
auto
b_block_reset_copy_step
=
make_multi_index
(
-
b_grid_desc_bk0_n_bk1
.
GetLength
(
I0
),
NPerBlock
,
0
);
// gridwise GEMM pipeline
// Only supports LoopScheduler::Default
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
>
();
GridwiseGemmPipeline_v1_Selector
<
NumGemmKPrefetchStage
,
LoopSched
uler
::
Default
>
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
...
...
@@ -520,7 +503,7 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// set up Gemm1
//
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to
A
data type
// Acc matrix threadwise copy: AccVGPR to VGPR and downcast to
XDL input
data type
constexpr
auto
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
=
blockwise_gemm
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
();
...
...
@@ -533,44 +516,46 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
constexpr
auto
n3
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I6
);
constexpr
auto
n4
=
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
.
GetLength
(
I7
);
constexpr
auto
a1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
n4
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1KPerBlock
/
B1K1
,
0
,
0
);
// acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
// n0_n1_n2_n3 -> k0
// m0_m1_m2 -> m
// n4 -> k1
// NOTE: had to use merge_v3 or will spit out compilation errors
constexpr
auto
acc_thread_desc_k0_m_k1
=
transform_tensor_descriptor
(
acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
// NOTE: had to use merge_v3 or it will spit out weird errors
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
n0
,
n1
,
n2
,
n3
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
m0
,
m1
,
m2
)),
make_pass_through_transform
(
n4
)),
make_tuple
(
Sequence
<
1
,
3
,
5
,
6
>
{},
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 thread descriptor for iterating Acc thread descriptor
// n2 num_groups_per_blk, n3 num_input_blks, n4 group_size // FIXME ANT: use block desc N3 instead of hardcoding
constexpr
auto
A1ThreadSlice
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
2
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
index_t
A1K0
=
A1ThreadSlice
[
I0
];
constexpr
index_t
A1K1
=
A1ThreadSlice
[
I2
];
// A1 matrix in AccVGPR
// N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
constexpr
auto
AccN3
=
blockwise_gemm
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4
().
GetLength
(
I6
);
constexpr
auto
A1ThreadSlice_K0_M_K1
=
make_tuple
(
Number
<
Gemm1KPerBlock
/
n4
/
AccN3
>
{},
Number
<
m0
*
m1
*
m2
>
{},
Number
<
n4
>
{});
constexpr
auto
A1ThreadSliceK0
=
A1ThreadSlice_K0_M_K1
[
I0
];
constexpr
auto
A1ThreadSliceM
=
A1ThreadSlice_K0_M_K1
[
I1
];
constexpr
auto
A1ThreadSliceK1
=
A1ThreadSlice_K0_M_K1
[
I2
];
constexpr
auto
a1_thread_desc_k0_m_k1
=
make_naive_tensor_descriptor
(
A1ThreadSlice
,
make_tuple
(
A1ThreadSlice
[
I1
]
*
A1ThreadSlice
[
I2
]
,
A1ThreadSlice
[
I2
]
,
I1
));
// make_tuple(Number<A1K0>{}, Number<m0 * m1 * m2>{}, Number<n4>{}).foo(); // <8, 1, 4>
A1ThreadSlice
_K0_M_K1
,
make_tuple
(
A1ThreadSlice
M
*
A1ThreadSlice
K1
,
A1ThreadSlice
K1
,
I1
));
// B1 matrix in LDS memory, dst of blockwise copy
constexpr
auto
b1_block_desc_bk0_n_bk1
=
GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A1 matrix blockwise copy
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
// FIXME: this cannot copy from static_buffer to static_buffer because v3r1 uses integer offset
// which is useless against static_buffer because it requires integral constant
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v1r3_Static
<
FloatGemmAcc
,
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_StaticToStatic
<
FloatGemmAcc
,
FloatAB
,
decltype
(
acc_thread_desc_k0_m_k1
),
decltype
(
a1_thread_desc_k0_m_k1
),
Sequence
<
A1K0
,
m0
*
m1
*
m2
,
A1
K1
>
,
Sequence
<
A1ThreadSliceK0
,
A1ThreadSliceM
,
A1ThreadSlice
K1
>
,
Sequence
<
1
,
0
,
2
>
,
2
,
n4
>
{};
...
...
@@ -596,8 +581,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
B1BlockTransferDstScalarPerVector_BK1
,
1
,
1
,
true
,
// TODO ANT: check if false
true
,
B1ThreadTransferSrcResetCoordinateAfterRun
,
true
,
// DstResetCoord
NumGemmKPrefetchStage
>
(
b1_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
...
...
@@ -637,19 +622,19 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
false
,
Gemm1KPack
,
// AMmaKStride
Gemm1KPack
*
XdlopsGemm
<
FloatAB
,
MPerXdl
,
NPerXdl
,
Gemm1KPack
,
false
>
{}.
K0PerXdlops
>
{
make_tuple
(
0
,
0
,
0
,
0
)
};
// TransposeC
make_tuple
(
0
,
0
,
0
,
0
)};
// TransposeC
auto
c_thread_buf
=
gemm1_blockwise_gemm
.
GetCThreadBuffer
();
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
const
index_t
num_gemm1_k_block_outer_loop
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
)
/
NPerBlock
;
constexpr
index_t
num_gemm1_k_block_inner_loop
=
NPerBlock
/
Gemm1KPerBlock
;
// Initialize C
c_thread_buf
.
Clear
();
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
// j loop
do
{
// gemm0
...
...
@@ -668,88 +653,40 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
blockwise_gemm
,
acc_thread_buf
,
num_k_block_main_loop
);
#if 0
if(hipThreadIdx_x == 0)
printf("gemm1_k_block_outer_index %d, num_gemm1_k_block_outer_loop %d\n",
gemm1_k_block_outer_index,
num_gemm1_k_block_outer_loop);
#endif
#if 0
if (hipBlockIdx_x == 0 && hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, acc_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, acc[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, I.value, acc_thread_buf[I]);
});
}
#endif
// gemm1
{
// TODO: explore using dynamic buffer for a1 thread buffer
// For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
// RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
// the A1 source buffer is static buffer holding the output of first GEMM and
// requires constexpr offset by design. Therefore, we pass tensor coordinate offset
// explicitly in Run() below.
// preload data into LDS
// FIXME ANT: do not need a1 copy here?
// a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// acc_thread_buf,
// a1_thread_desc_k0_m_k1,
// make_tuple(I0, I0, I0),
// a1_thread_buf
// );
#if 0
if (hipThreadIdx_x % 32 < 4) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, 0, I.value, (float)a1_thread_buf[I]);
});
}
#endif
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
// TODO ANT: how to access static buffer while using tensor coordinate?
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
b1_blockwise_copy
.
RunWrite
(
b1_block_desc_bk0_n_bk1
,
b1_block_buf
);
#if 0
if (hipBlockIdx_x == 0)
{
debug::print_shared(b1_block_buf.p_data_, index_t(b1_block_desc_bk0_n_bk1.GetElementSpaceSize()));
}
#endif
// main body
if
constexpr
(
num_gemm1_k_block_inner_loop
>
1
)
{
static_for
<
0
,
num_gemm1_k_block_inner_loop
-
1
,
1
>
{}([
&
](
auto
i
)
{
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
i
*
A1K0
>
{},
I0
,
I0
),
make_tuple
(
Number
<
i
*
A1
ThreadSlice
K0
>
{},
I0
,
I0
),
acc_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
a1_thread_buf
);
#if 0
if (hipBlockIdx_x == 0 && hipThreadIdx_x % 32 < 8) {
static_for<0, a1_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, a1[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, (float)a1_thread_buf[I]);
});
}
#endif
a1_thread_buf
);
b1_blockwise_copy
.
RunRead
(
b1_grid_desc_bk0_n_bk1
,
b1_grid_buf
);
block_sync_lds
();
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
#if 0
if (hipThreadIdx_x % 32 < 8) {
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, i.value, I.value, c_thread_buf[I]);
});
}
#endif
block_sync_lds
();
// a1_blockwise_copy.MoveSrcSliceWindow(acc_thread_desc_k0_m_k1,
// a1_block_slice_copy_step);
b1_blockwise_copy
.
MoveSrcSliceWindow
(
b1_grid_desc_bk0_n_bk1
,
b1_block_slice_copy_step
);
...
...
@@ -758,8 +695,10 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
}
// tail
{
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1K0
>
{},
I0
,
I0
),
a1_blockwise_copy
.
Run
(
acc_thread_desc_k0_m_k1
,
make_tuple
(
Number
<
(
num_gemm1_k_block_inner_loop
-
1
)
*
A1ThreadSliceK0
>
{},
I0
,
I0
),
acc_thread_buf
,
a1_thread_desc_k0_m_k1
,
make_tuple
(
I0
,
I0
,
I0
),
...
...
@@ -769,19 +708,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
gemm1_blockwise_gemm
.
Run
(
a1_thread_buf
,
b1_block_buf
,
c_thread_buf
);
}
}
// end gemm1
#if 0
if (hipThreadIdx_x % 32 < 8) {
static_for<0, c_thread_buf.Size(), 1>{}([&](auto I) {
printf("bid %zd tid %zd, iter %d, c[%d] = %f\n", hipBlockIdx_x, hipThreadIdx_x, num_gemm1_k_block_inner_loop - 1, I.value, c_thread_buf[I]);
});
}
#endif
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_block_reset_copy_step
);
// rewind K and step N
// don't need to rewind b1
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc_ak0_m_ak1
,
a_block_reset_copy_step
);
// rewind K
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc_bk0_n_bk1
,
b_block_reset_copy_step
);
// rewind K and step N
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
}
while
(
++
gemm1_k_block_outer_index
<
num_gemm1_k_block_outer_loop
);
// end j loop
// shuffle C and write out
{
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
eceea10a
...
...
@@ -1145,10 +1145,6 @@ struct ThreadwiseTensorSliceTransfer_v4
src_desc
,
src_data_coord
);
// copy data from src_buf into src_tmp_vector
#if 0
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
#else
if
constexpr
(
SrcBuffer
::
IsDynamicBuffer
())
{
src_tmp_vector
.
template
AsType
<
src_vector_t
>()(
Number
<
0
>
{})
=
...
...
@@ -1164,33 +1160,7 @@ struct ThreadwiseTensorSliceTransfer_v4
src_tmp_vector
.
template
AsType
<
SrcData
>()(
i
)
=
src_buf
[
Number
<
src_offset
>
{}];
});
// if constexpr(StaticBufferTupleOfVector)
// {
// // constexpr auto offset_nd = SrcRefToOriginDisplacement{} + data_to_origin_disp_idx;
// // // offset_nd.foo();
// // constexpr auto offset = src_desc.CalculateOffset(offset_nd);
// // src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// // src_buf.template GetAsType<src_vector_t>(Number<offset>{});
// static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
// // constexpr auto src_offset_nd = src_ref_to_origin_disp_idx +
// // data_to_origin_disp_idx + i * src_scalar_step_in_vector;
// // constexpr auto src_offset = src_desc.CalculateOffset(src_offset_nd);
// constexpr auto src_offset = src_desc.CalculateOffset(SrcRefToOriginDisplacement{});
// // SrcData s = src_buf[Number<src_offset>{}];
// SrcData s = src_buf[Number<0>{}];
// // apply type convert
// src_tmp_vector.template AsType<SrcData>()(i) = s;
// });
// }
// else
// {
// src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
// src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(),
// is_src_valid);
// }
}
#endif
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t
<
DstData
,
SrcScalarPerVector
>
dst_tmp_vector
;
...
...
@@ -1236,16 +1206,14 @@ template <typename SrcData,
typename
DimAccessOrder
,
index_t
DstVectorDim
,
index_t
DstScalarPerVector
,
// InMemoryDataOperationEnum DstInMemOp,
// index_t DstScalarStrideInVector,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_
v1r3_
Static
struct
ThreadwiseTensorSliceTransfer_
StaticTo
Static
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_
v1r3_
Static
()
__device__
constexpr
ThreadwiseTensorSliceTransfer_
StaticTo
Static
()
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
eceea10a
...
...
@@ -30,17 +30,6 @@ enum struct MfmaInstr
mfma_f64_16x16x4f64
};
// template <typename T, bool TransposeC>
// struct mfma_base_type
// {
// template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
// __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
// {
// if constexpr (!TransposeC) T::run(a, b, reg_c);
// else T::run(b, a, reg_c);
// }
// };
template
<
MfmaInstr
instr
>
struct
mfma_type
;
...
...
include/ck/utility/static_buffer.hpp
View file @
eceea10a
...
...
@@ -72,6 +72,7 @@ struct StaticBufferTupleOfVector
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
index_t
Size
()
{
return
s_per_buf
;
};
// Get S
// i is offset of S
template
<
index_t
I
>
...
...
include/ck/utility/tuple_helper.hpp
View file @
eceea10a
...
...
@@ -78,10 +78,4 @@ __host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
Size
(),
1
>::
type
{});
}
template
<
index_t
...
Is
>
__host__
__device__
constexpr
Tuple
<
Number
<
Is
>
...
>
to_tuple
(
Sequence
<
Is
...
>
)
{
return
Tuple
<
Number
<
Is
>
...
>
(
Number
<
Is
>
{}...);
}
}
// namespace ck
library/include/ck/library/utility/check_err.hpp
View file @
eceea10a
...
...
@@ -134,7 +134,7 @@ check_err(const std::vector<T>& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
128
)
if
(
err_count
<
5
)
{
std
::
cout
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment