Commit 694057a7 authored by rocking's avatar rocking
Browse files

Rewrite the device interface and rename some var

parent 982c85a3
...@@ -38,7 +38,8 @@ using CShuffleDataType = F32; ...@@ -38,7 +38,8 @@ using CShuffleDataType = F32;
using D0DataType = F16; using D0DataType = F16;
using D1DataType = F16; using D1DataType = F16;
using DsDataType = ck::Tuple<D0DataType, D1DataType>; using DsDataType = ck::Tuple<D0DataType, D1DataType>;
using EDataType = F16; using GammaDataType = F16;
using BetaDataType = F16;
using HDataType = F16; using HDataType = F16;
// Layout // Layout
...@@ -59,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa ...@@ -59,11 +60,11 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa
// clang-format off // clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleDLayernorm_Xdl_CShuffle
//######| ALayout| BLayout| DsLayout| ELayout| HLayout| AData| BData| AccData| CShuffle| DsData| EData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| ReduceThreadTransfer| | //######| ALayout| BLayout| DsLayout| HLayout| AData| BData| AccData| CShuffle| DsData| GammaData| BetaData| HData| A| B| CDE| H| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| PostShuffle| PostShuffle|
//######| | | | | | Type| Type| Type| DataType| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ReduceThreadTransfer| //######| | | | | Type| Type| Type| DataType| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector|
//######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MPerBlock_NPerBlock| ScalarPerVector| //######| | | | | | | | | | | | | Operation| Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _M_N| _M_N|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NPerBlock| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4>; < ALayout, BLayout, DsLayout, HLayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, GammaDataType, BetaDataType, HDataType, AElementOp, BElementOp, CDEElementOp, HElementOp, GemmDefault, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<64, 4>, 4, S<8, 32>, S<1, 8>, 1, 8, 8, 8, 8, 1>;
// clang-format on // clang-format on
auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) { auto f_host_tensor_descriptor1d = [](std::size_t len, std::size_t stride) {
...@@ -96,33 +97,39 @@ int main() ...@@ -96,33 +97,39 @@ int main()
ck::index_t StrideB = 1024; ck::index_t StrideB = 1024;
ck::index_t StrideD0 = 0; ck::index_t StrideD0 = 0;
ck::index_t StrideD1 = 1024; ck::index_t StrideD1 = 1024;
ck::index_t StrideE = 1024;
ck::index_t StrideH = 1024; ck::index_t StrideH = 1024;
// TODO - gamma and beta float epsilon = 1e-5;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{})); Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{})); Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1)); Tensor<D0DataType> d0_n(f_host_tensor_descriptor1d(N, 1));
Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{})); Tensor<D1DataType> d1_m_n(f_host_tensor_descriptor2d(M, N, StrideD1, D1Layout{}));
Tensor<EDataType> e_m_n(f_host_tensor_descriptor2d(M, N, StrideE, ELayout{})); Tensor<GammaDataType> gamma_n(f_host_tensor_descriptor1d(N, 1));
Tensor<BetaDataType> beta_n(f_host_tensor_descriptor1d(N, 1));
Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{})); Tensor<HDataType> h_m_n(f_host_tensor_descriptor2d(M, N, StrideH, HLayout{}));
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1}); a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-1, 1});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-1, 1});
d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1}); d0_n.GenerateTensorValue(GeneratorTensor_3<D0DataType>{-1, 1});
d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_3<D1DataType>{-1, 1});
gamma_n.GenerateTensorValue(GeneratorTensor_3<GammaDataType>{-1, 1});
beta_n.GenerateTensorValue(GeneratorTensor_3<BetaDataType>{-1, 1});
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize()); DeviceMem d0_device_buf(sizeof(D0DataType) * d0_n.mDesc.GetElementSpaceSize());
DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize()); DeviceMem d1_device_buf(sizeof(D1DataType) * d1_m_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_m_n.mDesc.GetElementSpaceSize()); DeviceMem gamma_device_buf(sizeof(GammaDataType) * gamma_n.mDesc.GetElementSpaceSize());
DeviceMem beta_device_buf(sizeof(BetaDataType) * beta_n.mDesc.GetElementSpaceSize());
DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize()); DeviceMem h_device_buf(sizeof(HDataType) * h_m_n.mDesc.GetElementSpaceSize());
a_device_buf.ToDevice(a_m_k.mData.data()); a_device_buf.ToDevice(a_m_k.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data()); b_device_buf.ToDevice(b_k_n.mData.data());
d0_device_buf.ToDevice(d0_n.mData.data()); d0_device_buf.ToDevice(d0_n.mData.data());
d1_device_buf.ToDevice(d1_m_n.mData.data()); d1_device_buf.ToDevice(d1_m_n.mData.data());
gamma_device_buf.ToDevice(gamma_n.mData.data());
beta_device_buf.ToDevice(beta_n.mData.data());
auto a_element_op = AElementOp{}; auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{}; auto b_element_op = BElementOp{};
...@@ -135,7 +142,8 @@ int main() ...@@ -135,7 +142,8 @@ int main()
device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(), b_device_buf.GetDeviceBuffer(),
{d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()}, {d0_device_buf.GetDeviceBuffer(), d1_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
beta_device_buf.GetDeviceBuffer(),
h_device_buf.GetDeviceBuffer(), h_device_buf.GetDeviceBuffer(),
M, M,
N, N,
...@@ -143,8 +151,8 @@ int main() ...@@ -143,8 +151,8 @@ int main()
StrideA, StrideA,
StrideB, StrideB,
{StrideD0, StrideD1}, {StrideD0, StrideD1},
StrideE,
StrideH, StrideH,
epsilon,
a_element_op, a_element_op,
b_element_op, b_element_op,
cde_element_op, cde_element_op,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment