"include/ck/utility/integral_constant.hpp" did not exist on "238d58c2f5947246a3e62f72db2b175b2e948554"
Commit 174b46b0 authored by coderfeli's avatar coderfeli
Browse files

add cpu shuffle

parent e6f5a78b
......@@ -63,36 +63,38 @@ struct MultiplyMultiply
}
};
void reshapeBuffer(char* buffer, int N, int K, char* output) {
const int KRepeat = 2;
const int NRepeat = 3;
const int KLane = 4;
const int NLane = 5;
const int KPack = 6;
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1;
const int KRepeat = 4;
const int KLane = 2;
const int NLane = 128;
const int KPack = 16;
int N0 = N / (NRepeat * NLane);
int K0 = K / (KRepeat * KLane * KPack);
int tempn, tempk;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
int n0 = n / (NRepeat * NLane);
int k0 = k / (KRepeat * KLane * KPack);
int nRel = n % (NRepeat * NLane);
int kRel = k % (KRepeat * KLane * KPack);
int nIndex = nRel / NLane;
int kIndex = kRel / (KLane * KPack);
int nLaneIndex = nRel % NLane;
int kLaneIndex = (kRel % (KLane * KPack)) / KPack;
int kPackIndex = kRel % KPack;
int outputIndex = (n0 * K0 + k0) * KRepeat * NRepeat * KLane * NLane * KPack
+ nIndex * KRepeat * KLane * KPack
+ kIndex * KLane * KPack
+ nLaneIndex * KPack
+ kLaneIndex * KPack
+ kPackIndex;
output[outputIndex] = buffer[n * K + k];
tempn = n % (NRepeat * NLane);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / NLane;
int k1 = tempk / (KLane * KPack);
int n2 = n1 % NLane;
tempk = tempk % (KLane * KPack);
int k2 = tempk / KPack;
int k3 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * KRepeat
+ k1 * KPack * NLane * KLane
+ k2 * KPack * NLane
+ n2 * KPack
+ k3;
dst[outputIndex] = src[n * K + k];
}
}
}
......@@ -191,6 +193,7 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); //use laout only for size
Tensor<D0DataType> d0_m_n(f_host_tensor_descriptor(M, N, StrideD, D0Layout{}));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor(M, N, StrideD, D1Layout{}));
Tensor<EDataType> e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{}));
......@@ -217,15 +220,15 @@ int main(int argc, char* argv[])
d0_m_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-0.5, 0.5});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-0.5, 0.5});
}
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
a0_device_buf.ToDevice(a0_m_k.mData.data());
b0_device_buf.ToDevice(b0_k_n.mData.data());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
e_device_buf.ToDevice(e_m_n_device_result.mData.data());
......
......@@ -131,7 +131,7 @@ struct GeneratorTensor_2<ck::f8_t>
template <typename... Is>
ck::f8_t operator()(Is...)
{
float tmp = (std::rand() % (max_value - min_value)) + min_value;
float tmp = 1;
return ck::type_convert<ck::f8_t>(tmp);
}
};
......
......@@ -281,7 +281,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
ABlockBuffer& a_block_buf0,
ABlockBuffer& a_block_buf1,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
......@@ -306,7 +307,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0);
// // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
......@@ -321,19 +322,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_block_buf0,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, k0, I0),
// b_thread_buf);
// });
});
__builtin_amdgcn_sched_barrier(0);
......@@ -344,9 +337,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
......@@ -364,8 +355,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
});
// if(threadIdx.x==0) {
// printf("\n");
// }
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
......@@ -387,7 +385,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
......@@ -397,10 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf0);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<0>{});
......@@ -441,7 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf,
a_block_buf0,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
......
......@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
// if(arg.KBatch > 1)
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::AtomicAdd,
// minimum_occupancy,
// TailNumber::Even>;
// Run(kernel);
// }
// }
// else
// {
// if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy,
// TailNumber::Odd>;
// Run(kernel);
// }
// else
// {
// const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
// GridwiseGemm,
// true,
// InMemoryDataOperationEnum::Set,
// minimum_occupancy,
// TailNumber::Even>;
// Run(kernel);
// }
// }
}
else
{
......
......@@ -40,6 +40,7 @@ __global__ void
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
......@@ -49,42 +50,7 @@ __global__ void
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
#else
ignore = karg;
#endif // end of if (defined(__gfx9__))
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds(typename GridwiseGemm::Argument karg)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
p_shared_0,
p_shared_1,
p_shared1,
karg,
karg.a_element_op,
karg.b_element_op,
......@@ -1256,6 +1222,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1268,6 +1235,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
p_ds_grid,
p_c_grid,
p_shared,
p_shared1,
problem,
a_element_op,
b_element_op,
......@@ -1284,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
DsGridPointer& p_ds_grid,
CDataType* p_c_grid,
void* p_shared,
void* p_shared1,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
......@@ -1409,6 +1378,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
......@@ -1432,6 +1403,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_buf1,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment