"...search-o1_pytorch.git" did not exist on "b6edc328feb927001555db66813a6c73fa4ccd59"
Commit 279e2eaf authored by mtgu0705's avatar mtgu0705
Browse files

move b thread dequant copy to blockwise.

parent 52c018b7
......@@ -222,7 +222,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BThreadTransfer,
typename BBlockTransferStep,
typename CThreadBuffer>
__device__ void Run(const AGridDesc& a_grid_desc,
......@@ -236,7 +235,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
BThreadTransfer& b_thread_dequant_copy,
// BThreadTransfer& b_thread_dequant_copy,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
......@@ -287,7 +286,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0),
b_thread_desc_,
......@@ -363,12 +362,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf),
b_thread_desc_,
make_tuple(I0, I0, I0, I0),
b_thread_dequant_bufs(local_read_buf));
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
......@@ -433,7 +432,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
});
});
// B VGPR->VGPR dequant
b_thread_dequant_copy.Run(b_block_desc_n0_n1_k0_k1,
b_thread_dequant_copy_.Run(b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1),
b_thread_desc_,
......@@ -529,6 +528,22 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
using Base::c_thread_desc_;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using BThreadDequantCopy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeDataType,
decltype(b_block_desc_n0_n1_k0_k1),
decltype(b_block_desc_n0_n1_k0_k1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}>,
Sequence<1, 2, 0, 3>,
3,
KPack>;
const PassThrough b_element_op{};
BThreadDequantCopy b_thread_dequant_copy_{b_element_op};
};
} // namespace ck
......@@ -1134,7 +1134,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1219,18 +1219,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<0, 1, 2, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment
// Cast after lds
......@@ -1260,9 +1248,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf,
num_k_block_main_loop);
......@@ -1522,7 +1507,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
const AElementwiseOperation a_element_op{};
const BElementwiseOperation b_element_op{};
// const BElementwiseOperation b_element_op{};
const CElementwiseOperation c_element_op{};
// divide block work by [M, N]
......@@ -1612,18 +1597,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
0,
KPack * (get_thread_local_1d_id() % warpSize)));
// B: VGRP->VGPR dequantization
auto b_thread_dequant_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
BDataType,
ComputeTypeA,
decltype(b_block_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
tensor_operation::element_wise::PassThrough,
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BK1Number>(b_element_op);
// LDS allocation for A and B: be careful of alignment
auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
......@@ -1656,9 +1629,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
// B: VGRP->VGPR dequantization
b_thread_dequant_copy,
c_thread_buf,
num_k_block_main_loop);
......
......@@ -1573,7 +1573,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
DstBuffer& dst_buf) const
{
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! Desc need to known at compile-time");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment