"git@developer.sourcefind.cn:kecinstone/2024-pra-vllm.git" did not exist on "ba84b8728a8d0a766a636b30661836c30b17fbe6"
test_nvrtc.cpp 2.91 KB
Newer Older
Tim Moon's avatar
Tim Moon committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Tim Moon's avatar
Tim Moon committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <stdexcept>
#include <vector>

#include <gtest/gtest.h>

#include "util/rtc.h"

using namespace transformer_engine;

TEST(UtilTest, NVRTC) {
  if (!rtc::is_enabled()) {
    GTEST_SKIP() << "NVRTC not enabled, skipping tests";
  }

  // GPU data buffer
  int *device_buffer;
  std::vector<int> host_buffer(2);
  cudaMalloc((void**)&device_buffer, 2*sizeof(int));  // NOLINT(*)
  cudaMemset(device_buffer, 0, 2*sizeof(int));

  // CUDA kernel implementations
  const char code1[] = R"code(
#include <cuda_runtime.h>
__global__ void my_kernel(int2 *data) {
  data->x = 123;
  data->y = -456;
}
)code";
  const char code2[] = R"code(
#include "utils.cuh"
__global__ void my_kernel(uint32_t *data) {
  data[0] = 789;
  data[1] = 12;
}
)code";

  // Make sure kernels are not available
  auto& nvrtc_manager = rtc::KernelManager::instance();
  EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel1"));
  EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
  EXPECT_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
                                    device_buffer),
               std::runtime_error);
  EXPECT_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0,
                                    device_buffer),
               std::runtime_error);

  // Compile and run first kernel
  EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel1",
                                        "my_kernel",
                                        code1,
                                        "test_nvrtc_kernel1.cu"));
  EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
  EXPECT_FALSE(nvrtc_manager.is_compiled("my gtest kernel2"));
  EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel1", 1, 1, 0, 0,
                                       device_buffer));
  EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
                       cudaMemcpyDeviceToHost),
            cudaSuccess);
  EXPECT_EQ(host_buffer[0], 123);
  EXPECT_EQ(host_buffer[1], -456);

  // Compile and run second kernel
  EXPECT_NO_THROW(nvrtc_manager.compile("my gtest kernel2",
                                        "my_kernel",
                                        code2,
                                        "test_nvrtc_kernel2.cu"));
  EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel1"));
  EXPECT_TRUE(nvrtc_manager.is_compiled("my gtest kernel2"));
  EXPECT_NO_THROW(nvrtc_manager.launch("my gtest kernel2", 1, 1, 0, 0, device_buffer));
  EXPECT_EQ(cudaMemcpy(host_buffer.data(), device_buffer, 2*sizeof(int),
                       cudaMemcpyDeviceToHost),
            cudaSuccess);
  EXPECT_EQ(host_buffer[0], 789);
  EXPECT_EQ(host_buffer[1], 12);
}