"experiments/vscode:/vscode.git/clone" did not exist on "d77ff487c9a51b58c979457c37e1f6edadf99fa5"
Commit 42517524 authored by Jing Zhang's avatar Jing Zhang
Browse files

add splitK+bias

parent bc45fd98
...@@ -38,7 +38,7 @@ using AccDataType = F32; ...@@ -38,7 +38,7 @@ using AccDataType = F32;
using CShuffleDataType = F32; using CShuffleDataType = F32;
using D0DataType = F32; using D0DataType = F32;
using DsDataType = ck::Tuple<D0DataType>; using DsDataType = ck::Tuple<D0DataType>;
using EDataType = F16; using EDataType = F32;
using ALayout = Row; using ALayout = Row;
using BLayout = Col; using BLayout = Col;
...@@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F ...@@ -59,7 +59,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Xdl_F
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>; < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>;
// clang-format on // clang-format on
struct ProblemSize final struct ProblemSize final
...@@ -171,7 +171,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -171,7 +171,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{}); b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
d0_tensors[i].GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); d0_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential<1>{});
} }
using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>; using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<1>;
...@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ...@@ -254,7 +254,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co
} }
gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer()); gemm.SetDeviceKernelArgs(argument, gemm_desc_workspace.GetDeviceBuffer());
gemm.SetKBatch(argument, 1); gemm.SetKBatch(argument, 2);
invoker.Run(argument, StreamConfig{nullptr, false}); invoker.Run(argument, StreamConfig{nullptr, false});
......
...@@ -462,10 +462,13 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -462,10 +462,13 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
template <bool HasMainKBlockLoop, template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation, InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t NumDTensor_,
typename DsDataType_,
typename AGridDesc_KBatch_AK0_M_AK1, typename AGridDesc_KBatch_AK0_M_AK1,
typename BGridDesc_KBatch_BK0_N_BK1, typename BGridDesc_KBatch_BK0_N_BK1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEElementwiseOperation_,
typename Block2ETileMap> typename Block2ETileMap>
__device__ static void Run(const ABDataType* __restrict__ p_a_grid, __device__ static void Run(const ABDataType* __restrict__ p_a_grid,
const ABDataType* __restrict__ p_b_grid, const ABDataType* __restrict__ p_b_grid,
...@@ -474,7 +477,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -474,7 +477,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
void* __restrict__ p_shared, void* __restrict__ p_shared,
const AElementwiseOperation& a_element_op, const AElementwiseOperation& a_element_op,
const BElementwiseOperation& b_element_op, const BElementwiseOperation& b_element_op,
const CDEElementwiseOperation& cde_element_op, const CDEElementwiseOperation_& cde_element_op,
const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1, const AGridDesc_KBatch_AK0_M_AK1& a_grid_desc_kbatch_ak0_m_ak1,
const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1, const BGridDesc_KBatch_BK0_N_BK1& b_grid_desc_kbatch_bk0_n_bk1,
const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
...@@ -495,7 +498,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -495,7 +498,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
p_ds_grid[i], p_ds_grid[i],
ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize()); ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
}, },
Number<NumDTensor>{}); Number<NumDTensor_>{});
auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
...@@ -777,7 +780,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -777,7 +780,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
generate_tie( generate_tie(
[&](auto i) -> const auto& // return type should be reference [&](auto i) -> const auto& // return type should be reference
{ return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
Number<NumDTensor>{})); Number<NumDTensor_>{}));
// tuple of reference to C/Ds tensor descriptors // tuple of reference to C/Ds tensor descriptors
const auto c_ds_buf_refs = concat_tuple_of_reference( const auto c_ds_buf_refs = concat_tuple_of_reference(
...@@ -785,7 +788,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -785,7 +788,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
generate_tie( generate_tie(
[&](auto i) -> const auto& // return type should be reference [&](auto i) -> const auto& // return type should be reference
{ return ds_grid_buf[i]; }, { return ds_grid_buf[i]; },
Number<NumDTensor>{})); Number<NumDTensor_>{}));
// tuple of starting index of C/Ds blockwise copy // tuple of starting index of C/Ds blockwise copy
const auto idx_c_ds_block_begin = container_concat( const auto idx_c_ds_block_begin = container_concat(
...@@ -794,16 +797,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -794,16 +797,16 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
[&](auto) { [&](auto) {
return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0); return make_multi_index(block_work_idx[I1], 0, block_work_idx[I2], 0);
}, },
Number<NumDTensor>{})); Number<NumDTensor_>{}));
// blockwise copy C/D/E between LDS and global // blockwise copy C/D/E between LDS and global
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7< auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7<
ThisThreadBlock, ThisThreadBlock,
decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType_{})),
Tuple<EDataType>, Tuple<EDataType>,
decltype(c_ds_desc_refs), decltype(c_ds_desc_refs),
decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
CDEElementwiseOperation, CDEElementwiseOperation_,
Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
// support arbitray type // support arbitray type
Sequence<1, Sequence<1,
...@@ -817,7 +820,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -817,7 +820,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
CDEShuffleBlockTransferScalarPerVector_NPerBlock, CDEShuffleBlockTransferScalarPerVector_NPerBlock,
sequence_merge_t< sequence_merge_t<
Sequence<true>, Sequence<true>,
uniform_sequence_gen_t<NumDTensor, uniform_sequence_gen_t<NumDTensor_,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
{c_ds_desc_refs, {c_ds_desc_refs,
...@@ -879,7 +882,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -879,7 +882,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
sfc_cde_block.GetForwardStep(access_id); sfc_cde_block.GetForwardStep(access_id);
// move on Ds // move on Ds
static_for<0, NumDTensor, 1>{}([&](auto i) { static_for<0, NumDTensor_, 1>{}([&](auto i) {
cde_block_copy_lds_and_global.MoveSrcSliceWindow( cde_block_copy_lds_and_global.MoveSrcSliceWindow(
c_ds_desc_refs, i + I1, cde_lds_and_global_step); c_ds_desc_refs, i + I1, cde_lds_and_global_step);
}); });
...@@ -961,7 +964,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -961,7 +964,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const auto e_grid_desc_mblock_mperblock_nblock_nperblock = const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation>( const auto block_work_idx =
block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
const index_t kbatch_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
if(kbatch_id == 0)
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, NumDTensor, DsDataType>(
p_a_grid, p_a_grid,
p_b_grid, p_b_grid,
p_ds_grid, p_ds_grid,
...@@ -976,6 +986,24 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle ...@@ -976,6 +986,24 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
e_grid_desc_mblock_mperblock_nblock_nperblock, e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map); block_2_etile_map);
} }
else
{
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, 0, Tuple<>>(
p_a_grid,
p_b_grid,
p_ds_grid,
p_e_grid,
p_shared,
a_element_op,
b_element_op,
ck::tensor_operation::element_wise::PassThrough{},
a_grid_desc_kbatch_ak0_m_ak1,
b_grid_desc_kbatch_bk0_n_bk1,
ds_grid_desc_mblock_mperblock_nblock_nperblock,
e_grid_desc_mblock_mperblock_nblock_nperblock,
block_2_etile_map);
}
}
}; };
} // namespace ck } // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment