Unverified Commit d2f6d04a authored by Chaitanya Sri Krishna Lolla's avatar Chaitanya Sri Krishna Lolla Committed by GitHub
Browse files

Enable mlp_cuda extension. (#28)

* enable mlp cuda

* add setup changes and tests

* skip the unit tests

* updated conditions for empty array

* removed hip platform conditions
parent 4116ed66
...@@ -4,6 +4,14 @@ ...@@ -4,6 +4,14 @@
#include <stdio.h> #include <stdio.h>
int SizeTToInt(size_t data)
{
if (data > std::numeric_limits<int>::max()) {
throw std::runtime_error("Invalid cast.");
}
return static_cast<int>(data);
}
size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features); size_t get_mlp_reserved_space(int batch_size, int num_layers, const int* output_features);
template <typename T> template <typename T>
...@@ -62,7 +70,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at ...@@ -62,7 +70,7 @@ std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at
// create output/workspace tensor // create output/workspace tensor
// TODO(deyuf): just get buffer? // TODO(deyuf): just get buffer?
auto out = at::empty({batch_size, output_features.back()}, inputs[0].type()); auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
auto reserved_space = at::empty({reserved_size}, inputs[0].type()); auto reserved_space = at::empty({SizeTToInt(reserved_size)}, inputs[0].type());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] { AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
std::vector<scalar_t*> w_ptr; std::vector<scalar_t*> w_ptr;
...@@ -134,7 +142,7 @@ std::vector<at::Tensor> mlp_backward( ...@@ -134,7 +142,7 @@ std::vector<at::Tensor> mlp_backward(
get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data()); get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
// auto work_space = at::empty({work_size*4}, at::kByte); // auto work_space = at::empty({work_size*4}, at::kByte);
auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type()); auto work_space = at::empty({SizeTToInt(work_size / sizeof(scalar_t))}, inputs[0].type());
auto result = mlp_bp<scalar_t>( auto result = mlp_bp<scalar_t>(
inputs[0].data_ptr<scalar_t>(), inputs[0].data_ptr<scalar_t>(),
......
...@@ -67,6 +67,33 @@ cublasStatus_t mlp_gemm( ...@@ -67,6 +67,33 @@ cublasStatus_t mlp_gemm(
const float* beta, const float* beta,
double* C, double* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f64_r,
lda,
B,
rocblas_datatype_f64_r,
ldb,
beta,
C,
rocblas_datatype_f64_r,
ldc,
C,
rocblas_datatype_f64_r,
ldc,
rocblas_datatype_f64_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -87,6 +114,7 @@ cublasStatus_t mlp_gemm( ...@@ -87,6 +114,7 @@ cublasStatus_t mlp_gemm(
ldc, ldc,
CUDA_R_64F, CUDA_R_64F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP32 Wrapper around cublas GEMMEx // FP32 Wrapper around cublas GEMMEx
...@@ -105,6 +133,34 @@ cublasStatus_t mlp_gemm( ...@@ -105,6 +133,34 @@ cublasStatus_t mlp_gemm(
const float* beta, const float* beta,
float* C, float* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f32_r,
lda,
B,
rocblas_datatype_f32_r,
ldb,
beta,
C,
rocblas_datatype_f32_r,
ldc,
C,
rocblas_datatype_f32_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -125,6 +181,7 @@ cublasStatus_t mlp_gemm( ...@@ -125,6 +181,7 @@ cublasStatus_t mlp_gemm(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT); CUBLAS_GEMM_DEFAULT);
#endif
} }
// FP16 Tensor core wrapper around cublas GEMMEx // FP16 Tensor core wrapper around cublas GEMMEx
...@@ -143,6 +200,33 @@ cublasStatus_t mlp_gemm( ...@@ -143,6 +200,33 @@ cublasStatus_t mlp_gemm(
float* beta, float* beta,
at::Half* C, at::Half* C,
int ldc) { int ldc) {
#ifdef __HIP_PLATFORM_HCC__
return rocblas_gemm_ex(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
rocblas_datatype_f16_r,
lda,
B,
rocblas_datatype_f16_r,
ldb,
beta,
C,
rocblas_datatype_f16_r,
ldc,
C,
rocblas_datatype_f16_r,
ldc,
rocblas_datatype_f32_r,
rocblas_gemm_algo_standard,
0,
0);
#else
return cublasGemmEx( return cublasGemmEx(
handle, handle,
transa, transa,
...@@ -163,6 +247,7 @@ cublasStatus_t mlp_gemm( ...@@ -163,6 +247,7 @@ cublasStatus_t mlp_gemm(
ldc, ldc,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP); CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
} }
// Bias ADD. Assume input X is [features x batch size], column major. // Bias ADD. Assume input X is [features x batch size], column major.
......
...@@ -223,7 +223,13 @@ if "--cuda_ext" in sys.argv: ...@@ -223,7 +223,13 @@ if "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
else: else:
print ("INFO: Skipping MLP extension") print ("INFO: Building MLP extension")
ext_modules.append(
CUDAExtension(name='mlp_cuda',
sources=['csrc/mlp.cpp',
'csrc/hip/mlp_hip.hip'],
extra_compile_args={'cxx' : ['-O3'] + version_dependent_macros,
'nvcc' : []}))
if "--bnp" in sys.argv: if "--bnp" in sys.argv:
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CUDAExtension
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from torch import nn from torch import nn
from apex.mlp import MLP from apex.mlp import MLP
from apex.testing.common_utils import skipIfRocm
batch_size = 1024 batch_size = 1024
mlp_sizes = [480, 1024, 1024, 512, 256, 1] mlp_sizes = [480, 1024, 1024, 512, 256, 1]
...@@ -17,6 +18,7 @@ class TestMLP(unittest.TestCase): ...@@ -17,6 +18,7 @@ class TestMLP(unittest.TestCase):
def test_creation(self): def test_creation(self):
MLP(mlp_sizes) MLP(mlp_sizes)
@skipIfRocm
def test_numeric(self): def test_numeric(self):
mlp = MLP(mlp_sizes).cuda() mlp = MLP(mlp_sizes).cuda()
...@@ -51,6 +53,7 @@ class TestMLP(unittest.TestCase): ...@@ -51,6 +53,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
@skipIfRocm
def test_no_bias(self): def test_no_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']: for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda() mlp = MLP(mlp_sizes, bias=False, activation=use_activation).cuda()
...@@ -88,6 +91,7 @@ class TestMLP(unittest.TestCase): ...@@ -88,6 +91,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].weight.grad.detach().cpu().numpy(), ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=100) atol=1e-7, rtol=100)
@skipIfRocm
def test_with_bias(self): def test_with_bias(self):
for use_activation in ['none', 'relu', 'sigmoid']: for use_activation in ['none', 'relu', 'sigmoid']:
mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda() mlp = MLP(mlp_sizes, bias=True, activation=use_activation).cuda()
...@@ -130,6 +134,7 @@ class TestMLP(unittest.TestCase): ...@@ -130,6 +134,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].bias.grad.detach().cpu().numpy(), ref_mlp[0].bias.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
@skipIfRocm
def test_no_grad(self): def test_no_grad(self):
mlp = MLP(mlp_sizes).cuda() mlp = MLP(mlp_sizes).cuda()
...@@ -160,7 +165,7 @@ class TestMLP(unittest.TestCase): ...@@ -160,7 +165,7 @@ class TestMLP(unittest.TestCase):
ref_mlp[0].weight.grad.detach().cpu().numpy(), ref_mlp[0].weight.grad.detach().cpu().numpy(),
atol=1e-7, rtol=1e-5) atol=1e-7, rtol=1e-5)
@skipIfRocm
def test_performance_half(self): def test_performance_half(self):
mlp = MLP(mlp_sizes).cuda().half() mlp = MLP(mlp_sizes).cuda().half()
......
...@@ -9,7 +9,6 @@ ROCM_BLACKLIST = [ ...@@ -9,7 +9,6 @@ ROCM_BLACKLIST = [
'run_fused_layer_norm', 'run_fused_layer_norm',
'run_pyprof_nvtx', 'run_pyprof_nvtx',
'run_pyprof_data', 'run_pyprof_data',
'run_mlp'
] ]
runner = unittest.TextTestRunner(verbosity=2) runner = unittest.TextTestRunner(verbosity=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment