Commit 54f44e62 authored by coderfeli's avatar coderfeli
Browse files

fix brepeat, kloop and lds two buffer; works ok now

parent 2c056624
......@@ -24,6 +24,11 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
// using I8 = int8_t;
// using I32 = int;
using F16 = ck::half_t;
using FP8 = ck::f8_t;
using F32 = float;
......@@ -54,25 +59,139 @@ struct MultiplyMultiply
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const
__host__ __device__ constexpr void operator()<F16, float, float, float>(
F16& e, const float& c, const float& d0, const float& d1) const
{
const float x0_f = c * d0 * d1;
e = ck::type_convert<F16>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
ck::half_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f);
}
};
// struct MultiplyMultiply
// {
// template <typename E, typename C, typename D0, typename D1>
// __host__ __device__ constexpr void
// operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
// ck::half_t& e, const float& c, const float& d0, const float& d1) const
// {
// const float x0_f = c * d0 * d1;
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
// ck::half_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
// ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::bhalf_t>(x0_f);
// }
// };
// void reinit2(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// int kinit = 0;
// for (int k = 0; k < K; k+=1) {
// // dst[n * K + k] = n;
// if(k>0 && k%128==0){
// kinit += 1;
// }
// dst[n * K + k] = k % 128 + kinit;//rand() % 5 - 2;
// }
// }
// }
// void reinit(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; k+=1) {
// dst[n * K + k] = ck::type_convert<FP8>(float(1));
// }
// }
// }
void dump(FP8* dst, int N, int K) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
printf("%.1f,", ck::type_convert<float>(dst[n * K + k]));
}
printf("\n");
}
}
// void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
// const int NRepeat = 1;
// const int KRepeat = 8;
// const int NWave = 4;
// const int KLane = 2;
// const int NLane = 32;
// const int KPack = 16;
// int K0 = K / (KRepeat * KLane * KPack);
// int tempn, tempk;
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; ++k) {
// int n0 = n / (NRepeat * NLane * NWave);
// int k0 = k / (KRepeat * KLane * KPack);
// tempn = n % (NRepeat * NLane * NWave);
// tempk = k % (KRepeat * KLane * KPack);
// int n1 = tempn / (NLane * NWave);
// int k1 = tempk / (KLane * KPack);
// tempn = tempn % (NLane * NWave);
// tempk = tempk % (KLane * KPack);
// int n2 = tempn / NLane;
// int k2 = tempk / KPack;
// int n3 = tempn % NLane;
// int k3 = tempk % KPack;
// int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
// + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
// + n1 * KPack * NLane * KLane * NWave * KRepeat
// + k1 * KPack * NLane * KLane * NWave
// + n2 * KPack * NLane * KLane
// + k2 * KPack * NLane
// + n3 * KPack
// + k3;
// dst[outputIndex] = src[n * K + k];
// }
// }
// }
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1;
const int KRepeat = 4;
const int KRepeat = 8;
const int NWave = 4;
const int KLane = 2;
const int NLane = 32;
const int KPack = 16;
int N0 = N / (NRepeat * NLane * NWave);
int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all lanes contiguous
// N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
......@@ -80,21 +199,22 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KLane * KPack);
int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack;
int n3 = tempn % NLane;
int k3 = tempk % KPack;
int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane;
int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * NWave * KRepeat
+ k1 * KPack * NLane * KLane * NWave
+ k2 * KPack * NLane * KLane * NWave //switch k1, k2
+ n2 * KPack * NLane * KLane
+ k2 * KPack * NLane
+ k1 * KPack * NLane
+ n3 * KPack
+ k3;
......@@ -102,7 +222,6 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
......@@ -120,6 +239,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///###### RCR
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
......@@ -215,8 +335,8 @@ int main(int argc, char* argv[])
case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break;
default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
......@@ -229,9 +349,12 @@ int main(int argc, char* argv[])
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
// reinit2(a0_m_k.mData.data(), M, K);
// reinit2(b0_k_n.mData.data(), N, K);
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
// dump(b0_preshuffled.mData.data(), N, K);
a0_device_buf.ToDevice(a0_m_k.mData.data());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data());
......@@ -273,7 +396,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......@@ -288,7 +411,7 @@ int main(int argc, char* argv[])
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false});
invoker.Run(argument, StreamConfig{nullptr, false, 0, 1, 1});
e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
......@@ -328,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
// main body
......@@ -355,15 +354,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
});
// if(threadIdx.x==0) {
// printf("\n");
// }
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
......@@ -442,20 +434,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 2;
} while(i < (num_loop - 1));
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
......
......@@ -130,10 +130,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 128;
static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 4;
static_assert(NLane * KLane == BlockSize);
static constexpr index_t KRepeat = 8;
static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size();
......@@ -173,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__host__ __device__ static auto CalculateBN0Shuffled(index_t N)
{
return math::integer_least_multiple(N, NLane);
return math::integer_divide_ceil(N, NLane * NWave);
}
__host__ __device__ static auto CalculateBK0Shuffled(index_t K, index_t KBatch)
{
return math::integer_least_multiple(K, KLane * KPack * KBatch);
return math::integer_divide_ceil(K, KLane * KPack * KBatch);
}
__host__ __device__ static auto CalculateKPadded(index_t K)
......@@ -1296,8 +1297,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock) / NLane;
__builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)) ;
// lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment