Commit 1857761a authored by ozturkosu's avatar ozturkosu
Browse files

:Merge branch 'muozturk_sk_padding' of...

:Merge branch 'muozturk_sk_padding' of https://github.com/ROCm/composable_kernel into muozturk_sk_padding
parents 4c64fa6d 715ffa67
...@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -138,6 +138,7 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
if(stream_config.log_level_ > 0) if(stream_config.log_level_ > 0)
{ {
arg.Print(); arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
} }
if(!GridwiseGemm::CheckValidity(arg)) if(!GridwiseGemm::CheckValidity(arg))
...@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout, ...@@ -733,7 +734,9 @@ struct DeviceGemm_Xdl_CShuffleV3 : public DeviceGemmV2<ALayout,
<< "BlkGemmPipelineVersion: " << "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", " << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: " << "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages; << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride;
// clang-format on // clang-format on
return str.str(); return str.str();
......
...@@ -224,12 +224,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -224,12 +224,12 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
}(); }();
// Pad both M and K to be multiples of the block sizes // Pad both M and K to be multiples of the block sizes
const auto a_grid_desc_m_k = transform_tensor_descriptor( const auto a_grid_desc_m_k =
a_grid_desc_mraw_kraw, transform_tensor_descriptor(a_grid_desc_mraw_kraw,
make_tuple(make_right_pad_transform(M, MPad - M), make_tuple(make_right_pad_transform(M, MPad - M),
make_right_pad_transform(K, KPad - K)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
a_grid_desc_m_k, a_grid_desc_m_k,
...@@ -322,14 +322,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -322,14 +322,14 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
} }
}(); }();
// Pad both N and K to be multiples of the block sizes // Pad both N and K to be multiples of the block sizes
const auto b_grid_desc_n_k = transform_tensor_descriptor( const auto b_grid_desc_n_k =
b_grid_desc_nraw_kraw, transform_tensor_descriptor(b_grid_desc_nraw_kraw,
make_tuple(make_right_pad_transform(N, NPad - N), make_tuple(make_right_pad_transform(N, NPad - N),
make_right_pad_transform(K, KPad - K)), make_right_pad_transform(K, KPad - K)),
make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{}),
make_tuple(Sequence<0>{}, Sequence<1>{})); make_tuple(Sequence<0>{}, Sequence<1>{}));
const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
b_grid_desc_n_k, b_grid_desc_n_k,
...@@ -990,7 +990,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -990,7 +990,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
!(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)) !(is_same<tensor_layout::gemm::RowMajor, ALayout>::value))
{ {
if(!(karg.M % MPerBlock == 0)) if(!(karg.M % MPerBlock == 0))
...@@ -1008,7 +1008,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1008,7 +1008,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) &&
(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)) (is_same<tensor_layout::gemm::RowMajor, BLayout>::value))
{ {
if(!(karg.N % NPerBlock == 0)) if(!(karg.N % NPerBlock == 0))
...@@ -1075,7 +1075,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1075,7 +1075,6 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
return false; return false;
} }
...@@ -1093,9 +1092,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1093,9 +1092,9 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
std::cout << "Arg N (" << karg.N std::cout << "Arg N (" << karg.N
<< ") value is not a multiple of BBlockTransferSrcScalarPerVector (" << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
return false; return false;
} }
} }
...@@ -1110,7 +1109,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1110,7 +1109,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__ << std::endl; << __LINE__ << ", in function: " << __func__ << std::endl;
} }
return false; return false;
} }
} }
...@@ -1128,7 +1127,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1128,7 +1127,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
return false; return false;
} }
} }
...@@ -1145,7 +1144,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3 ...@@ -1145,7 +1144,7 @@ struct GridwiseGemm_xdl_cshuffle_streamk_v3
<< __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
<< std::endl; << std::endl;
} }
return false; return false;
} }
} }
......
...@@ -37,7 +37,7 @@ __global__ void ...@@ -37,7 +37,7 @@ __global__ void
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -70,7 +70,7 @@ __global__ void ...@@ -70,7 +70,7 @@ __global__ void
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>( GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset, karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
...@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -638,45 +638,45 @@ struct GridwiseGemm_xdl_cshuffle_v3
struct SplitKBatchOffset struct SplitKBatchOffset
{ {
__device__ SplitKBatchOffset(Argument& karg) __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
{ {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; a_k_split_offset = k_id * karg.KRead / APackedSize;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{ {
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; a_k_split_offset = k_id * karg.KRead * karg.StrideA;
} }
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; b_k_split_offset = k_id * karg.KRead * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
if constexpr(!PermuteB) if constexpr(!PermuteB)
{ {
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; b_k_split_offset = k_id * karg.KRead / BPackedSize;
} }
else else
{ {
const int k0_offset = karg.KRead * karg.N; const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; b_k_split_offset = k_id * k0_offset / BPackedSize;
} }
} }
// Calculate B scale offset // Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>) if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{ {
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
} }
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>) else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{ {
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
} }
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1)) if(k_id < (karg.KBatch - 1))
{ {
karg.K = karg.KRead; karg.K = karg.KRead;
} }
...@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -687,7 +687,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
if(karg.IsReduceAdd()) if(karg.IsReduceAdd())
{ {
c_reduce_offset = blockIdx.z * karg.M * karg.N; c_reduce_offset = k_id * karg.M * karg.N;
} }
else else
{ {
......
...@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst ...@@ -90,14 +90,22 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL); KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n", "%d, %d\n C MFMA inst: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num, A_Buffer_Load_Inst_Num,
B_Buffer_Load_Inst_Num, B_Buffer_Load_Inst_Num,
A_LDS_Write_Inst_Num, A_LDS_Write_Inst_Num,
B_LDS_Write_Inst_Num, B_LDS_Write_Inst_Num,
A_LDS_Read_Inst_Num, A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num, B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num); C_MFMA_Inst_Num,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,
BLDSWriteWidth,
ABufferLoadWidth,
BBufferLoadWidth);
} }
}; };
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core/numeric/vector_type.hpp" #include "ck_tile/core/numeric/vector_type.hpp"
...@@ -8,16 +8,75 @@ ...@@ -8,16 +8,75 @@
namespace ck_tile { namespace ck_tile {
CK_TILE_HOST_DEVICE bf16_t add_bf16_t(const bf16_t& a, const bf16_t& b) template <typename T, typename ComputeType>
CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
{ {
return type_convert<bf16_t>(type_convert<float>(a) + type_convert<float>(b)); return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
} }
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b) CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t& a, const bf16x2_t& b)
{ {
bf16x2_t rtn; bf16x2_t rtn;
rtn[0] = add_bf16_t(a[0], b[0]); rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add_bf16_t(a[1], b[1]); rtn[1] = add<bf16_t, float>(a[1], b[1]);
return rtn;
}
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t& a, const bf16x4_t& b)
{
bf16x4_t rtn;
rtn[0] = add<bf16_t, float>(a[0], b[0]);
rtn[1] = add<bf16_t, float>(a[1], b[1]);
rtn[2] = add<bf16_t, float>(a[2], b[2]);
rtn[3] = add<bf16_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t& a, const fp8x4_t& b)
{
fp8x4_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t& a, const fp8x8_t& b)
{
fp8x8_t rtn;
rtn[0] = add<fp8_t, float>(a[0], b[0]);
rtn[1] = add<fp8_t, float>(a[1], b[1]);
rtn[2] = add<fp8_t, float>(a[2], b[2]);
rtn[3] = add<fp8_t, float>(a[3], b[3]);
rtn[4] = add<fp8_t, float>(a[4], b[4]);
rtn[5] = add<fp8_t, float>(a[5], b[5]);
rtn[6] = add<fp8_t, float>(a[6], b[6]);
rtn[7] = add<fp8_t, float>(a[7], b[7]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t& a, const bf8x4_t& b)
{
bf8x4_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
return rtn;
}
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t& a, const bf8x8_t& b)
{
bf8x8_t rtn;
rtn[0] = add<bf8_t, float>(a[0], b[0]);
rtn[1] = add<bf8_t, float>(a[1], b[1]);
rtn[2] = add<bf8_t, float>(a[2], b[2]);
rtn[3] = add<bf8_t, float>(a[3], b[3]);
rtn[4] = add<bf8_t, float>(a[4], b[4]);
rtn[5] = add<bf8_t, float>(a[5], b[5]);
rtn[6] = add<bf8_t, float>(a[6], b[6]);
rtn[7] = add<bf8_t, float>(a[7], b[7]);
return rtn; return rtn;
} }
...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x) ...@@ -59,6 +118,192 @@ CK_TILE_DEVICE void atomic_add<bf16x2_t>(bf16x2_t* p_dst, const bf16x2_t& x)
} while(cur_v.u32 != old_v); } while(cur_v.u32 != old_v);
} }
template <>
CK_TILE_DEVICE void atomic_add<bf16x4_t>(bf16x4_t* p_dst, bf16x4_t const& x)
{
// Union to treat the pointer as either bf16x4_t* or uint64_t*:
union U64BF164_ADDR
{
uint64_t* u64_a;
bf16x4_t* bf164_a;
};
// Union to treat the data as either bf16x4_t or 64-bit integer
union U64BF164
{
uint64_t u64;
bf16x4_t bf164;
};
U64BF164_ADDR addr;
addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
// First read (non-atomic) of the old value
U64BF164 cur_v;
cur_v.u64 = *addr.u64_a;
U64BF164 new_v_union;
uint64_t old_v, new_v;
do
{
// old 64 bits
old_v = cur_v.u64;
// Add elementwise in bf16
new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
new_v = new_v_union.u64;
// Attempt the 64-bit CAS
cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<fp8x4_t>(fp8x4_t* p_dst, const fp8x4_t& x)
{
union U32FP84_ADDR
{
uint32_t* u32_a;
fp8x4_t* fp84_a;
};
union U32FP84
{
uint32_t u32;
fp8x4_t fp84;
};
U32FP84_ADDR dword_addr;
U32FP84 cur_v;
U32FP84 new_;
uint32_t old_v, new_v;
dword_addr.fp84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
template <>
CK_TILE_DEVICE void atomic_add<bf8x4_t>(bf8x4_t* p_dst, const bf8x4_t& x)
{
union U32BF84_ADDR
{
uint32_t* u32_a;
bf8x4_t* bf84_a;
};
union U32BF84
{
uint32_t u32;
bf8x4_t bf84;
};
U32BF84_ADDR dword_addr;
U32BF84 cur_v;
U32BF84 new_;
uint32_t old_v, new_v;
dword_addr.bf84_a = p_dst;
cur_v.u32 = *dword_addr.u32_a;
do
{
old_v = cur_v.u32;
new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
new_v = new_.u32;
cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
} while(cur_v.u32 != old_v);
}
//
// Atomic add for fp8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<fp8x8_t>(fp8x8_t* p_dst, fp8x8_t const& x)
{
// Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
union U64FP88_ADDR
{
uint64_t* u64_a; // pointer to 64-bit integer
fp8x8_t* fp88_a; // pointer to fp8x8_t
};
union U64FP88
{
uint64_t u64;
fp8x8_t fp88;
};
U64FP88_ADDR dword_addr;
U64FP88 cur_v;
U64FP88 new_v_union;
uint64_t old_v, new_v;
// Point to the destination as both fp8x8_t* and uint64_t*.
dword_addr.fp88_a = p_dst;
// Initial read of 64 bits from memory
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each fp8 element using your add_fp8x8_t(...) routine
new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
new_v = new_v_union.u64;
// Attempt 64-bit CAS
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
//
// Atomic add for bf8x8_t
//
template <>
CK_TILE_DEVICE void atomic_add<bf8x8_t>(bf8x8_t* p_dst, bf8x8_t const& x)
{
union U64BF88_ADDR
{
uint64_t* u64_a;
bf8x8_t* bf88_a;
};
union U64BF88
{
uint64_t u64;
bf8x8_t bf88;
};
U64BF88_ADDR dword_addr;
U64BF88 cur_v;
U64BF88 new_v_union;
uint64_t old_v, new_v;
dword_addr.bf88_a = p_dst;
// Read the original 64 bits
cur_v.u64 = *dword_addr.u64_a;
do
{
old_v = cur_v.u64;
// Add each bf8 element using your add_bf8x8_t(...) routine
new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
new_v = new_v_union.u64;
// 64-bit CAS loop
cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
} while(cur_v.u64 != old_v);
}
template <typename T, index_t N> template <typename T, index_t N>
CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
{ {
...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -66,8 +311,10 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
(std::is_same<T, uint32_t>::value && (N == 1)) || (std::is_same<T, uint32_t>::value && (N == 1)) ||
(std::is_same<T, float>::value && (N == 1 || N == 2)) || (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
(std::is_same<T, double>::value && (N == 1 || N == 2)) || (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
(std::is_same<T, bf16_t>::value && (N == 2 || N == 4)), (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
"wrong! not implemented"); (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
"The granularity of the thread buffer is unsupported on the hardware!");
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{}; constexpr auto I1 = number<1>{};
...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x) ...@@ -118,9 +365,45 @@ CK_TILE_DEVICE void atomic_add_g(T* p_dst, const thread_buffer<T, N>& x)
} }
else if constexpr(N == 4) else if constexpr(N == 4)
{ {
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), x.template get_as<bf16x2_t>()[I0]); atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst) + 1, }
x.template get_as<bf16x2_t>()[I1]); else if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
x.template get_as<bf16x4_t>()[I1]);
}
}
else if constexpr(std::is_same<T, fp8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
}
}
else if constexpr(std::is_same<T, bf8_t>::value)
{
if constexpr(N == 4)
{
atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
}
if constexpr(N == 8)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
}
if constexpr(N == 16)
{
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
} }
} }
} }
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
...@@ -34,4 +35,3 @@ ...@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
...@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType = ...@@ -22,13 +22,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_relative_threshold(const int number_of_accumulations = 1) double get_relative_threshold(const int number_of_accumulations = 1)
{ {
using F8 = ck_tile::fp8_t; using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t; using BF16 = ck_tile::bf16_t;
using F32 = float; using F32 = float;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"); "Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0; double compute_error = 0;
...@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -41,7 +42,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, -numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the relative threshold!"); "Warning: Unhandled OutDataType for setting up the relative threshold!");
double output_error = 0; double output_error = 0;
...@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1) ...@@ -55,7 +56,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the relative threshold!"); "Warning: Unhandled AccDataType for setting up the relative threshold!");
double acc_error = 0; double acc_error = 0;
...@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType = ...@@ -74,13 +75,14 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1)
{ {
using F8 = ck_tile::fp8_t; using F8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using F16 = ck_tile::half_t; using F16 = ck_tile::half_t;
using BF16 = ck_tile::bf16_t; using BF16 = ck_tile::bf16_t;
using F32 = float; using F32 = float;
using I8 = int8_t; using I8 = int8_t;
using I32 = int32_t; using I32 = int32_t;
static_assert(is_any_of<ComputeDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"); "Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num)); auto expo = std::log2(std::abs(max_possible_num));
...@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -94,7 +96,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5; compute_error = std::pow(2, expo - numeric_traits<ComputeDataType>::mant) * 0.5;
} }
static_assert(is_any_of<OutDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<OutDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"); "Warning: Unhandled OutDataType for setting up the absolute threshold!");
double output_error = 0; double output_error = 0;
...@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of ...@@ -108,7 +110,7 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
} }
double midway_error = std::max(compute_error, output_error); double midway_error = std::max(compute_error, output_error);
static_assert(is_any_of<AccDataType, F8, F16, BF16, F32, I8, I32, int>::value, static_assert(is_any_of<AccDataType, F8, BF8, F16, BF16, F32, I8, I32, int>::value,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"); "Warning: Unhandled AccDataType for setting up the absolute threshold!");
double acc_error = 0; double acc_error = 0;
...@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val ...@@ -501,7 +503,11 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
} }
if(!res) if(!res)
{ {
std::cerr << std::setw(12) << std::setprecision(7) << "max err: " << max_err << std::endl; const float error_percent =
static_cast<float>(err_count) / static_cast<float>(out.size()) * 100.f;
std::cerr << "max err: " << max_err;
std::cerr << ", number of errors: " << err_count;
std::cerr << ", " << error_percent << "% wrong values" << std::endl;
} }
return res; return res;
} }
......
...@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A, ...@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? col * strideB + k ? col * strideB + k
: k * strideB + col; : k * strideB + col;
acc += static_cast<AccDataType>(A[a_index]) * static_cast<AccDataType>(B[b_index]); acc += ck_tile::type_convert<AccDataType>(A[a_index]) *
ck_tile::type_convert<AccDataType>(B[b_index]);
} }
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>) int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? row * strideC + col ? row * strideC + col
: col * strideC + row; : col * strideC + row;
C[c_index] = acc; C[c_index] = ck_tile::type_convert<CDataType>(acc);
} }
} }
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -77,6 +77,7 @@ struct CShuffleEpilogue ...@@ -77,6 +77,7 @@ struct CShuffleEpilogue
* *
* @return The vector store size for C tensor. * @return The vector store size for C tensor.
*/ */
template <typename ODataType>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{ {
constexpr index_t MaxVectorStoreSize = 16; constexpr index_t MaxVectorStoreSize = 16;
...@@ -142,7 +143,7 @@ struct CShuffleEpilogue ...@@ -142,7 +143,7 @@ struct CShuffleEpilogue
TileDistributionEncodingPattern2D<kBlockSize, TileDistributionEncodingPattern2D<kBlockSize,
kMPerIteration, kMPerIteration,
kNPerIteration, kNPerIteration,
GetVectorSizeC(), GetVectorSizeC<ODataType>(),
tile_distribution_pattern::thread_raked>; tile_distribution_pattern::thread_raked>;
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
......
...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr ...@@ -79,7 +79,10 @@ struct BlockUniversalGemmAsBsCr
// TODO: Should we have two policies? Interwave & Intrawave ?? // TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1; static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread; // should be at least equal to: WarpGemm::Impl::kABKPerLane
// and the question is how to assess upper limit or exact value?
// TODO: Should we introduce AK1/BK1 parameters ?
static constexpr index_t KPack = 8;
static constexpr index_t KPerThread = KIterPerWarp * KPack; static constexpr index_t KPerThread = KIterPerWarp * KPack;
static constexpr index_t KRepeat = KPerThread / KPack; static constexpr index_t KRepeat = KPerThread / KPack;
}; };
......
...@@ -159,7 +159,7 @@ struct GemmKernel ...@@ -159,7 +159,7 @@ struct GemmKernel
CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs) CK_TILE_HOST static bool IsSupportedArgument(const GemmKernelArgs& kargs)
{ {
if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value) is_any_of<CDataType, fp16_t, bf16_t>::value)
{ {
if(kargs.k_batch != 1) if(kargs.k_batch != 1)
...@@ -240,7 +240,7 @@ struct GemmKernel ...@@ -240,7 +240,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0) if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -255,7 +255,7 @@ struct GemmKernel ...@@ -255,7 +255,7 @@ struct GemmKernel
<< std::endl; << std::endl;
return false; return false;
} }
if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0) if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
{ {
std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl; std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
return false; return false;
...@@ -321,7 +321,7 @@ struct GemmKernel ...@@ -321,7 +321,7 @@ struct GemmKernel
c_ptr, c_ptr,
make_tuple(kargs.M, kargs.N), make_tuple(kargs.M, kargs.N),
make_tuple(kargs.stride_C, 1), make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{}, number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
number<1>{}); number<1>{});
} }
else else
...@@ -519,7 +519,7 @@ struct GemmKernel ...@@ -519,7 +519,7 @@ struct GemmKernel
{ {
// Do not compile in case where we have unsupported // Do not compile in case where we have unsupported
// VectorSizeC & data type configuration. // VectorSizeC & data type configuration.
if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
is_any_of<CDataType, fp16_t, bf16_t>::value)) is_any_of<CDataType, fp16_t, bf16_t>::value))
{ {
RunGemm<memory_operation_enum::atomic_add>( RunGemm<memory_operation_enum::atomic_add>(
......
...@@ -3,6 +3,9 @@ ...@@ -3,6 +3,9 @@
#pragma once #pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
...@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -83,6 +86,56 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
{ {
...@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -95,29 +148,35 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
CK_TILE_DEVICE static constexpr auto HotLoopScheduler() CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{ {
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{}); constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{}); constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{}); constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64; constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{}); constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{}); constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = KPerXDL; // Below should be equal to AK1|BK1
constexpr index_t B_LDS_Read_Width = KPerXDL; constexpr index_t A_LDS_Read_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Read_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_LDS_Write_Width = Policy::template GetSmemPackA<Problem>();
constexpr index_t B_LDS_Write_Width = Policy::template GetSmemPackB<Problem>();
constexpr index_t A_Buffer_Load_Inst_Num = constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA()); MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num = constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB()); NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL); constexpr index_t A_LDS_Write_Inst_Num =
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL); MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num = constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num = constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * MPerBlock * KPerBlock / (BlockSize * KPerXDL); WaveNumM * MPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock / constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (BlockSize / WaveSize) /
......
...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -185,7 +185,6 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename ScaleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMXGemm : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ADataType>& a_m_k,
const Tensor<ScaleDataType>& a_m_kblock_scales,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& b_kblock_n_scales,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: a_m_k_{a_m_k},
a_m_kblock_scales_{a_m_kblock_scales},
b_k_n_{b_k_n},
b_kblock_n_scales_{b_kblock_n_scales},
c_m_n_{c_m_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ADataType>& a_m_k_;
const Tensor<ScaleDataType>& a_m_kblock_scales_;
const Tensor<BDataType>& b_k_n_;
const Tensor<ScaleDataType>& b_kblock_n_scales_;
Tensor<CDataType>& c_m_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMXGemm::Argument;
float Run(const Argument& arg)
{
using GemmInstance = ck::tensor_operation::host::ReferenceGemm<ComputeTypeA,
ComputeTypeB,
CDataType,
AccDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
ComputeTypeA,
ComputeTypeB>;
Tensor<ComputeTypeA> a_m_k_scaled(arg.a_m_k_.mDesc);
Tensor<ComputeTypeB> b_k_n_scaled(arg.b_k_n_.mDesc);
const auto M = arg.a_m_k_.mDesc.GetLengths()[0];
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
for(size_t m = 0; m < M; m++)
{
for(size_t k = 0; k < K; k++)
{
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(arg.a_m_k_(m, k)) *
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
}
}
for(size_t n = 0; n < N; n++)
{
for(size_t k = 0; k < K; k++)
{
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(arg.b_k_n_(k, n)) *
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
}
}
auto ref_gemm = GemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k_scaled,
b_k_n_scaled,
arg.c_m_n_,
arg.a_element_op_,
arg.b_element_op_,
arg.c_element_op_);
ref_invoker.Run(ref_argument);
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ADataType>& a_m_k,
const Tensor<ScaleDataType>& a_m_kblock_scales,
const Tensor<BDataType>& b_k_n,
const Tensor<ScaleDataType>& b_kblock_n_scales,
Tensor<CDataType>& c_m_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{a_m_k,
a_m_kblock_scales,
b_k_n,
b_kblock_n_scales,
c_m_n,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMXGemm"
<< std::endl;
// clang-format on
return str.str();
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl_fpAintB_b_scale.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include <memory>
#include <vector>
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
#if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8))
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances);
#endif
template <typename ADataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
index_t ScaleBlockK>
struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceBatchedGemmV2BScale<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
{
using DeviceOp = DeviceBatchedGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
1,
ScaleBlockK,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ADataType, half_t> && is_same_v<BDataType, pk_i4_t> &&
is_same_v<CDataType, half_t>)
{
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Col> &&
is_same_v<CLayout, Row>)
{
add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
# ONLY XDL_KERNELS
set(BATCHED_GEMM_B_SCALE_INSTANCES)
list(APPEND BATCHED_GEMM_B_SCALE_INSTANCES
device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp
)
set_source_files_properties(device_batched_gemm_b_scale_xdl_f16_i4_f16/device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_batched_gemm_b_scale_instance ${BATCHED_GEMM_B_SCALE_INSTANCES})
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(
std::vector<std::unique_ptr<DeviceBatchedGemmV2BScale<Row,
Col,
Row,
F16,
I4,
F16,
F16,
1,
128,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances,
device_batched_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances<Intrawave,
GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment