strided_batched_gemm.cuh 6.4 KB
Newer Older
1
#pragma once
2
#include <iostream>
3
#include <vector>
4
5
6

#include <cuda.h>
#include <cuda_fp16.h>
7
8
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
9

10
//#include <ATen/ATen.h>
11
#include <ATen/cuda/CUDAContext.h>
12
#include <ATen/cuda/Exceptions.h>
13

14
15
16
//#include "cutlass/cutlass.h"
//#include "cutlass/gemm/gemm.h"
//#include "cutlass/gemm/wmma_gemm_traits.h"
17
18

// symbol to be automatically resolved by PyTorch libs
19
20
21
22
/*
rocblas_datatype a_type       = rocblas_datatype_f16_r; // OK
rocblas_datatype b_type       = rocblas_datatype_f16_r; // OK
rocblas_datatype c_type       = rocblas_datatype_f16_r; // OK
23
24
25
26
27
28
rocblas_datatype d_type       = rocblas_datatype_f16_r;
rocblas_datatype compute_type       = rocblas_datatype_f32_r;

rocblas_gemm_algo algo           = rocblas_gemm_algo_standard;
int32_t           solution_index = 0;
rocblas_int       flags          = 0;
29
*/
30

31
namespace {
32
cublasOperation_t convertTransToCublasOperation(char trans) {
33
34
35
36
37
38
  if (trans == 't')
    return CUBLAS_OP_T;
  else if (trans == 'n')
    return CUBLAS_OP_N;
  else if (trans == 'c')
    return CUBLAS_OP_C;
39
  else {
40
    AT_ERROR("trans must be one of: t, n, c");
41
42
43
44
    return CUBLAS_OP_T;
  }
}

45
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
46
                    float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
47
                    float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
48
49
    cublasOperation_t opa = convertTransToCublasOperation(transa);
    cublasOperation_t opb = convertTransToCublasOperation(transb);
50

51
    cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
52
53
    cudaStream_t   stream = at::cuda::getCurrentCUDAStream().stream();
    cublasSetStream(handle, stream);
54
55
    float fAlpha = alpha;
    float fBeta = beta;
56
    //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
57
58
    TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
                                     opa, opb, (int)m, (int)n, (int)k,
59
60
61
62
                                     (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
                                     b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
                                     (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
                                     d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
63
                                     (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
64
65
}

66
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
67
                           float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
68
                           float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_int flags) {
69
  auto stream = c10::cuda::getCurrentCUDAStream();
70
71
72
  if        ( (transa == 't') && (transb == 'n') ) {
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
73
  } else if ( (transa == 'n') && (transb == 'n') ) {
74
75
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
76
  } else if ( (transa == 'n') && (transb == 't') ) {
77
78
    if      (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
    else                                                   { RocblasStridedBatchedGemm(transa, transb, m, n, k, alpha, a, lda, strideA, b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, rocblas_gemm_algo_standard, flags); }
79
80
81
82
83
  } else {
    AT_ASSERTM(false, "TransA and TransB are invalid");
  }
}

84
85
void adjustLdLevel3(char transa, char transb, int64_t m, int64_t n, int64_t k,
                    int64_t *lda, int64_t *ldb, int64_t *ldc) {
86
87
88
  int transa_ = ((transa == 't') || (transa == 'T'));
  int transb_ = ((transb == 't') || (transb == 'T'));

89
90
91
  // Note: leading dimensions generally are checked that they are > 0 and at
  // least as big the result requires (even if the value won't be used).
  if (n <= 1)
92
93
    *ldc = std::max<int64_t>(m, 1);

94
95
  if (transa_) {
    if (m <= 1)
96
      *lda = std::max<int64_t>(k, 1);
97
98
  } else {
    if (k <= 1)
99
100
101
      *lda = std::max<int64_t>(m, 1);
  }

102
103
  if (transb_) {
    if (k <= 1)
104
      *ldb = std::max<int64_t>(n, 1);
105
106
  } else {
    if (n <= 1)
107
108
109
110
      *ldb = std::max<int64_t>(k, 1);
  }
}

111
112
113
114
void HgemmStridedBatched(char transa, char transb, long m,
                         long n, long k, float alpha, const half *a, long lda,
                         long strideA, const half *b, long ldb, long strideB,
                         float beta, half *c, long ldc, long strideC,
115
116
                         half *d, long ldd, long strideD, long batchCount) {

117
118
  if ((m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) ||
      (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX))
119
120

  {
121
122
123
124
    AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, "
             "batchCount"
             "with the bound [val] <= %d",
             INT_MAX);
125
126
127
128
  }

  adjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);

129
130
131
  // gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA,
  //                       b, ldb, strideB, beta, c, ldc, strideC, batchCount);
  gemm_switch_fp32accum(transa, transb, m, n, k, alpha, a, lda, strideA, 
132
                        b, ldb, strideB, beta, c, ldc, strideC, d, ldd, strideD, batchCount, 0 /*flags*/);
133
134
}

135
} // namespace