You need to sign in or sign up before continuing.
Commit bee790ec authored by mtgu0705's avatar mtgu0705
Browse files

init b preshuffle dequant in VGPR.

parent ed89a238
......@@ -222,6 +222,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BThreadTransfer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
......@@ -235,6 +236,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
BThreadTransfer& b_thread_dequant_copy,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
......@@ -242,12 +244,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
b_thread_desc_.GetElementSpaceSize());
auto b_thread_dequant_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
StaticallyIndexedArray<decltype(b_thread_dequant_buf), Number<2>{}> b_thread_dequant_bufs;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
......@@ -279,6 +286,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I0));
// Initialize C
c_thread_buf.Clear();
......@@ -316,9 +330,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_dequant_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
......@@ -347,6 +361,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(mfma_reg_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(mfma_reg_buf));
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
......@@ -381,7 +402,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -410,6 +431,13 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf);
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(I1));
__builtin_amdgcn_sched_barrier(0);
......@@ -424,7 +452,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......@@ -457,7 +485,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
b_thread_dequant_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
......
......@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
// const BElementwiseOperation b_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1205,8 +1205,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2<
BDataType,
// BDataType,
ADataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
......@@ -1220,18 +1219,24 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
// auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
// reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
// sizeof(ADataType) /
// APackedSize),
// b_block_desc_bk0_n_bk1.GetElementSpaceSize());
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0);
......@@ -1255,6 +1260,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf,
num_k_block_main_loop);
......@@ -1514,7 +1522,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
// const BElementwiseOperation b_element_op{};
const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1604,6 +1612,18 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -1636,6 +1656,9 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf,
num_k_block_main_loop);
......
......@@ -287,6 +287,7 @@ struct ThreadwiseTensorSliceTransfer_v2
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
#if 0
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
......@@ -352,12 +353,13 @@ struct ThreadwiseTensorSliceTransfer_v2
});
}
else
#endif
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
using src_vector_t =
typename vector_type_maker<SrcData, SrcScalarPerVector>::type::type;
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
......@@ -365,24 +367,24 @@ struct ThreadwiseTensorSliceTransfer_v2
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize, is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
i * src_scalar_step_in_vector);
if constexpr(InvalidElementAsNaN)
{
dst_buf(Number<dst_offset>{}) =
dst_buf(Number<dst_offset / PackedSize>{}) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<dst_offset>{}) =
dst_buf(Number<dst_offset / PackedSize>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
......@@ -1544,6 +1546,13 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
using Index = MultiIndex<nDim>;
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
const ElementwiseOperation& element_op)
: element_op_{element_op}
......@@ -1598,26 +1607,70 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
if constexpr(is_same<remove_cvref_t<SrcData>, pk_i4_t>::value)
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type src_tmp_vector;
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset / PackedSize>{}];
});
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
// apply type convert
dst_buf(Number<dst_offset>{}) = v;
constexpr index_t pack_size = 8;
static_assert(DstScalarPerVector % pack_size == 0, "");
using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
ck::tensor_operation::element_wise::PassThroughPack8{}(
dst_tmp_vector.template AsType<dst_v_t>()(i),
src_tmp_vector.template AsType<src_v_t>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
});
});
});
}
else
{
static_for<0, num_access, 1>{}([&](auto idx_1d) {
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
// copy data from src_buf into dst_vector
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
constexpr index_t src_offset = src_desc.CalculateOffset(
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
DstData v;
// apply element-wise operation
element_op_(v, src_buf[Number<src_offset>{}]);
// apply type convert
dst_buf(Number<dst_offset>{}) = v;
});
});
}
}
ElementwiseOperation element_op_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment