Commit 4a1ec815 authored by coderfeli's avatar coderfeli
Browse files

add bypass logic and build

parent 19b7c131
......@@ -516,10 +516,6 @@ include_directories(BEFORE
)
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV)
add_compile_options(-Werror)
add_compile_options(-Weverything)
endif()
message("CMAKE_CXX_FLAGS: ${CMAKE_CXX_FLAGS}")
if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang")
......
......@@ -66,7 +66,6 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
......
......@@ -63,6 +63,40 @@ struct MultiplyMultiply
}
};
void reshapeBuffer(char* buffer, int N, int K, char* output) {
const int KRepeat = 2;
const int NRepeat = 3;
const int KLane = 4;
const int NLane = 5;
const int KPack = 6;
int N0 = N / (NRepeat * NLane);
int K0 = K / (KRepeat * KLane * KPack);
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K; ++k) {
int n0 = n / (NRepeat * NLane);
int k0 = k / (KRepeat * KLane * KPack);
int nRel = n % (NRepeat * NLane);
int kRel = k % (KRepeat * KLane * KPack);
int nIndex = nRel / NLane;
int kIndex = kRel / (KLane * KPack);
int nLaneIndex = nRel % NLane;
int kLaneIndex = (kRel % (KLane * KPack)) / KPack;
int kPackIndex = kRel % KPack;
int outputIndex = (n0 * K0 + k0) * KRepeat * NRepeat * KLane * NLane * KPack
+ nIndex * KRepeat * KLane * KPack
+ kIndex * KLane * KPack
+ nLaneIndex * KPack
+ kLaneIndex * KPack
+ kPackIndex;
output[outputIndex] = buffer[n * K + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
......@@ -77,10 +111,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S<C, D0, D1>|
///###### RRR
///< Row, Row, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
///###### RCR
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// clang-format on
int main(int argc, char* argv[])
......
......@@ -305,11 +305,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Local prefill 1
// // Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2
// // Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
......@@ -330,14 +330,14 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, k0, I0),
// b_thread_buf);
// });
});
__builtin_amdgcn_sched_barrier(0);
......@@ -351,7 +351,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
......@@ -400,14 +400,14 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
make_tuple(m0, I0, k0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k0, I0),
b_thread_buf);
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, k0, I0),
// b_thread_buf);
// });
});
HotLoopScheduler();
......@@ -455,7 +455,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
// using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment