hip_compat.h 1.46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#pragma once

/**
 * HIP 到 CUDA 的兼容层
 * 
 * 使用 nvcc 编译时,自动将 HIP API 映射到 CUDA API
 * 使用 hipcc 编译时,使用原生 HIP 头文件
 */

#if defined(__NVCC__) || defined(__CUDACC__)

#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdio.h>

// Runtime API 映射
#define hipMalloc cudaMalloc
#define hipFree cudaFree
#define hipMemcpy cudaMemcpy
#define hipMemcpyHostToDevice cudaMemcpyHostToDevice
#define hipMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define hipMemset cudaMemset
#define hipDeviceSynchronize cudaDeviceSynchronize
#define hipGetDeviceProperties cudaGetDeviceProperties
#define hipGetErrorString cudaGetErrorString

// Event API 映射
#define hipEvent_t cudaEvent_t
#define hipEventCreate cudaEventCreate
#define hipEventDestroy cudaEventDestroy
#define hipEventRecord cudaEventRecord
#define hipEventSynchronize cudaEventSynchronize
#define hipEventElapsedTime cudaEventElapsedTime

// 数据类型映射
#define hipDeviceProp_t cudaDeviceProp
#define hipError_t cudaError_t
#define hipSuccess cudaSuccess
// CUDA 使用 __nv_bfloat16,HIP 使用 hip_bfloat16
typedef __nv_bfloat16 hip_bfloat16;

// Shuffle 指令映射
// CUDA 9.0+ 需要使用带 _sync 后缀的版本,并传入 warp mask
// 0xffffffff 表示整个 warp 的所有线程都参与
#ifndef __shfl_down
#define __shfl_down(val, offset) __shfl_down_sync(0xffffffff, val, offset)
#endif

#else

#include <hip/hip_runtime.h>
#include <hip/hip_bfloat16.h>

#endif