utils.cpp 1019 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>

#include "utils.h"

namespace transformer_engine {
namespace jax {

int GetCudaRuntimeVersion() {
    int ver = 0;
    NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&ver));
    return ver;
}

int GetDeviceComputeCapability(int gpu_id) {
    int max_num_gpu = 0;
    NVTE_CHECK_CUDA(cudaGetDeviceCount(&max_num_gpu));
    assert(gpu_id < max_num_gpu);

    int major = 0;
    NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, gpu_id));

    int minor = 0;
    NVTE_CHECK_CUDA(cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, gpu_id));

    int gpu_arch = major * 10 + minor;
    return gpu_arch;
}

}  // namespace jax
}  // namespace transformer_engine