Commit 2c056624 authored by coderfeli's avatar coderfeli
Browse files

fix tail

parent 174b46b0
......@@ -66,32 +66,36 @@ struct MultiplyMultiply
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1;
const int KRepeat = 4;
const int NWave = 4;
const int KLane = 2;
const int NLane = 128;
const int NLane = 32;
const int KPack = 16;
int N0 = N / (NRepeat * NLane);
int N0 = N / (NRepeat * NLane * NWave);
int K0 = K / (KRepeat * KLane * KPack);
int tempn, tempk;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
int n0 = n / (NRepeat * NLane);
int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane);
tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / NLane;
int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KLane * KPack);
int n2 = n1 % NLane;
tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack;
int n3 = tempn % NLane;
int k3 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * KRepeat
+ k1 * KPack * NLane * KLane
int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * NWave * KRepeat
+ k1 * KPack * NLane * KLane * NWave
+ n2 * KPack * NLane * KLane
+ k2 * KPack * NLane
+ n2 * KPack
+ n3 * KPack
+ k3;
dst[outputIndex] = src[n * K + k];
......@@ -269,7 +273,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem");
}
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50});
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype =
......
......@@ -357,7 +357,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
make_tuple(m0, I0, k0, ik))>{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
});
......@@ -451,6 +451,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
......@@ -462,6 +467,48 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment