Commit 8166d875 authored by rocking's avatar rocking
Browse files

Modify test, instance and client example

parent 12673f3f
......@@ -81,8 +81,8 @@ int main(int argc, char* argv[])
auto argument_ptr = op_ptr->MakeArgumentPointer({M, N}, // lengths
{Stride, 1}, // xStrides
{1}, // gammaStrides
{1}, // betaStrides
{0, 1}, // gammaStrides
{0, 1}, // betaStrides
{Stride, 1}, // yStrides
{1}, // reduceDims
1e-4,
......
......@@ -20,18 +20,18 @@ using Pass = ck::tensor_operation::element_wise::PassThrough;
template <index_t Rank, index_t Reduce>
using device_layernorm_f16_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 8, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 8, 8, 8>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>, // fallback kernel
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 8, 1, 8, 1, 8, 8>,
DeviceLayernormImpl<F16, F16, F16, F32, F16, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 8, 1, 8, 1, 8, 8>
// clang-format on
>;
......
......@@ -20,16 +20,16 @@ template <index_t Rank, index_t Reduce>
using device_layernorm_f32_instances = std::tuple<
// clang-format off
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 2, 2, 2>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 4, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 4, 4, 4>
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 1, 1, 1, 1, 1, 1>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 2, 1, 2, 1, 2, 2>, // fallback kernel
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 8, 32, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 4, 64, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 2, 128, 1, 32, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 8, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 16, 1, 4, 1, 4, 1, 4, 4>,
DeviceLayernormImpl<F32, F32, F32, F32, F32, Pass, Rank, Reduce, 256, 1, 256, 1, 32, 1, 4, 1, 4, 1, 4, 4>
// clang-format on
>;
......
......@@ -14,15 +14,15 @@ class TestLayernormFP16 : public ck::TestLayernorm<Tuple>
// clang-format off
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<8>, I<8>, I<8>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim , GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>,
std::tuple<ck::half_t, ck::half_t, ck::half_t, float, ck::half_t, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<8>, I<1>, I<8>, I<1>, I<8>, I<8>>
>;
// clang-format on
TYPED_TEST_SUITE(TestLayernormFP16, KernelTypes);
......
......@@ -14,15 +14,15 @@ class TestLayernormFP32 : public ck::TestLayernorm<Tuple>
// clang-format off
using KernelTypes = ::testing::Types<
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, , GammaSrcVectorSize, BetaSrcVectorSize, YDstVectorSize>
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<4>, I<4>, I<4>>
// XDataType, GammaDataType, BetaDataType, AccDataType, YDataType, Rank, NumReduceDim, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, YDstVectorSize>
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<8>, I<32>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<4>, I<64>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<2>, I<128>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<1>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>,
std::tuple<float, float, float, float, float, I<2>, I<1>, I<256>, I<1>, I<256>, I<2>, I<8>, I<1>, I<4>, I<1>, I<4>, I<1>, I<4>, I<4>>
>;
// clang-format on
TYPED_TEST_SUITE(TestLayernormFP32, KernelTypes);
......
......@@ -48,9 +48,11 @@ class TestLayernorm : public ::testing::Test
static constexpr index_t KThreadSliceSize = std::tuple_element_t<11, Tuple>{}.value;
static constexpr index_t XYSrcVectorDim = std::tuple_element_t<12, Tuple>{}.value;
static constexpr index_t XSrcVectorSize = std::tuple_element_t<13, Tuple>{}.value;
static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<14, Tuple>{}.value;
static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
static constexpr index_t YDstVectorSize = std::tuple_element_t<16, Tuple>{}.value;
static constexpr index_t GammaSrcVectorDim = std::tuple_element_t<14, Tuple>{}.value;
static constexpr index_t GammaSrcVectorSize = std::tuple_element_t<15, Tuple>{}.value;
static constexpr index_t BetaSrcVectorDim = std::tuple_element_t<16, Tuple>{}.value;
static constexpr index_t BetaSrcVectorSize = std::tuple_element_t<17, Tuple>{}.value;
static constexpr index_t YDstVectorSize = std::tuple_element_t<18, Tuple>{}.value;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
......@@ -78,7 +80,9 @@ class TestLayernorm : public ::testing::Test
KThreadSliceSize,
XYSrcVectorDim,
XSrcVectorSize,
GammaSrcVectorDim,
GammaSrcVectorSize,
BetaSrcVectorDim,
BetaSrcVectorSize,
YDstVectorSize>;
......@@ -115,10 +119,8 @@ class TestLayernorm : public ::testing::Test
auto argument_ptr = device_instance.MakeArgumentPointer(
lengths,
std::vector<ck::index_t>{x.mDesc.GetStrides().begin(), x.mDesc.GetStrides().end()},
std::vector<ck::index_t>{gamma.mDesc.GetStrides().begin(),
gamma.mDesc.GetStrides().end()},
std::vector<ck::index_t>{beta.mDesc.GetStrides().begin(),
beta.mDesc.GetStrides().end()},
{0, 1},
{0, 1},
std::vector<ck::index_t>{y.mDesc.GetStrides().begin(), y.mDesc.GetStrides().end()},
reduceDims,
1e-4,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment