Commit d650e6a2 authored by Jon Barker's avatar Jon Barker Committed by Jared Casper
Browse files

replace custom layer_norm_cuda with Apex layer_norm_cuda

parent 8dbd0757
...@@ -54,7 +54,7 @@ def load(args): ...@@ -54,7 +54,7 @@ def load(args):
'-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
'--expt-relaxed-constexpr', '--expt-relaxed-constexpr',
'--expt-extended-lambda'] '--expt-extended-lambda']
# Upper triangular softmax. # Upper triangular softmax.
sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp', sources=[srcpath / 'scaled_upper_triang_masked_softmax.cpp',
srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu'] srcpath / 'scaled_upper_triang_masked_softmax_cuda.cu']
...@@ -74,29 +74,6 @@ def load(args): ...@@ -74,29 +74,6 @@ def load(args):
scaled_softmax_cuda = _cpp_extention_load_helper( scaled_softmax_cuda = _cpp_extention_load_helper(
"scaled_softmax_cuda", sources, extra_cuda_flags) "scaled_softmax_cuda", sources, extra_cuda_flags)
# =================================
# Mixed precision fused layer norm.
# =================================
extra_hopper_flags = ['-U__CUDA_NO_HALF_OPERATORS__',
'-U__CUDA_NO_HALF_CONVERSIONS__']
extra_cuda_flags = ['-maxrregcount=50']
sources=[srcpath / 'layer_norm_cuda.cpp',
srcpath / 'layer_norm_cuda_kernel.cu']
fused_mix_prec_layer_norm_cuda = _cpp_extention_load_helper(
"fused_mix_prec_layer_norm_cuda", sources, extra_cuda_flags + extra_hopper_flags)
# =================================
# Fused gradient accumulation to weight gradient computation of linear layer
# =================================
if args.gradient_accumulation_fusion:
sources=[srcpath / 'fused_weight_gradient_dense.cpp',
srcpath / 'fused_weight_gradient_dense.cu']
fused_dense_cuda = _cpp_extention_load_helper(
"fused_dense_cuda", sources, extra_hopper_flags)
def _get_cuda_bare_metal_version(cuda_dir): def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"],
......
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
void wgrad_gemm_accum_fp32(const at::Tensor input, const at::Tensor d_output, at::Tensor d_weight) {
at::Tensor input_2d, d_output_2d;
// input tensor: collapse to the first dim
auto in_sizes = input.sizes();
if (input.dim() > 2) {
input_2d = input.view({-1, in_sizes[in_sizes.size() - 1]});
} else {
input_2d = input;
}
// d_output tensor: collapse to the first dim
auto d_out_sizes = d_output.sizes();
if (d_output.dim() > 2) {
d_output_2d = d_output.view({-1, d_out_sizes[d_out_sizes.size() - 1]});
} else {
d_output_2d = d_output;
}
int hidden_dim = input_2d.size(0);
int in_dim = input_2d.size(1);
int out_dim = d_weight.size(0);
DISPATCH_HALF_BFLOAT_AND_FLOAT(input_2d.scalar_type(), "wgrad_gemm_accum_fp32",
int result = wgrad_gemm_accum_fp32_cuda<scalar_t>(
input_2d.data_ptr<scalar_t>(),
d_output_2d.data_ptr<scalar_t>(),
d_weight.data_ptr<float>(),
in_dim,
hidden_dim,
out_dim);
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32, "wgrad gemm accum in fp32");
}
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::BFloat16* A,
int lda,
at::BFloat16* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16BF,
lda,
B,
CUDA_R_16BF,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
at::Half* A,
int lda,
at::Half* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_16F,
lda,
B,
CUDA_R_16F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t gemmex_wrapper(
cublasHandle_t handle,
cublasOperation_t transa,
cublasOperation_t transb,
int m,
int n,
int k,
const float* alpha,
float* A,
int lda,
float* B,
int ldb,
const float* beta,
float* C,
int ldc) {
return cublasGemmEx(
handle,
transa,
transb,
m,
n,
k,
alpha,
A,
CUDA_R_32F,
lda,
B,
CUDA_R_32F,
ldb,
beta,
C,
CUDA_R_32F,
ldc,
CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
}
template <typename T>
int wgrad_gemm_accum_fp32_cuda(T *input, T *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim) {
cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
cudaStream_t stream;
cublasGetStream(handle, &stream);
const float alpha = 1.0;
const float beta = 1.0;
int status = 1;
status = gemmex_wrapper(
handle,
CUBLAS_OP_N,
CUBLAS_OP_T,
in_dim,
out_dim,
hidden_dim,
&alpha,
input,
in_dim,
d_output,
out_dim,
&beta,
d_weight,
in_dim);
return status;
}
template int wgrad_gemm_accum_fp32_cuda<at::Half>(at::Half *input, at::Half *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<at::BFloat16>(at::BFloat16 *input, at::BFloat16 *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
template int wgrad_gemm_accum_fp32_cuda<float>(float *input, float *d_output, float *d_weight, int in_dim, int hidden_dim, int out_dim);
/* Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. */
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace {
void compute_n1_n2(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2) {
int idiff = input.ndimension() - normalized_shape.size();
n2 = 1;
for (int i = 0; i < (int)normalized_shape.size(); ++i) {
assert( input.sizes()[i+idiff] == normalized_shape[i] );
n2 *= normalized_shape[i];
}
n1 = 1;
for (int i = 0; i < idiff; ++i) {
n1 *= input.sizes()[i];
}
}
void check_args(
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta
)
{
TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
int& n1,
int& n2
)
{
int64_t normalized_ndim = normalized_shape.size();
if (normalized_ndim < 1) {
std::stringstream ss;
ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
<< "containing at least one element, but got normalized_shape="
<< normalized_shape;
throw std::runtime_error(ss.str());
}
auto input_shape = input.sizes();
auto input_ndim = input.dim();
if (input_ndim < normalized_ndim ||
!input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
std::stringstream ss;
ss << "Given normalized_shape=" << normalized_shape
<< ", expected input with shape [*";
for (auto size : normalized_shape) {
ss << ", " << size;
}
ss << "], but got input of size" << input_shape;
throw std::runtime_error(ss.str());
}
compute_n1_n2(input,normalized_shape,n1,n2);
}
void check_args(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
int& n1,
int& n2
)
{
check_args(input,normalized_shape,n1,n2);
check_args(normalized_shape,gamma,beta);
}
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor output = at::empty_like(
input, gamma.options().dtype(gamma.scalar_type()));
at::Tensor mean = at::empty(
{n1}, input.options().dtype(at::ScalarType::Float));
at::Tensor invvar = at::empty_like(mean);
cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon);
return {output, mean, invvar};
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
at::IntArrayRef normalized_shape,
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta
);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout,
at::Tensor mean,
at::Tensor invvar,
at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma,
at::Tensor beta,
double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
int n1, n2;
check_args(input, normalized_shape, gamma, beta, n1, n2);
at::Tensor grad_input = at::empty_like(input);
at::Tensor grad_gamma = at::empty_like(gamma);
at::Tensor grad_beta = at::empty_like(beta);
cuda_layer_norm_gradient(&dout, &mean, &invvar, &input, n1, n2,
normalized_shape, &gamma, &beta, epsilon,
&grad_input, &grad_gamma, &grad_beta);
return {grad_input, grad_gamma, grad_beta};
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward_affine", &layer_norm_affine,
"LayerNorm forward (CUDA)");
m.def("backward_affine", &layer_norm_gradient_affine,
"LayerNorm backward (CUDA)");
}
This diff is collapsed.
...@@ -11,7 +11,7 @@ from megatron.fused_kernels import load ...@@ -11,7 +11,7 @@ from megatron.fused_kernels import load
def test_load_fused_kernels(): def test_load_fused_kernels():
try: try:
import fused_mix_prec_layer_norm_cuda import fused_layer_norm_cuda
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
import scaled_upper_triang_masked_softmax_cuda import scaled_upper_triang_masked_softmax_cuda
import torch import torch
...@@ -21,7 +21,6 @@ def test_load_fused_kernels(): ...@@ -21,7 +21,6 @@ def test_load_fused_kernels():
print("[Fail] load_fused_kernels") print("[Fail] load_fused_kernels")
raise e raise e
def test_fused_softmax(): def test_fused_softmax():
bert = BertModel.from_pretrained("bert-base-cased").cuda().half() bert = BertModel.from_pretrained("bert-base-cased").cuda().half()
tokenizer = BertTokenizer.from_pretrained("bert-base-cased") tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
...@@ -328,7 +327,7 @@ def test_masked_softmax_backward(): ...@@ -328,7 +327,7 @@ def test_masked_softmax_backward():
def test_allmasked_softmax_forward(): def test_allmasked_softmax_forward():
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
batch = 2 batch = 2
attn = 16 attn = 16
...@@ -345,7 +344,7 @@ def test_allmasked_softmax_forward(): ...@@ -345,7 +344,7 @@ def test_allmasked_softmax_forward():
def test_allmasked_softmax_backward(): def test_allmasked_softmax_backward():
import scaled_masked_softmax_cuda import scaled_masked_softmax_cuda
batch = 2 batch = 2
attn = 16 attn = 16
scale_t = torch.tensor([1.0]) scale_t = torch.tensor([1.0])
......
...@@ -18,40 +18,11 @@ try: ...@@ -18,40 +18,11 @@ try:
except: except:
HAVE_PERSIST_LAYER_NORM = False HAVE_PERSIST_LAYER_NORM = False
global fused_mix_prec_layer_norm_cuda from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction
fused_mix_prec_layer_norm_cuda = None
class FusedLayerNormAffineFunction(torch.autograd.Function): global fused_layer_norm_cuda
fused_layer_norm_cuda = None
@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = fused_mix_prec_layer_norm_cuda.forward_affine(
input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= fused_mix_prec_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
...@@ -64,9 +35,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -64,9 +35,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.apply_layernorm_1p = apply_layernorm_1p self.apply_layernorm_1p = apply_layernorm_1p
global fused_mix_prec_layer_norm_cuda global fused_layer_norm_cuda
fused_mix_prec_layer_norm_cuda = importlib.import_module( fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda")
"fused_mix_prec_layer_norm_cuda")
# List of hiddens sizes supported in the persistent layer norm kernel # List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent # If the hidden size is not supported, fall back to the non-persistent
...@@ -87,7 +57,7 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -87,7 +57,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self.reset_parameters() self.reset_parameters()
self.no_persist_layer_norm = no_persist_layer_norm self.no_persist_layer_norm = no_persist_layer_norm
self.sequence_parallel = sequence_parallel self.sequence_parallel = sequence_parallel
# set sequence parallelism flag on weight and bias parameters # set sequence parallelism flag on weight and bias parameters
setattr(self.weight, 'sequence_parallel', self.sequence_parallel) setattr(self.weight, 'sequence_parallel', self.sequence_parallel)
setattr(self.bias, 'sequence_parallel', self.sequence_parallel) setattr(self.bias, 'sequence_parallel', self.sequence_parallel)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment