Commit 8c0e03ba authored by mtgu0705's avatar mtgu0705
Browse files

General fix.

parent f0fba871
......@@ -22,7 +22,7 @@ using CLayout = Row;
void preShuffleBuffer(const I4* src, I4* dst, int N, int K, int NXdl)
{
int KPack = 32;
int KPack = 32; // int4 -> 32, fp8 -> 16, fp16 -> 8
int NLane = NXdl;
int KLane = 64 / NLane;
......@@ -174,7 +174,10 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize());
DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
int NperXdl = GetPreShuffleParameters;
// do GEMM
auto gemm = DeviceGemmV2Instance{};
int NperXdl = gemm.GetPreShuffleParameters();
preShuffleBuffer(b_k_n.mData.data(), b_k_n_preshuffled.mData.data(), N, K, NperXdl);
// weight permute
......@@ -263,8 +266,6 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
// do GEMM
auto gemm = DeviceGemmV2Instance{};
auto invoker = gemm.MakeInvoker();
float ave_time = 0;
......
......@@ -11,6 +11,15 @@
namespace ck {
enum struct BlockGemmPipelineVersion
{
v1, // Naive
v2, // Mem
v3, // Comp
v4, // Comp, double lds buffer
v5, // Comp, double global prefetch register buffer
};
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t BlockSize,
......
......@@ -46,7 +46,8 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BBlockTransferSrcScalarPerVector;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
......@@ -56,6 +57,7 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
......@@ -112,6 +114,17 @@ struct BlockwiseGemmXdlops_pipeline_base
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
}
__device__ static auto CalculateAThreadOriginDataIndex6D()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
......
......@@ -142,8 +142,10 @@ struct DeviceGemmV2BPreshuffle : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual bool GetPermuteA() = 0;
virtual bool GetPermuteB() = 0;
virtual ck::index_t GetKPerBlock() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device
......
......@@ -329,6 +329,7 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle<AL
{
throw std::runtime_error("Only support pipeline ver v1, v2, v3 now!");
}
}
#if 0
else
{
......
......@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1514,7 +1514,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1614,7 +1614,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(, 0, KRepeat, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
// Blockwise GEMM pipeline
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment