Commit 49bac8ce authored by mtgu0705's avatar mtgu0705
Browse files

Added b preshuffle pipeline v3 support.

parent a0432459
......@@ -28,9 +28,9 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr ck::index_t KPerBlock = 128;
// clang-format off
#if 0
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
......@@ -38,7 +38,7 @@ using DeviceGemmV2Instance =
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
128, 128,
KPerBlock, 16, 32,
256, 16, 32,
32, 32,
4, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
......@@ -48,6 +48,25 @@ using DeviceGemmV2Instance =
1, 1, S<1, 32, 1, 8>, 4,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, F8, F8, PermuteA, PermuteB>;
#else
using DeviceGemmV2Instance =
ck::tensor_operation::device::DeviceGemm_Xdl_CShuffleV3_BPreshuffle<
ALayout, BLayout, CLayout,
ADataType, BDataType, CDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CElementOp, GemmDefault,
256,
256, 256,
128, 16, 32,
32, 32,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 16, 16, 0,
S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>,
2, 32, 32, 0,
1, 1, S<1, 32, 1, 8>, 8,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, F8, F8, PermuteA, PermuteB>;
#endif
// clang-format on
template <typename ProblemType>
......
......@@ -510,10 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
// Global prefetch A1 B1
......@@ -545,6 +548,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(I0, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
// Initialize C
c_thread_buf.Clear();
......@@ -594,7 +604,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -633,6 +643,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
}
else
{
......@@ -652,6 +669,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(mfma_reg_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(mfma_reg_buf));
}
HotLoopScheduler(m0);
......@@ -691,7 +715,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -720,6 +744,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
}
else
{
......@@ -732,6 +763,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
}
EpilogueScheduler_1(m0);
......@@ -748,7 +786,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -776,6 +814,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
Number<(m0 + 1 + HotloopLocalBufSwitch) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
EpilogueScheduler_2();
}
......@@ -797,7 +842,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -823,6 +868,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
make_tuple(Number<(m0 + 1) % 2>{}, I0, I0, k0, I0, I0),
a_thread_buf);
});
// B VGPR->VGPR dequant
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
EpilogueScheduler_2();
}
......@@ -855,6 +907,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v3<BlockGemmPipelineScheduler::I
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
};
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment