"...resnet50_tensorflow.git" did not exist on "25efe03e7a59ce66e7da9c6823364d16ea2ca8de"
Commit 2c056624 authored by coderfeli's avatar coderfeli
Browse files

fix tail

parent 174b46b0
...@@ -66,32 +66,36 @@ struct MultiplyMultiply ...@@ -66,32 +66,36 @@ struct MultiplyMultiply
void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) { void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
const int NRepeat = 1; const int NRepeat = 1;
const int KRepeat = 4; const int KRepeat = 4;
const int NWave = 4;
const int KLane = 2; const int KLane = 2;
const int NLane = 128; const int NLane = 32;
const int KPack = 16; const int KPack = 16;
int N0 = N / (NRepeat * NLane); int N0 = N / (NRepeat * NLane * NWave);
int K0 = K / (KRepeat * KLane * KPack); int K0 = K / (KRepeat * KLane * KPack);
int tempn, tempk; int tempn, tempk;
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) { for (int k = 0; k < K; ++k) {
int n0 = n / (NRepeat * NLane); int n0 = n / (NRepeat * NLane * NWave);
int k0 = k / (KRepeat * KLane * KPack); int k0 = k / (KRepeat * KLane * KPack);
tempn = n % (NRepeat * NLane); tempn = n % (NRepeat * NLane * NWave);
tempk = k % (KRepeat * KLane * KPack); tempk = k % (KRepeat * KLane * KPack);
int n1 = tempn / NLane; int n1 = tempn / (NLane * NWave);
int k1 = tempk / (KLane * KPack); int k1 = tempk / (KLane * KPack);
int n2 = n1 % NLane; tempn = tempn % (NLane * NWave);
tempk = tempk % (KLane * KPack); tempk = tempk % (KLane * KPack);
int n2 = tempn / NLane;
int k2 = tempk / KPack; int k2 = tempk / KPack;
int n3 = tempn % NLane;
int k3 = tempk % KPack; int k3 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * KRepeat * NRepeat * K0 int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
+ k0 * KPack * NLane * KLane * KRepeat * NRepeat + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
+ n1 * KPack * NLane * KLane * KRepeat + n1 * KPack * NLane * KLane * NWave * KRepeat
+ k1 * KPack * NLane * KLane + k1 * KPack * NLane * KLane * NWave
+ n2 * KPack * NLane * KLane
+ k2 * KPack * NLane + k2 * KPack * NLane
+ n2 * KPack + n3 * KPack
+ k3; + k3;
dst[outputIndex] = src[n * K + k]; dst[outputIndex] = src[n * K + k];
...@@ -269,7 +273,7 @@ int main(int argc, char* argv[]) ...@@ -269,7 +273,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"); "not support this GEMM problem");
} }
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 20, 50}); float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 1});
std::size_t flop = std::size_t(2) * M * N * K; std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_btype = std::size_t num_btype =
......
...@@ -357,7 +357,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -357,7 +357,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
// if(threadIdx.x==0) { // if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik))); // printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// } // }
}); });
...@@ -451,6 +451,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -451,6 +451,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// tail // tail
if constexpr(TailNum == TailNumber::Full) if constexpr(TailNum == TailNumber::Full)
{ {
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf1);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, Number<1>{});
static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, NRepeat, 1>{}([&](auto n0) {
...@@ -462,6 +467,48 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave, ...@@ -462,6 +467,48 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf[Number<a_thread_desc_.CalculateOffset( a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}]; make_tuple(m0, I0, k0, ik))>{}];
}); });
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
xdlops_gemm.Run(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
a_block_buf1,
a_thread_desc_,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
});
__builtin_amdgcn_sched_barrier(0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec =
b_blockwise_copy.template GetSrcThreadScratchIdx<Sequence<0, k0, 0>, Number<1>{}>();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
});
using mfma_input_type = using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type; typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment