strided_batched_gemm.h 6.05 KB
Newer Older
1
#include <iostream>
2
#include <vector>
3
4
5

#include <cuda.h>
#include <cuda_fp16.h>
6
7
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
8

9
//#include <ATen/ATen.h>
10
#include <ATen/cuda/CUDAContext.h>
11
#include <ATen/cuda/Exceptions.h>
12

13
14
15
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
16
17
18

// symbol to be automatically resolved by PyTorch libs

19
20
21
22
23
24
25
26
27
28
29
rocblas_datatype a_type       = rocblas_datatype_f16_r;
rocblas_datatype b_type       = rocblas_datatype_f16_r;
rocblas_datatype c_type       = rocblas_datatype_f16_r;
rocblas_datatype d_type       = rocblas_datatype_f16_r;
rocblas_datatype compute_type       = rocblas_datatype_f32_r;

rocblas_gemm_algo algo           = rocblas_gemm_algo_standard;
int32_t           solution_index = 0;
rocblas_int       flags          = 0;


30
cublasOperation_t convertTransToCublasOperation(char trans) {
31
32
33
34
35
36
  if (trans == 't')
    return CUBLAS_OP_T;
  else if (trans == 'n')
    return CUBLAS_OP_N;
  else if (trans == 'c')
    return CUBLAS_OP_C;
37
  else {
38
    AT_ERROR("trans must be one of: t, n, c");
39
40
41
42
    return CUBLAS_OP_T;
  }
}

43
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
44
                    float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
45
                    float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
46
47
    cublasOperation_t opa = convertTransToCublasOperation(transa);
    cublasOperation_t opb = convertTransToCublasOperation(transb);
48

49
    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
50
51
    cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();
    cublasSetStream(handle, stream);
52
53
    float fAlpha = alpha;
    float fBeta = beta;
54
55
    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
    TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
56
                                     opa, opb, (int)m, (int)n, (int)k,
57
58
59
60
61
                                     (void*)&fAlpha, a, a_type, (int)lda, strideA,
                                     b, b_type, (int)ldb, strideB,
                                     (void*)&fBeta, c, c_type, (int)ldc, strideC,
				     d, d_type, int(ldd), strideD,
                                     (int)batchCount, compute_type, algo, solution_index, flags));
62
63
}

64
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
65
                           float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
66
                           float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
67
  auto stream = c10::cuda::getCurrentCUDAStream();
68
  if        ( (transa == 't') && (transb == 'n') ) { 
69
70
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
71
  } else if ( (transa == 'n') && (transb == 'n') ) {
72
73
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
74
  } else if ( (transa == 'n') && (transb == 't') ) {
75
76
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, algo, flags); }
77
78
79
80
81
  } else {
    AT_ASSERTM(false, "TransA and TransB are invalid");
  }
}

82
83
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
                    int64_t *lda, int64_t *ldb, int64_t *ldc) {
84
85
86
  int transa_ = ((transa == 't') || (transa == 'T'));
  int transb_ = ((transb == 't') || (transb == 'T'));

87
88
89
  // Note: leading dimensions generally are checked that they are > 0 and at
  // least as big the result requires (even if the value won't be used).
  if (n <= 1)
90
91
    *ldc = std::max<int64_t>(m, 1);

92
93
  if (transa_) {
    if (m <= 1)
94
      *lda = std::max<int64_t>(k, 1);
95
96
  } else {
    if (k <= 1)
97
98
99
      *lda = std::max<int64_t>(m, 1);
  }

100
101
  if (transb_) {
    if (k <= 1)
102
      *ldb = std::max<int64_t>(n, 1);
103
104
  } else {
    if (n <= 1)
105
106
107
108
      *ldb = std::max<int64_t>(k, 1);
  }
}

109
110
111
112
void HgemmStridedBatched(char transa, char transb, long m,
                         long n, long k, float alpha, const half *a, long lda,
                         long strideA, const half *b, long ldb, long strideB,
                         float beta, half *c, long ldc, long strideC,
113
114
                         half *d, long ldd, long strideD, long batchCount) {

115
116
  if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
      (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
117
118

  {
119
120
121
122
    AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
             "batchCount"
             "with the bound [val] <= %d",
             INT_MAX);
123
124
125
126
  }

  adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

127
128
129
  // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
  //                       b, ldb, strideB, beta, c, ldc, strideC, batchCount);
  gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, 
130
                        b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, flags);
131
132
}

133