test_transpose.cu 2.87 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transpose.h>
#include <transformer_engine/logging.h>
#include <gtest/gtest.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <memory>
#include <iostream>
#include <iomanip>
#include <random>
#include <cstring>
#include "../test_common.h"

using namespace transformer_engine;

namespace {

template <typename Type>
void compute_ref(const Type *data,  Type *output,
                 const size_t N, const size_t H) {
  for (size_t i = 0; i < N; ++i) {
    for (size_t j = 0; j < H; ++j) {
      output[j * N + i] = data[i * H + j];
    }
  }
}

template <typename Type>
void performTest(const size_t N, const size_t H) {
  using namespace test;

  DType dtype = TypeInfo<Type>::dtype;

  Tensor input({ N, H }, dtype);
  Tensor output({ H, N }, dtype);

  std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);

44
  fillUniform(&input);
Przemek Tredak's avatar
Przemek Tredak committed
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94

  nvte_transpose(input.data(), output.data(), 0);

  compute_ref<Type>(input.cpu_dptr<Type>(), ref_output.get(), N, H);

  cudaDeviceSynchronize();
  auto err = cudaGetLastError();
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
  auto [atol, rtol] = getTolerances(dtype);
  compareResults("output", output, ref_output.get(), atol, rtol);
}

std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
                                                     {768, 1024},
                                                     {256, 65536},
                                                     {65536, 128},
                                                     {256, 256},
                                                     {120, 2080},
                                                     {8, 8}};
}  // namespace

class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
                                                              std::pair<size_t, size_t>>> {};

TEST_P(TTestSuite, TestTranspose) {
  using namespace transformer_engine;
  using namespace test;

  const DType type = std::get<0>(GetParam());
  const auto size = std::get<1>(GetParam());

  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
    performTest<T>(size.first, size.second);
  );
}



INSTANTIATE_TEST_SUITE_P(
  OperatorTest,
  TTestSuite,
  ::testing::Combine(
      ::testing::ValuesIn(test::all_fp_types),
      ::testing::ValuesIn(test_cases)),
  [](const testing::TestParamInfo<TTestSuite::ParamType>& info) {
    std::string name = test::typeName(std::get<0>(info.param)) + "X" +
                       std::to_string(std::get<1>(info.param).first) + "X" +
                       std::to_string(std::get<1>(info.param).second);
    return name;
  });