Commit 2fc2a189 authored by rocking's avatar rocking
Browse files

Refine naming

parent 3bd10f4e
...@@ -57,7 +57,7 @@ using UnarySquareElementOp = ...@@ -57,7 +57,7 @@ using UnarySquareElementOp =
using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>; using DxsInElementOp = ck::Tuple<UnaryIdenticElementOp, UnarySquareElementOp>;
using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>; using DxsOutElementOp = ck::Tuple<UnaryDivElementOp, UnaryDivElementOp>;
using DGlobalMemOp = using DxsGlobalMemOp =
ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd, ck::InMemoryDataOperationEnumSequence<ck::InMemoryDataOperationEnum::AtomicAdd,
ck::InMemoryDataOperationEnum::AtomicAdd>; ck::InMemoryDataOperationEnum::AtomicAdd>;
...@@ -70,7 +70,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_ ...@@ -70,7 +70,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
//######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector| //######| | | | Type| Type| Type| DataType| DataType| DataType| Type Tuple| Elementwise| Elementwise| Elementwise| Reduce| | | MemoryData| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MPerBlock| ScalarPerVector| ThreadClusterLengths| SrcDstScalarPerVector| SrcDstScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock| //######| | | | | | | | | | | Operation| Operation| Operation| Operation| | | Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| _MPerBlock_NPerBlock| _NPerBlock| _MPerBlock|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>; < Row, Col, Row, F16, F16, F16, F32, F32, F32, DPtrsGlobal, AElementOp, BElementOp, CElementOp, DxsReduceOp, DxsInElementOp, DxsOutElementOp, DxsGlobalMemOp, GemmSpecialization, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8, S<64, 4>, 4, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType, using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,
...@@ -274,10 +274,10 @@ int main() ...@@ -274,10 +274,10 @@ int main()
reduceMeanSquare_device_buf.SetZero(); reduceMeanSquare_device_buf.SetZero();
// Prepare LayerNorm // Prepare LayerNorm
auto layerNorm = DeviceNormalizeInstance{}; auto normalize = DeviceNormalizeInstance{};
auto layerNorm_invoker_ptr = layerNorm.MakeInvokerPointer(); auto normalize_invoker_ptr = normalize.MakeInvokerPointer();
auto layerNorm_argument = auto normalize_argument =
layerNorm.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(), normalize.MakeArgumentPointer(c_device_buf.GetDeviceBuffer(),
reduceMean_device_buf.GetDeviceBuffer(), reduceMean_device_buf.GetDeviceBuffer(),
reduceMeanSquare_device_buf.GetDeviceBuffer(), reduceMeanSquare_device_buf.GetDeviceBuffer(),
gamma_device_buf.GetDeviceBuffer(), gamma_device_buf.GetDeviceBuffer(),
...@@ -292,7 +292,7 @@ int main() ...@@ -292,7 +292,7 @@ int main()
{StrideC, 1}, {StrideC, 1},
NormalizeFunctor{}); NormalizeFunctor{});
if(!layerNorm.IsSupportedArgument(layerNorm_argument.get())) if(!normalize.IsSupportedArgument(normalize_argument.get()))
{ {
throw std::runtime_error("The runtime parameters seems not supported by the " throw std::runtime_error("The runtime parameters seems not supported by the "
"Device5AryElementwise_Xdl_CShuffle instance, exiting!"); "Device5AryElementwise_Xdl_CShuffle instance, exiting!");
...@@ -300,7 +300,7 @@ int main() ...@@ -300,7 +300,7 @@ int main()
// run kernel // run kernel
gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel}); gemmReduce_invoker.Run(gemmReduce_argument, StreamConfig{nullptr, time_kernel});
layerNorm_invoker_ptr->Run(layerNorm_argument.get(), StreamConfig{nullptr, time_kernel}); normalize_invoker_ptr->Run(normalize_argument.get(), StreamConfig{nullptr, time_kernel});
bool pass = true; bool pass = true;
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment