Commit d49d6928 authored by Chao Liu's avatar Chao Liu
Browse files

update example

parent 6c2b74de
...@@ -40,17 +40,40 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough; ...@@ -40,17 +40,40 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough; using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementOp = ck::tensor_operation::element_wise::Bilinear; using CDEElementOp = ck::tensor_operation::element_wise::Bilinear;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device:: using DeviceOpInstanceKKNN = ck::tensor_operation::device::
//############################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//############################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| //#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//############################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| //#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>; DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceKNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 4, 1, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMKNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 4, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4>;
using DeviceOpInstanceMNNN = ck::tensor_operation::device::
//#####################################| NumDimM| NumDimN| NumDimK| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
//#####################################| | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//#####################################| | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceContractionMultipleD_Xdl_CShuffle< NumDimM, NumDimN, NumDimK, F32, F32, F32, F32, DsDataType, F32, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 256, 128, 16, 1, 1, 32, 32, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 1, 0, 1, 1, S<1, 16, 1, 16>, 4>;
// clang-format on // clang-format on
using DeviceOpInstance = DeviceOpInstanceKKNN;
// hardcoded for NumDimM == NumDimN == NumDimK == 2 // hardcoded for NumDimM == NumDimN == NumDimK == 2
template <ck::index_t NumDimM, template <ck::index_t NumDimM,
ck::index_t NumDimN, ck::index_t NumDimN,
......
...@@ -163,6 +163,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -163,6 +163,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{}; static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
// Assume: A[M0, M1, M2, ..., K0, K1, K2, ...] // Assume: A[M0, M1, M2, ..., K0, K1, K2, ...]
static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec, static auto MakeAGridDescriptor_AK0_M_AK1(const std::vector<index_t>& a_ms_ks_lengths_vec,
...@@ -593,17 +594,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -593,17 +594,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
a_element_op_{a_element_op}, a_element_op_{a_element_op},
b_element_op_{b_element_op}, b_element_op_{b_element_op},
cde_element_op_{cde_element_op}, cde_element_op_{cde_element_op},
a_mz_length_{},
a_mz_stride_{}, a_mz_stride_{},
a_kz_length_{},
a_kz_stride_{}, a_kz_stride_{},
b_nz_length_{},
b_nz_stride_{}, b_nz_stride_{},
b_kz_length_{},
b_kz_stride_{}, b_kz_stride_{},
e_mz_length_{}, ds_nz_stride_{},
e_mz_stride_{},
e_nz_length_{},
e_nz_stride_{} e_nz_stride_{}
{ {
if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_, if(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1_,
...@@ -630,31 +625,17 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -630,31 +625,17 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
} }
// for sanity check of vector memory access // for sanity check of vector memory access
a_mz_length_ = a_ms_ns_lengths[NumDimM - 1];
a_mz_stride_ = a_ms_ks_strides[NumDimM - 1]; a_mz_stride_ = a_ms_ks_strides[NumDimM - 1];
a_kz_length_ = a_ms_ns_lengths[NumDimM + NumDimK - 1];
a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1]; a_kz_stride_ = a_ms_ks_strides[NumDimM + NumDimK - 1];
b_nz_length_ = b_ns_ks_lengths[NumDimN - 1];
b_nz_stride_ = b_ns_ks_strides[NumDimN - 1]; b_nz_stride_ = b_ns_ks_strides[NumDimN - 1];
b_kz_length_ = b_ns_ks_lengths[NumDimN + NumDimK - 1];
b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1]; b_kz_stride_ = b_ns_ks_strides[NumDimN + NumDimK - 1];
for(index_t i = 0; i < NumDTensor; ++i) for(index_t i = 0; i < NumDTensor; ++i)
{ {
ds_mz_length_[i] = ds_ms_ns_lengths[i][NumDimM - 1];
ds_mz_stride_[i] = ds_ms_ns_strides[i][NumDimM - 1];
ds_nz_length_[i] = ds_ms_ns_lengths[i][NumDimM + NumDimN - 1];
ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1]; ds_nz_stride_[i] = ds_ms_ns_strides[i][NumDimM + NumDimN - 1];
} }
e_mz_length_ = e_ms_ns_lengths[NumDimM - 1];
e_mz_stride_ = e_ms_ns_strides[NumDimM - 1];
e_nz_length_ = e_ms_ns_lengths[NumDimM + NumDimN - 1];
e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1]; e_nz_stride_ = e_ms_ns_strides[NumDimM + NumDimN - 1];
} }
...@@ -685,24 +666,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -685,24 +666,14 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation b_element_op_; BElementwiseOperation b_element_op_;
CDEElementwiseOperation cde_element_op_; CDEElementwiseOperation cde_element_op_;
// These are last M/N/K dimensions of A/B/Ds/E // Strides for the last M/N/K dimensions of A/B/Ds/E
// for sanity check of vector load/store // for sanity check of vector load/store
// FIXME: add check for D0, D1, ...
index_t a_mz_length_;
index_t a_mz_stride_; index_t a_mz_stride_;
index_t a_kz_length_;
index_t a_kz_stride_; index_t a_kz_stride_;
index_t b_nz_length_;
index_t b_nz_stride_; index_t b_nz_stride_;
index_t b_kz_length_;
index_t b_kz_stride_; index_t b_kz_stride_;
std::array<index_t, NumDTensor> ds_mz_length_;
std::array<index_t, NumDTensor> ds_mz_stride_;
std::array<index_t, NumDTensor> ds_nz_length_;
std::array<index_t, NumDTensor> ds_nz_stride_; std::array<index_t, NumDTensor> ds_nz_stride_;
index_t e_mz_length_;
index_t e_mz_stride_; index_t e_mz_stride_;
index_t e_nz_length_;
index_t e_nz_stride_; index_t e_nz_stride_;
}; };
...@@ -826,7 +797,73 @@ struct DeviceContractionMultipleD_Xdl_CShuffle ...@@ -826,7 +797,73 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return false; return false;
} }
// FIXME check vector load/store // check vector access
static_assert((ABlockTransferSrcVectorDim == 1 || ABlockTransferSrcVectorDim == 2) &&
(BBlockTransferSrcVectorDim == 1 || BBlockTransferSrcVectorDim == 2),
"wrong!");
// vector memory access of A: could be on M or AK1 dimension
if constexpr(ABlockTransferSrcVectorDim == 1)
{
if(!(arg.a_mz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.a_kz_stride_ == 1 &&
arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) % ABlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of B: could be on N or BK1 dimension
if constexpr(BBlockTransferSrcVectorDim == 1)
{
if(!(arg.b_nz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
else
{
if(!(arg.b_kz_stride_ == 1 &&
arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) % BBlockTransferSrcScalarPerVector == 0))
{
return false;
}
}
// vector memory access of Ds: always on NPerBlock dimension
bool valid_d_access = true;
static_for<0, NumDTensor, 1>{}([&](auto i) {
if(!(arg.ds_nz_stride_[i] == 1 &&
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_[i].GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
valid_d_access = false;
}
});
if(valid_d_access == false)
{
return false;
}
// vector memory access of E: always on NPerBlock dimension
if(!(arg.e_nz_stride_ == 1 &&
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_.GetLength(I3) %
CDEBlockTransferScalarPerVector_NPerBlock ==
0))
{
return false;
}
return true; return true;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment