Commit b5d7745d authored by flyingdown's avatar flyingdown
Browse files

merge mirror master

parents 03204b84 3ba7192d
......@@ -85,9 +85,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
output_lin_dim,
batches,
embed_dim,
......@@ -108,7 +108,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
// MatMul1 of Dot-Product Attention Plus scaling by 1/Sqrt(head size)
gemm_switch_fp32accum( a_layout_t,
......@@ -188,9 +188,9 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
flags);
// Output Linear
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_T),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
......@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(bool use_time_mask, bool is_training,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {input_lin_results, softmax_results, dropout_results,
......@@ -289,107 +289,9 @@ std::vector<torch::Tensor> bwd_cuda(
#endif
// Output Linear Dgrad
if (use_fp16) {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
embed_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(output_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_lin_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
embed_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(matmul2_results.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(output_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(output_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// MatMul2 Dgrad1
gemm_switch_fp32accum( a_layout_t,
b_layout_n,
k_seq_len,
q_seq_len,
head_dim,
h_alpha,
static_cast<const half*>(v_lin_results_ptr),
lead_dim,
batch_stride,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
h_beta,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
attn_batches,
flags);
// Matmul2 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_alpha,
static_cast<const half*>(output_lin_grads.data_ptr()),
head_dim*attn_batches,
head_dim,
static_cast<const half*>(dropout_results.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
v_lin_grads_ptr,
lead_dim,
batch_stride,
v_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
} else {
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
embed_dim,
......@@ -410,12 +312,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
// Output Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
embed_dim,
batches,
......@@ -484,8 +386,6 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches,
flags);
}
// Apply Dropout Mask and Scale by Dropout Probability
apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(matmul2_grads.data_ptr()),
......@@ -504,104 +404,6 @@ std::vector<torch::Tensor> bwd_cuda(
assert(softmax_success);
// Matmul1 Dgrad1
if (use_fp16) {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
q_seq_len,
k_seq_len,
h_scale,
k_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
q_lin_grads_ptr,
lead_dim,
batch_stride,
q_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Matmul1 Dgrad2
gemm_switch_fp32accum( a_layout_n,
b_layout_t,
head_dim,
k_seq_len,
q_seq_len,
h_scale,
q_lin_results_ptr,
lead_dim,
batch_stride,
static_cast<half*>(matmul2_grads.data_ptr()),
k_seq_len,
k_seq_len*q_seq_len,
h_beta,
k_lin_grads_ptr,
lead_dim,
batch_stride,
k_lin_grads_ptr,
lead_dim,
batch_stride,
attn_batches,
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
embed_dim,
batches,
output_lin_dim,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(input_weights.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
embed_dim,
output_lin_dim,
batches,
static_cast<const void*>(&h_alpha),
static_cast<const void*>(inputs.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<const void*>(q_lin_grads_ptr),
rocblas_datatype_f16_r,
output_lin_dim,
static_cast<const void*>(&h_beta),
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
static_cast<void*>(input_weight_grads.data_ptr()),
rocblas_datatype_f16_r,
embed_dim,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
} else {
gemm_switch_fp32accum( a_layout_n,
b_layout_n,
head_dim,
......@@ -648,9 +450,9 @@ std::vector<torch::Tensor> bwd_cuda(
flags);
// Input Linear Dgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_N),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_N,
embed_dim,
batches,
output_lin_dim,
......@@ -671,12 +473,12 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags)));
flags));
// Input Linear Wgrad
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_ex((rocblas_handle) handle,
hipOperationToRocOperation(CUBLAS_OP_N),
hipOperationToRocOperation(CUBLAS_OP_T),
TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
embed_dim,
output_lin_dim,
batches,
......@@ -698,8 +500,6 @@ std::vector<torch::Tensor> bwd_cuda(
rocblas_gemm_algo_standard /*algo*/,
0 /*solution_index*/,
flags));
}
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return {
......
......@@ -7,8 +7,6 @@
//#include <cuda_profiler_api.h>
#include <cuda_runtime.h>
#include <rocblas/rocblas.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
......@@ -47,52 +45,6 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
}
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float alpha, const half *a, long lda, long strideA, const half *b, long ldb, long strideB,
float beta, half *c, long ldc, long strideC, half *d, long ldd, long strideD, long batchCount, rocblas_gemm_algo algo, rocblas_int flags) {
......@@ -105,13 +57,13 @@ void RocblasStridedBatchedGemm(char transa, char transb, long m, long n, long k,
float fAlpha = alpha;
float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocBLASStatusToHIPStatus(rocblas_gemm_strided_batched_ex((rocblas_handle)handle,
hipOperationToRocOperation(opa), hipOperationToRocOperation(opb), (int)m, (int)n, (int)k,
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, rocblas_datatype_f16_r /*a_type*/, (int)lda, strideA,
b, rocblas_datatype_f16_r /*b_type*/, (int)ldb, strideB,
(void*)&fBeta, c, rocblas_datatype_f16_r /*c_type*/, (int)ldc, strideC,
d, rocblas_datatype_f16_r /*d_type*/, int(ldd), strideD,
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags)));
(int)batchCount, rocblas_datatype_f32_r /*compute_type*/, algo, 0 /*solution_index*/, flags));
}
void gemm_switch_fp32accum(char transa, char transb, long m, long n, long k,
......
......@@ -10,22 +10,10 @@
#include <cublas_v2.h>
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
#endif
// until we use hipblas v2
// hipify correctly maps things like CUDA_R_16F to HIP_R_16F,
// however hipblas v1 is still using its custom type
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_16F HIPBLAS_R_16F
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t gemm_bias(
cublasHandle_t handle,
......@@ -42,6 +30,33 @@ cublasStatus_t gemm_bias(
const float* beta,
double* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -62,6 +77,7 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_64F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP32 Wrapper around cublas GEMMEx
......@@ -80,6 +96,34 @@ cublasStatus_t gemm_bias(
const float* beta,
float* C,
int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
......@@ -100,6 +144,7 @@ cublasStatus_t gemm_bias(
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT);
#endif
}
// FP16 Tensor core wrapper around cublas GEMMEx
......@@ -118,6 +163,7 @@ cublasStatus_t gemm_bias(
const float* beta,
at::Half* C,
int ldc) {
<<<<<<< HEAD
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
......@@ -163,6 +209,56 @@ cublasStatus_t gemm_bias(
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
=======
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_16F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
>>>>>>> mirror/master
}
......
......@@ -13,8 +13,6 @@
#include <cuda_runtime.h>
#include "utils.h"
#include <rocblas/rocblas.h>
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION >= 11000
// includes cublaslt
#include <cublasLt.h>
......@@ -62,52 +60,6 @@ __device__ __inline__ float sigmoid(float a) {
return (retf);
}
// needed to work around calling rocblas API instead of hipblas API
static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op)
{
switch(op)
{
case HIPBLAS_OP_N:
return rocblas_operation_none;
case HIPBLAS_OP_T:
return rocblas_operation_transpose;
case HIPBLAS_OP_C:
return rocblas_operation_conjugate_transpose;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error)
{
switch(error)
{
case rocblas_status_size_unchanged:
case rocblas_status_size_increased:
case rocblas_status_success:
case rocblas_status_continue:
return HIPBLAS_STATUS_SUCCESS;
case rocblas_status_invalid_handle:
return HIPBLAS_STATUS_NOT_INITIALIZED;
case rocblas_status_not_implemented:
case rocblas_status_excluded_from_build:
return HIPBLAS_STATUS_NOT_SUPPORTED;
case rocblas_status_invalid_pointer:
case rocblas_status_invalid_size:
case rocblas_status_invalid_value:
case rocblas_status_size_query_mismatch:
return HIPBLAS_STATUS_INVALID_VALUE;
case rocblas_status_memory_error:
return HIPBLAS_STATUS_ALLOC_FAILED;
case rocblas_status_internal_error:
case rocblas_status_perf_degraded:
case rocblas_status_check_numerics_fail:
return HIPBLAS_STATUS_INTERNAL_ERROR;
case rocblas_status_arch_mismatch:
return HIPBLAS_STATUS_ARCH_MISMATCH;
}
AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM");
}
// FP64 Wrapper around cublas GEMMEx
cublasStatus_t mlp_gemm(
cublasHandle_t handle,
......@@ -126,10 +78,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
......@@ -150,7 +102,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
flag));
flag);
#else
return cublasGemmEx(
handle,
......@@ -193,10 +145,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
......@@ -217,7 +169,7 @@ cublasStatus_t mlp_gemm(
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
flag));
flag);
#else
return cublasGemmEx(
......@@ -261,39 +213,10 @@ cublasStatus_t mlp_gemm(
int ldc,
int flag) {
#ifdef __HIP_PLATFORM_HCC__
if (parseEnvVarFlag("APEX_ROCBLAS_GEMM_ALLOW_HALF")) {
half h_alpha = __float2half(*alpha);
half h_beta = __float2half(*beta);
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
m,
n,
k,
/* alpha */ &h_alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
/* beta */ &h_beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
/* rocblas_datatype_f32_r */ rocblas_datatype_f16_r,
rocblas_gemm_algo_standard,
0,
flag);
} else {
return rocBLASStatusToHIPStatus(rocblas_gemm_ex(
(rocblas_handle) handle,
hipOperationToRocOperation(transa),
hipOperationToRocOperation(transb),
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
......@@ -315,7 +238,6 @@ cublasStatus_t mlp_gemm(
rocblas_gemm_algo_standard,
0,
flag);
}
#else
return cublasGemmEx(
handle,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment