"...composable_kernel.git" did not exist on "d66421fe34f2b69de7fe53876a7eb5dea4f3fd9f"
Commit 74ef5021 authored by aska-0096's avatar aska-0096
Browse files

tempsave

parent 3f9dbcac
add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp)
target_compile_options(example_gemm_multiply_multiply_xdl_fp8 PRIVATE -mllvm -greedy-reverse-local-assignment=1 -save-temps=$PWD -Wno-gnu-line-marker)
add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp)
add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp)
\ No newline at end of file
...@@ -57,14 +57,16 @@ struct MultiplyMultiply ...@@ -57,14 +57,16 @@ struct MultiplyMultiply
operator()(E& e, const C& c, const D0& d0, const D1& d1) const; operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <> template <>
__host__ __device__ constexpr void operator()<F16, float, float, float>( __host__ __device__ constexpr void operator()<F16, float, float, float>(F16& e,
F16& e, const float& c, const float& d0, const float& d1) const const float& c,
const float& d0,
const float& d1) const
{ {
const float x0_f = c * d0 * d1; const float x0_f = c * d0 * d1;
e = ck::type_convert<F16>(x0_f); e = ck::type_convert<F16>(x0_f);
} }
template <> template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, float, float>( __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
ck::half_t& e, const int& c, const float& d0, const float& d1) const ck::half_t& e, const int& c, const float& d0, const float& d1) const
...@@ -74,44 +76,43 @@ struct MultiplyMultiply ...@@ -74,44 +76,43 @@ struct MultiplyMultiply
e = ck::type_convert<ck::half_t>(x0_f); e = ck::type_convert<ck::half_t>(x0_f);
} }
}; };
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst)
const int NRepeat = 1; {
const int KRepeat = 8; const int NRepeat = 4;
const int NWave = 4; const int KRepeat = 4;
const int KLane = 2; const int NWave = 2;
const int NLane = 32; const int KLane = 2;
const int KPack = 16; const int NLane = 32;
int K0 = K / (KRepeat * KLane * KPack); const int KPack = 16;
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all lanes contiguous int K0 = K / (KRepeat * KLane * KPack);
// N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1 // K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all
// lanes contiguous N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk; int tempn, tempk;
for (int n = 0; n < N; ++n) { for(int n = 0; n < N; ++n)
for (int k = 0; k < K; ++k) { {
for(int k = 0; k < K; ++k)
{
int n0 = n / (NRepeat * NLane * NWave); int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack); int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave); tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack); tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave); int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KRepeat * KPack); // Klane int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave); tempn = tempn % (NLane * NWave);
tempk = tempk % (KRepeat * KPack); tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane; int n2 = tempn / NLane;
int k2 = tempk / KPack; // KRepeat int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane; int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0 int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0 +
+ k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat +
+ n1 * KPack * NLane * KLane * NWave * KRepeat n1 * KPack * NLane * KLane * NWave * KRepeat +
+ k2 * KPack * NLane * KLane * NWave //switch k1, k2 k2 * KPack * NLane * KLane * NWave // switch k1, k2
+ n2 * KPack * NLane * KLane + n2 * KPack * NLane * KLane + k1 * KPack * NLane + n3 * KPack + k3;
+ k1 * KPack * NLane
+ n3 * KPack
+ k3;
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
} }
...@@ -136,7 +137,16 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -136,7 +137,16 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
// kernel 1: 256->32x128x128 // kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec, 256,
256, 256, 128,
16, 16,
32, 32,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
1, 1, S<1, 32, 1, 8>, S<8, 8, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...@@ -213,7 +223,8 @@ int main(int argc, char* argv[]) ...@@ -213,7 +223,8 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); //use laout only for size Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{})); Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
......
...@@ -59,25 +59,25 @@ template <index_t BlockSize, ...@@ -59,25 +59,25 @@ template <index_t BlockSize,
// ,bool TransposeC //disable transposec right now... // ,bool TransposeC //disable transposec right now...
> >
struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intrawave, struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intrawave,
BlockSize, BlockSize,
ADataType, ADataType,
BDataType, BDataType,
ComputeDataType, ComputeDataType,
AccDataType, AccDataType,
ATileDesc, ATileDesc,
BTileDesc, BTileDesc,
AMmaTileDesc, AMmaTileDesc,
BMmaTileDesc, BMmaTileDesc,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, KPerBlock,
MPerXDL, MPerXDL,
NPerXDL, NPerXDL,
MRepeat, MRepeat,
NRepeat, NRepeat,
KPack> KPack>
: BlockwiseGemmXdlops_pipeline_base<BlockSize, : BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType, ADataType,
BDataType, BDataType,
...@@ -348,14 +348,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -348,14 +348,15 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>(); b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>,
Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
...@@ -399,8 +400,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -399,8 +400,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>(); b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>,
Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
...@@ -449,25 +451,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -449,25 +451,24 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>(); b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<0>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
}); });
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset = constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run( xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
a_thread_vec.template AsType<mfma_input_type>(), b_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(), c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
}); });
}); });
}); });
...@@ -477,11 +478,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -477,11 +478,11 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}), make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1, a_block_buf1,
a_thread_desc_, a_thread_desc_,
make_tuple(m0, I0, k0, I0), make_tuple(m0, I0, k0, I0),
a_thread_buf); a_thread_buf);
}); });
}); });
...@@ -491,8 +492,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr ...@@ -491,8 +492,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle<BlockGemmPipelineScheduler::Intr
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec; vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec = vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>(); b_blockwise_copy
.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) { static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
......
...@@ -112,7 +112,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1 ...@@ -112,7 +112,7 @@ struct ThreadGroupTensorSliceTransfer_v4r1
template <typename SeqIdx, index_t ThreadScratchId = 0> template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx() __device__ constexpr auto GetSrcThreadScratchIdx()
{ {
return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>(); return threadwise_transfer_.template GetSrcThreadScratchIdx<SeqIdx, ThreadScratchId>();
} }
template <typename SrcBuffer, index_t ThreadScratchId = 0> template <typename SrcBuffer, index_t ThreadScratchId = 0>
......
...@@ -67,55 +67,57 @@ template <typename ALayout, ...@@ -67,55 +67,57 @@ template <typename ALayout,
typename ComputeTypeA = CDataType, typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA, typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB> typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xdl_CShuffle_V3<ALayout, struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle
BLayout, : public DeviceGemmMultiD_Xdl_CShuffle_V3<
DsLayout, ALayout,
CLayout, BLayout,
ADataType, DsLayout,
BDataType, CLayout,
DsDataType, ADataType,
CDataType, BDataType,
GemmAccDataType, DsDataType,
CShuffleDataType, CDataType,
AElementwiseOperation, GemmAccDataType,
BElementwiseOperation, CShuffleDataType,
CElementwiseOperation, AElementwiseOperation,
GemmSpec, BElementwiseOperation,
BlockSize, CElementwiseOperation,
MPerBlock, GemmSpec,
NPerBlock, BlockSize,
KPerBlock, MPerBlock,
AK1, NPerBlock,
BK1, KPerBlock,
MPerXDL, AK1,
NPerXDL, BK1,
MXdlPerWave, MPerXDL,
NXdlPerWave, NPerXDL,
ABlockTransferThreadClusterLengths_AK0_M_AK1, MXdlPerWave,
ABlockTransferThreadClusterArrangeOrder, NXdlPerWave,
ABlockTransferSrcAccessOrder, ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcVectorDim, ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcScalarPerVector, ABlockTransferSrcAccessOrder,
ABlockTransferDstScalarPerVector_AK1, ABlockTransferSrcVectorDim,
ABlockLdsExtraM, ABlockTransferSrcScalarPerVector,
BBlockTransferThreadClusterLengths_BK0_N_BK1, ABlockTransferDstScalarPerVector_AK1,
BBlockTransferThreadClusterArrangeOrder, ABlockLdsExtraM,
BBlockTransferSrcAccessOrder, BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcVectorDim, BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcAccessOrder,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferSrcVectorDim,
BBlockLdsExtraN, BBlockTransferSrcScalarPerVector,
CShuffleMXdlPerWavePerShuffle, BBlockTransferDstScalarPerVector_BK1,
CShuffleNXdlPerWavePerShuffle, BBlockLdsExtraN,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleMXdlPerWavePerShuffle,
CDEShuffleBlockTransferScalarPerVectors, CShuffleNXdlPerWavePerShuffle,
BlkGemmPipeSched, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
BlkGemmPipelineVer, CDEShuffleBlockTransferScalarPerVectors,
ComputeTypeA, BlkGemmPipeSched,
ComputeTypeB, BlkGemmPipelineVer,
LDSTypeA, ComputeTypeA,
LDSTypeB> ComputeTypeB,
LDSTypeA,
LDSTypeB>
{ {
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -172,7 +174,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -172,7 +174,6 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
LDSTypeA, LDSTypeA,
LDSTypeB>; LDSTypeB>;
using Argument = typename GridwiseGemm::Argument; using Argument = typename GridwiseGemm::Argument;
// Invoker // Invoker
...@@ -267,7 +268,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -267,7 +268,9 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
constexpr index_t minimum_occupancy = constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2;
// static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right now"); // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
// has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right
// now");
if(has_main_k_block_loop) if(has_main_k_block_loop)
{ {
// Tail number always full // Tail number always full
...@@ -284,11 +287,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -284,11 +287,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<GridwiseGemm, GridwiseGemm,
true, true,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -298,7 +301,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -298,7 +301,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
} }
else else
{ {
if(arg.KBatch > 1) if(arg.KBatch > 1)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle< const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
...@@ -310,11 +313,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -310,11 +313,11 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} }
else else
{ {
const auto kernel = const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<
kernel_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle<GridwiseGemm, GridwiseGemm,
false, false,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
minimum_occupancy>; minimum_occupancy>;
Run(kernel); Run(kernel);
} }
} }
...@@ -437,4 +440,4 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd ...@@ -437,4 +440,4 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle : public DeviceGemmMultiD_Xd
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
\ No newline at end of file
...@@ -21,6 +21,12 @@ struct ColumnMajor : public BaseTensorLayout ...@@ -21,6 +21,12 @@ struct ColumnMajor : public BaseTensorLayout
{ {
static constexpr const char* name = "ColumnMajor"; static constexpr const char* name = "ColumnMajor";
}; };
struct MFMA : public BaseTensorLayout
{
static constexpr const char* name = "MFMA";
};
} // namespace gemm } // namespace gemm
namespace convolution { namespace convolution {
......
...@@ -283,6 +283,16 @@ struct MultiplyMultiply ...@@ -283,6 +283,16 @@ struct MultiplyMultiply
e = ck::type_convert<ck::half_t>(x0_f); e = ck::type_convert<ck::half_t>(x0_f);
} }
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
ck::half_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
template <> template <>
__host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>( __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
......
...@@ -126,15 +126,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -126,15 +126,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
CDEShuffleBlockTransferScalarPerVectors{}[I0]; CDEShuffleBlockTransferScalarPerVectors{}[I0];
// K1 should be Number<...> // K1 should be Number<...>
static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{}; static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{}; static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{}; static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 32; static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4; static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2; static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 8; static constexpr index_t KRepeat = 8;
static_assert(NLane * NWave * KLane == BlockSize); static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -323,10 +323,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -323,10 +323,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
constexpr index_t NKSWIZZLE_V = BlockSize * KPack; constexpr index_t NKSWIZZLE_V = BlockSize * KPack;
constexpr index_t NKSWIZZLE_N = Number<NKSWIZZLE_V>{}; constexpr index_t NKSWIZZLE_N = Number<NKSWIZZLE_V>{};
return make_naive_tensor_descriptor( return make_naive_tensor_descriptor(make_tuple(N0, K0, NKSWIZZLE_N),
make_tuple(N0, K0, NKSWIZZLE_N), make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1));
make_tuple(K0 * NKSWIZZLE_V, NKSWIZZLE_N, I1)
);
} }
__host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
...@@ -956,29 +954,30 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -956,29 +954,30 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
} }
using BlockwiseGemmPipe = using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockwiseGemmXdlops_pipeline_bpreshuffle<BlkGemmPipeSched, remove_cvref_t<decltype(BlockwiseGemmXdlops_pipeline_bpreshuffle<
BlockSize, BlkGemmPipeSched,
LDSTypeA, BlockSize,
LDSTypeB, LDSTypeA,
ComputeTypeA, LDSTypeB,
AccDataType, ComputeTypeA,
decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), AccDataType,
decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()),
decltype(MakeAMmaTileDescriptor_M0_M1_M2_K( decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()),
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(
decltype(MakeBMmaTileDescriptor_N0_N1_N2_K( GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())),
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(
ABlockTransferSrcScalarPerVector, GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())),
BBlockTransferSrcScalarPerVector, ABlockTransferSrcScalarPerVector,
MPerBlock, BBlockTransferSrcScalarPerVector,
NPerBlock, MPerBlock,
KPerBlock, NPerBlock,
MPerXdl, KPerBlock,
NPerXdl, MPerXdl,
MXdlPerWave, NPerXdl,
NXdlPerWave, MXdlPerWave,
KPack>{})>; NXdlPerWave,
KPack>{})>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte() __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{ {
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
...@@ -1260,8 +1259,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1260,8 +1259,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
{ {
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled = MakeBGridDescriptor_Preshuffled( const auto b_grid_desc_bpreshuffled =
problem.BN0Shuffled, problem.BK0Shuffled); MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>( const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
...@@ -1295,8 +1294,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1295,8 +1294,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)) ; __builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave));
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
...@@ -1340,35 +1339,34 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1340,35 +1339,34 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
// using BThreadClusterLengths = Sequence<1, 1, BlockSize>; // using BThreadClusterLengths = Sequence<1, 1, BlockSize>;
// using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>; // using BBlockTransferClusterArrangeOrder = Sequence<0, 1, 2>;
// B matrix blockwise copy // B matrix blockwise copy
auto b_blockwise_copy = auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1<
ThreadGroupTensorSliceTransfer_v4r1<ThisThreadBlock, ThisThreadBlock,
BElementwiseOperation, BElementwiseOperation,
ck::tensor_operation::element_wise::PassThrough, ck::tensor_operation::element_wise::PassThrough,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
Sequence<1, KRepeat, KPack * BlockSize>, Sequence<1, KRepeat, KPack * BlockSize>,
Sequence<1, 1, BlockSize>, //BThreadClusterLengths, Sequence<1, 1, BlockSize>, // BThreadClusterLengths,
Sequence<0, 1, 2>, //BBlockTransferClusterArrangeOrder, Sequence<0, 1, 2>, // BBlockTransferClusterArrangeOrder,
BDataType, BDataType,
LDSTypeB, LDSTypeB,
decltype(b_grid_desc_bpreshuffled), decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
Sequence<0, 1, 2>,//BBlockTransferSrcAccessOrder, Sequence<0, 1, 2>, // BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2>, Sequence<0, 1, 2>,
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1, BBlockTransferDstScalarPerVector_BK1,
1, 1,
1, 1,
BThreadTransferSrcResetCoordinateAfterRun, BThreadTransferSrcResetCoordinateAfterRun,
true, true,
2>( 2>(b_grid_desc_bpreshuffled,
b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, 0, 0),
make_multi_index(n_block_data_idx_on_grid, 0, 0), b_element_op,
b_element_op, b_block_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0),
make_multi_index(0, 0, 0), ck::tensor_operation::element_wise::PassThrough{});
ck::tensor_operation::element_wise::PassThrough{});
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple( constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
...@@ -1673,7 +1671,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle ...@@ -1673,7 +1671,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}); });
} }
} }
}; };
} // namespace ck } // namespace ck
...@@ -268,12 +268,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1 ...@@ -268,12 +268,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
} }
template <typename SeqIdx, index_t ThreadScratchId = 0> template <typename SeqIdx, index_t ThreadScratchId = 0>
__device__ constexpr auto GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{}) __device__ constexpr auto
GetSrcThreadScratchIdx(Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{ {
using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type; using vector_t = typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{}); return src_thread_scratch_tuple_(thread_scratch_id).template GetAsType<vector_t>(SeqIdx{});
} }
template <index_t ThreadScratchId> template <index_t ThreadScratchId>
__device__ void __device__ void
TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id) TransferDataFromSrcThreadScratchToDstThreadScratch(Number<ThreadScratchId> thread_scratch_id)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if 0
#if(defined(CK_ENABLE_F16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
F16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMultipleDSplitK<
ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>>
{
using DeviceOp =
DeviceGemmMultipleDSplitK<ALayout,
BLayout,
Tuple<Row, Col>,
CLayout,
ADataType,
BDataType,
Tuple<F32, F32>,
CDataType,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::MultiplyMultiply>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
// TODO: Add MFMA layout into tensor layout
#if 0
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_comp_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_comp_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_mem_v1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_mem_v1_kpadding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_mem_v2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_f16_mk_mfma_mn_mem_v2_kpadding_instances(
op_ptrs);
}
}
#endif
#endif
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
if constexpr(is_same_v<ADataType, f8_t> && is_same_v<BDataType, f8_t> &&
is_same_v<CDataType, bhalf_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
op_ptrs);
add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
op_ptrs);
}
}
#endif
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES)
list(APPEND GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp
)
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16/device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_multiply_multiply_weight_preshuffle_instance ${GEMM_MULTIPLY_MULTIPLY_WEIGHT_PRESHUFFLE_INSTANCES})
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p1_instances<
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p2_instances<
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_default_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_padding_instances(
std::vector<std::unique_ptr<DeviceGemmMultipleDSplitK<Row,
Col,
Tuple<Row, Col>,
Row,
F8,
F8,
Tuple<F32, F32>,
BF16,
PassThrough,
PassThrough,
MultiplyMultiply>>>& instances)
{
add_device_operation_instances(
instances,
device_gemm_multiply_multiply_weight_preshuffle_xdl_f8_f8_bf16_mk_mfma_mn_p3_instances<
GemmKPadding>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_b_preshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply_weight_preshuffle.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace profiler {
template <typename ADataType,
typename BDataType,
typename ComputeDataType,
typename AccDataType,
typename D0DataType,
typename D1DataType,
typename EDataType,
typename ALayout,
typename BLayout,
typename D0Layout,
typename D1Layout,
typename ELayout>
bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideD0,
int StrideD1,
int StrideE,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD0, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
Tensor<EDataType> e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
int total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "d0_m_n: " << d0_m_n.mDesc << std::endl;
std::cout << "d1_m_n: " << d1_m_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-1, 2});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-1, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-1, 1});
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{0.0, 1.0});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{0.0, 1.0});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{0.0, 1.0});
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CElementOp = MultiplyMultiply;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
using DeviceOp =
ck::tensor_operation::device::DeviceGemmMultipleDSplitK<ALayout,
BLayout,
ck::Tuple<D0Layout, D1Layout>,
ELayout,
ADataType,
BDataType,
ck::Tuple<D0DataType, D1DataType>,
EDataType,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
Tensor<AccDataType> c_m_n({M, N});
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
BDataType,
AccDataType,
AccDataType,
AElementOp,
BElementOp,
PassThrough,
ComputeDataType>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument =
ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < M; ++m)
{
for(int n = 0; n < N; ++n)
{
c_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_m_n(m, n), d1_m_n(m, n));
}
}
}
std::string best_op_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
// TODO: Shuffle the weight
// ...
std::vector<int> kbatch_list = {1, 2, 4, 8, 16};
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
std::array<const void*, 2>{d0_device_buf.GetDeviceBuffer(),
d1_device_buf.GetDeviceBuffer()},
static_cast<EDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, 2>{StrideD0, StrideD1},
StrideE,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
c_device_buf.FromDevice(e_m_n_device_result.mData.data());
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if constexpr((is_same_v<ADataType, f8_t> || is_same_v<BDataType, f8_t> ||
is_same_v<EDataType, f8_t>) ||
(is_same_v<ADataType, int8_t> || is_same_v<BDataType, int8_t> ||
is_same_v<EDataType, int8_t>))
{
std::string msg = "Error: Incorrect results!";
double rtol = 1e-1;
double atol = 1e-1;
pass = pass & ck::utils::check_err(
e_m_n_device_result, e_m_n_host_result, msg, rtol, atol);
}
else
{
#endif
pass = pass & ck::utils::check_err(e_m_n_device_result, e_m_n_host_result);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if(do_log)
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", e_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", e_m_n_device_result.mData, ",")
<< std::endl;
}
}
std::string op_name = op_ptr->GetTypeString();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(EDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<EDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<EDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<EDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<EDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideE = " << StrideE << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck
...@@ -50,7 +50,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -50,7 +50,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# endif() # endif()
# list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp)
list(APPEND PROFILER_SOURCES profile_gemm_multiply_multiply_weight_preshuffle.cpp)
# list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp) # list(APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp)
# endif() # endif()
# list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp) # list(APPEND PROFILER_SOURCES profile_batched_gemm.cpp)
...@@ -136,7 +137,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") ...@@ -136,7 +137,8 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_batched_gemm_reduce_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_add_instance)
# if(SUPPORTED_GPU_TARGETS MATCHES "gfx94") # if(SUPPORTED_GPU_TARGETS MATCHES "gfx94")
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_instance)
target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_multiply_multiply_weight_preshuffle_instance)
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_ab_scale_instance)
# endif() # endif()
# target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance) # target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_gemm_splitk_instance)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment