Commit c0ef46ff authored by mtgu0705's avatar mtgu0705
Browse files

move b thread dequant copy to blockwise.

parent a316dff9
...@@ -222,7 +222,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -222,7 +222,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
typename BBlockTransfer, typename BBlockTransfer,
typename BGridBuffer, typename BGridBuffer,
typename BBlockBuffer, typename BBlockBuffer,
typename BThreadTransfer,
typename BBlockTransferStep, typename BBlockTransferStep,
typename CThreadBuffer> typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc, __device__ void Run(const AGridDesc& a_grid_desc,
...@@ -236,7 +235,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -236,7 +235,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
const BGridBuffer& b_grid_buf, const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf, BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step, const BBlockTransferStep& b_block_copy_step,
BThreadTransfer& b_thread_dequant_copy, // BThreadTransfer& b_thread_dequant_copy,
CThreadBuffer& c_thread_buf, CThreadBuffer& c_thread_buf,
index_t num_loop) const index_t num_loop) const
{ {
...@@ -287,7 +286,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -287,7 +286,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
}); });
}); });
// B VGPR->VGPR dequant // B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(I0), b_thread_bufs(I0),
b_thread_desc_, b_thread_desc_,
...@@ -362,12 +361,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -362,12 +361,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
}); });
}); });
// B VGPR->VGPR dequant // B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(local_read_buf), b_thread_bufs(local_read_buf),
b_thread_desc_, b_thread_desc_,
make_tuple(I0, I0, I0, I0), make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf)); b_thread_dequant_bufs(local_read_buf));
HotLoopScheduler(); HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
...@@ -432,7 +431,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -432,7 +431,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
}); });
}); });
// B VGPR->VGPR dequant // B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1, b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx, b_block_origin_idx,
b_thread_bufs(I1), b_thread_bufs(I1),
b_thread_desc_, b_thread_desc_,
...@@ -528,6 +527,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I ...@@ -528,6 +527,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1; static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_; using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
}; };
} // namespace ck } // namespace ck
...@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{}; // const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
...@@ -1219,18 +1219,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1219,18 +1219,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
// Cast after lds // Cast after lds
...@@ -1260,9 +1248,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1260,9 +1248,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf, b_grid_buf,
b_block_buf, b_block_buf,
b_block_slice_copy_step, b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
...@@ -1522,7 +1507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1522,7 +1507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{}; const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{}; // const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{}; const CElementwiseOperation c_element_op{};
// divide block work by [M, N] // divide block work by [M, N]
...@@ -1612,18 +1597,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1612,18 +1597,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0, 0,
KPack * (get_thread_local_1d_id() % warpSize))); KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment // LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>( auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
...@@ -1656,9 +1629,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle ...@@ -1656,9 +1629,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf, b_grid_buf,
b_block_bufs, b_block_bufs,
b_block_slice_copy_step, b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf, c_thread_buf,
num_k_block_main_loop); num_k_block_main_loop);
......
...@@ -1573,7 +1573,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic ...@@ -1573,7 +1573,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
const SrcBuffer& src_buf, const SrcBuffer& src_buf,
const DstDesc&, const DstDesc&,
const DstSliceOriginIdx&, const DstSliceOriginIdx&,
DstBuffer& dst_buf) DstBuffer& dst_buf) const
{ {
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time"); "wrong! Desc need to known at compile-time");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment