Commit 54f44e62 authored by coderfeli's avatar coderfeli
Browse files

fix brepeat, kloop and lds two buffer; works ok now

parent 2c056624
...@@ -24,6 +24,11 @@ ...@@ -24,6 +24,11 @@
template <ck::index_t... Is> template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F32 = float;
// using I8 = int8_t;
// using I32 = int;
using F16 = ck::half_t; using F16 = ck::half_t;
using FP8 = ck::f8_t; using FP8 = ck::f8_t;
using F32 = float; using F32 = float;
...@@ -54,25 +59,139 @@ struct MultiplyMultiply ...@@ -54,25 +59,139 @@ struct MultiplyMultiply
operator()(E& e, const C& c, const D0& d0, const D1& d1) const; operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
template <> template <>
__host__ __device__ constexpr void operator()<ck::half_t, float, float, float>( __host__ __device__ constexpr void operator()<F16, float, float, float>(
ck::half_t& e, const float& c, const float& d0, const float& d1) const F16& e, const float& c, const float& d0, const float& d1) const
{ {
const float x0_f = c * d0 * d1; const float x0_f = c * d0 * d1;
e = ck::type_convert<F16>(x0_f);
}
template <>
__host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
ck::half_t& e, const int& c, const float& d0, const float& d1) const
{
const float x0_f =
ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
e = ck::type_convert<ck::half_t>(x0_f); e = ck::type_convert<ck::half_t>(x0_f);
} }
}; };
// struct MultiplyMultiply
// {
// template <typename E, typename C, typename D0, typename D1>
// __host__ __device__ constexpr void
// operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
// ck::half_t& e, const float& c, const float& d0, const float& d1) const
// {
// const float x0_f = c * d0 * d1;
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
// ck::half_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
// ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::bhalf_t>(x0_f);
// }
// };
// void reinit2(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// int kinit = 0;
// for (int k = 0; k < K; k+=1) {
// // dst[n * K + k] = n;
// if(k>0 && k%128==0){
// kinit += 1;
// }
// dst[n * K + k] = k % 128 + kinit;//rand() % 5 - 2;
// }
// }
// }
// void reinit(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; k+=1) {
// dst[n * K + k] = ck::type_convert<FP8>(float(1));
// }
// }
// }
void dump(FP8* dst, int N, int K) {
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
printf("%.1f,", ck::type_convert<float>(dst[n * K + k]));
}
printf("\n");
}
}
// void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
// const int NRepeat = 1;
// const int KRepeat = 8;
// const int NWave = 4;
// const int KLane = 2;
// const int NLane = 32;
// const int KPack = 16;
// int K0 = K / (KRepeat * KLane * KPack);
// int tempn, tempk;
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; ++k) {
// int n0 = n / (NRepeat * NLane * NWave);
// int k0 = k / (KRepeat * KLane * KPack);
// tempn = n % (NRepeat * NLane * NWave);
// tempk = k % (KRepeat * KLane * KPack);
// int n1 = tempn / (NLane * NWave);
// int k1 = tempk / (KLane * KPack);
// tempn = tempn % (NLane * NWave);
// tempk = tempk % (KLane * KPack);
// int n2 = tempn / NLane;
// int k2 = tempk / KPack;
// int n3 = tempn % NLane;
// int k3 = tempk % KPack;
// int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
// + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
// + n1 * KPack * NLane * KLane * NWave * KRepeat
// + k1 * KPack * NLane * KLane * NWave
// + n2 * KPack * NLane * KLane
// + k2 * KPack * NLane
// + n3 * KPack
// + k3;
// dst[outputIndex] = src[n * K + k];
// }
// }
// }
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1; const int NRepeat = 1;
const int KRepeat = 4; const int KRepeat = 8;
const int NWave = 4; const int NWave = 4;
const int KLane = 2; const int KLane = 2;
const int NLane = 32; const int NLane = 32;
const int KPack = 16; const int KPack = 16;
int N0 = N / (NRepeat * NLane * NWave);
int K0 = K / (KRepeat * KLane * KPack); int K0 = K / (KRepeat * KLane * KPack);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all lanes contiguous
// N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int tempn, tempk; int tempn, tempk;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
...@@ -80,21 +199,22 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { ...@@ -80,21 +199,22 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
int k0 = k / (KRepeat * KLane * KPack); int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane * NWave); tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack); tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / (NLane * NWave); int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KLane * KPack); int k1 = tempk / (KRepeat * KPack); // Klane
tempn = tempn % (NLane * NWave); tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack); tempk = tempk % (KRepeat * KPack);
int n2 = tempn / NLane; int n2 = tempn / NLane;
int k2 = tempk / KPack; int k2 = tempk / KPack; // KRepeat
int n3 = tempn % NLane; int n3 = tempn % NLane;
int k3 = tempk % KPack; int k3 = tempk % KPack; // Kpack
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0 int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * NWave * KRepeat + n1 * KPack * NLane * KLane * NWave * KRepeat
+ k1 * KPack * NLane * KLane * NWave + k2 * KPack * NLane * KLane * NWave //switch k1, k2
+ n2 * KPack * NLane * KLane + n2 * KPack * NLane * KLane
+ k2 * KPack * NLane + k1 * KPack * NLane
+ n3 * KPack + n3 * KPack
+ k3; + k3;
...@@ -102,7 +222,6 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { ...@@ -102,7 +222,6 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
} }
} }
} }
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough; using AElementOp = PassThrough;
...@@ -120,6 +239,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu ...@@ -120,6 +239,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///###### RCR ///###### RCR
// kernel 1: 256->32x128x128 // kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128 // kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>; // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...@@ -215,8 +335,8 @@ int main(int argc, char* argv[]) ...@@ -215,8 +335,8 @@ int main(int argc, char* argv[])
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2}); b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-2, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-2, 2});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
...@@ -229,9 +349,12 @@ int main(int argc, char* argv[]) ...@@ -229,9 +349,12 @@ int main(int argc, char* argv[])
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_m_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize());
// reinit2(a0_m_k.mData.data(), M, K);
// reinit2(b0_k_n.mData.data(), N, K);
preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data()); preShuffleBuffer(b0_k_n.mData.data(), N, K, b0_preshuffled.mData.data());
// dump(b0_preshuffled.mData.data(), N, K);
a0_device_buf.ToDevice(a0_m_k.mData.data()); a0_device_buf.ToDevice(a0_m_k.mData.data());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf.ToDevice(b0_preshuffled.mData.data()); b0_device_buf.ToDevice(b0_preshuffled.mData.data());
d0_device_buf.ToDevice(d0_m_n.mData.data()); d0_device_buf.ToDevice(d0_m_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
...@@ -273,7 +396,7 @@ int main(int argc, char* argv[]) ...@@ -273,7 +396,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
...@@ -288,7 +411,7 @@ int main(int argc, char* argv[]) ...@@ -288,7 +411,7 @@ int main(int argc, char* argv[])
if(do_verification) if(do_verification)
{ {
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false, 0, 1, 1});
e_device_buf.FromDevice(e_m_n_device_result.mData.data()); e_device_buf.FromDevice(e_m_n_device_result.mData.data());
......
...@@ -328,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -328,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf); a_thread_buf);
}); });
}); });
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
// main body // main body
...@@ -355,15 +354,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -355,15 +354,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
}); });
// if(threadIdx.x==0) {
// printf("\n");
// }
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type; xdlops_gemm.K1PerXdlops>::type;
...@@ -442,20 +434,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -442,20 +434,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf); a_thread_buf);
}); });
}); });
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
i += 2; i += 2;
} while(i < (num_loop - 1)); } while(i < (num_loop - 2));
} }
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{}); b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
......
...@@ -130,10 +130,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -130,10 +130,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static constexpr auto AK1Number = Number<AK1Value>{}; static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{}; static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr auto BlockSizeNumber = Number<BlockSize>{}; static constexpr auto BlockSizeNumber = Number<BlockSize>{};
static constexpr index_t NLane = 128; static constexpr index_t NLane = 32;
static constexpr index_t NWave = 4;
static constexpr index_t KLane = 2; static constexpr index_t KLane = 2;
static constexpr index_t KRepeat = 4; static constexpr index_t KRepeat = 8;
static_assert(NLane * KLane == BlockSize); static_assert(NLane * NWave * KLane == BlockSize);
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
...@@ -173,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -173,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__host__ __device__ static auto CalculateBN0Shuffled(index_t N) __host__ __device__ static auto CalculateBN0Shuffled(index_t N)
{ {
return math::integer_least_multiple(N, NLane); return math::integer_divide_ceil(N, NLane * NWave);
} }
__host__ __device__ static auto CalculateBK0Shuffled(index_t K, index_t KBatch) __host__ __device__ static auto CalculateBK0Shuffled(index_t K, index_t KBatch)
{ {
return math::integer_least_multiple(K, KLane * KPack * KBatch); return math::integer_divide_ceil(K, KLane * KPack * KBatch);
} }
__host__ __device__ static auto CalculateKPadded(index_t K) __host__ __device__ static auto CalculateKPadded(index_t K)
...@@ -1296,8 +1297,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3 ...@@ -1296,8 +1297,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
const index_t n_block_data_idx_on_grid = const index_t n_block_data_idx_on_grid =
__builtin_amdgcn_readfirstlane(block_n_id * NPerBlock) / NLane; __builtin_amdgcn_readfirstlane(block_n_id * (NPerBlock / NLane / NWave)) ;
// lds max alignment // lds max alignment
constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment