Commit b5d7745d authored by flyingdown's avatar flyingdown
Browse files

merge mirror master

parents 03204b84 3ba7192d
...@@ -7,8 +7,6 @@ ...@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h> //#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
...@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
} }
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB, float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) { float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
...@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k, ...@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA, (void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB, b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC, (void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD, d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags))); (int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
} }
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k, void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
...@@ -10,22 +10,10 @@ ...@@ -10,22 +10,10 @@
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
#endif #endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias( cublasStatus_t gemm_bias(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias( ...@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias( ...@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias( ...@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias( ...@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias( ...@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
const float* beta, const float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
<<<<<<< HEAD
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha); half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta); half h_beta = __float2half(*beta);
...@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias( ...@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} }
=======
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
>>>>>>> mirror/master
} }
......
...@@ -13,8 +13,6 @@ ...@@ -13,8 +13,6 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "utils.h" #include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000 #if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt // includes cublaslt
#include <cublasLt.h> #include <cublasLt.h>
...@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) { ...@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
return (retf); return (retf);
} }
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx // FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm( cublasStatus_t mlp_gemm(
cublasHandle_t handle, cublasHandle_t handle,
...@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm( ...@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm( ...@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r, rocblas_datatype_f64_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
...@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm( ...@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( return rocblas_gemm_ex(
(rocblas_handle) handle, handle,
hipOperationToRocOperation(transa), transa,
hipOperationToRocOperation(transb), transb,
m, m,
n, n,
k, k,
...@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm( ...@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r, rocblas_datatype_f32_r,
rocblas_gemm_algo_standard, rocblas_gemm_algo_standard,
0, 0,
flag)); flag);
#else #else
return cublasGemmEx( return cublasGemmEx(
...@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm( ...@@ -261,61 +213,31 @@ cublasStatus_t mlp_gemm(
int ldc, int ldc,
int flag) { int flag) {
#ifdef __HIP_PLATFORM_HCC__ #ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) { return rocblas_gemm_ex(
half h_alpha = __float2half(*alpha); handle,
half h_beta = __float2half(*beta); transa,
return rocBLASStatusToHIPStatus(rocblas_gemm_ex( transb,
(rocblas_handle) handle, m,
hipOperationToRocOperation(transa), n,
hipOperationToRocOperation(transb), k,
m, alpha,
n, A,
k, rocblas_datatype_f16_r,
/* alpha */ &h_alpha, lda,
A, B,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
lda, ldb,
B, beta,
rocblas_datatype_f16_r, C,
ldb, rocblas_datatype_f16_r,
/* beta */ &h_beta, ldc,
C, C,
rocblas_datatype_f16_r, rocblas_datatype_f16_r,
ldc, ldc,
C, rocblas_datatype_f32_r,
rocblas_datatype_f16_r, rocblas_gemm_algo_standard,
ldc, 0,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r, flag);
rocblas_gemm_algo_standard,
0,
flag);
} else {
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag);
}
#else #else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment