Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1016e98a
Commit
1016e98a
authored
Feb 18, 2023
by
zhuww
Browse files
megatron-lm0.3.2 based on dtk-22.10
parent
6c3f6c7b
Changes
241
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3506 additions
and
0 deletions
+3506
-0
megatron/fused_kernels/build/scaled_softmax.o
megatron/fused_kernels/build/scaled_softmax.o
+0
-0
megatron/fused_kernels/build/scaled_softmax_cuda.so
megatron/fused_kernels/build/scaled_softmax_cuda.so
+0
-0
megatron/fused_kernels/build/scaled_softmax_hip.cuda.o
megatron/fused_kernels/build/scaled_softmax_hip.cuda.o
+0
-0
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o
.../fused_kernels/build/scaled_upper_triang_masked_softmax.o
+0
-0
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so
..._kernels/build/scaled_upper_triang_masked_softmax_cuda.so
+0
-0
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_hip.cuda.o
...rnels/build/scaled_upper_triang_masked_softmax_hip.cuda.o
+0
-0
megatron/fused_kernels/compat.h
megatron/fused_kernels/compat.h
+31
-0
megatron/fused_kernels/fused_weight_gradient_dense.cpp
megatron/fused_kernels/fused_weight_gradient_dense.cpp
+47
-0
megatron/fused_kernels/fused_weight_gradient_dense.cu
megatron/fused_kernels/fused_weight_gradient_dense.cu
+157
-0
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+201
-0
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+832
-0
megatron/fused_kernels/layer_norm_hip_kernel.hip
megatron/fused_kernels/layer_norm_hip_kernel.hip
+833
-0
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+97
-0
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+717
-0
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+117
-0
megatron/fused_kernels/scaled_masked_softmax_hip.hip
megatron/fused_kernels/scaled_masked_softmax_hip.hip
+118
-0
megatron/fused_kernels/scaled_softmax.cpp
megatron/fused_kernels/scaled_softmax.cpp
+75
-0
megatron/fused_kernels/scaled_softmax_cuda.cu
megatron/fused_kernels/scaled_softmax_cuda.cu
+104
-0
megatron/fused_kernels/scaled_softmax_hip.hip
megatron/fused_kernels/scaled_softmax_hip.hip
+105
-0
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+72
-0
No files found.
megatron/fused_kernels/build/scaled_softmax.o
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/build/scaled_softmax_cuda.so
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/build/scaled_softmax_hip.cuda.o
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax.o
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_cuda.so
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/build/scaled_upper_triang_masked_softmax_hip.cuda.o
0 → 100644
View file @
1016e98a
File added
megatron/fused_kernels/compat.h
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#ifdef VERSION_GE_1_3
#define DATA_PTR data_ptr
#else
#define DATA_PTR data
#endif
megatron/fused_kernels/fused_weight_gradient_dense.cpp
0 → 100644
View file @
1016e98a
#include <torch/torch.h>
#include <torch/extension.h>
#include <vector>
#include <stdio.h>
#include "type_shim.h"
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
void
wgrad_gemm_accum_fp32
(
const
at
::
Tensor
input
,
const
at
::
Tensor
d_output
,
at
::
Tensor
d_weight
)
{
at
::
Tensor
input_2d
,
d_output_2d
;
// input tensor: collapse to the first dim
auto
in_sizes
=
input
.
sizes
();
if
(
input
.
dim
()
>
2
)
{
input_2d
=
input
.
view
({
-
1
,
in_sizes
[
in_sizes
.
size
()
-
1
]});
}
else
{
input_2d
=
input
;
}
// d_output tensor: collapse to the first dim
auto
d_out_sizes
=
d_output
.
sizes
();
if
(
d_output
.
dim
()
>
2
)
{
d_output_2d
=
d_output
.
view
({
-
1
,
d_out_sizes
[
d_out_sizes
.
size
()
-
1
]});
}
else
{
d_output_2d
=
d_output
;
}
int
hidden_dim
=
input_2d
.
size
(
0
);
int
in_dim
=
input_2d
.
size
(
1
);
int
out_dim
=
d_weight
.
size
(
0
);
DISPATCH_HALF_BFLOAT_AND_FLOAT
(
input_2d
.
scalar_type
(),
"wgrad_gemm_accum_fp32"
,
int
result
=
wgrad_gemm_accum_fp32_cuda
<
scalar_t
>
(
input_2d
.
data_ptr
<
scalar_t
>
(),
d_output_2d
.
data_ptr
<
scalar_t
>
(),
d_weight
.
data_ptr
<
float
>
(),
in_dim
,
hidden_dim
,
out_dim
);
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"wgrad_gemm_accum_fp32"
,
&
wgrad_gemm_accum_fp32
,
"wgrad gemm accum in fp32"
);
}
megatron/fused_kernels/fused_weight_gradient_dense.cu
0 → 100644
View file @
1016e98a
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <torch/torch.h>
/* Includes, cuda */
#include <cublas_v2.h>
#include <cuda_runtime.h>
// BF16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
BFloat16
*
A
,
int
lda
,
at
::
BFloat16
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16BF
,
lda
,
B
,
CUDA_R_16BF
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP16 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
at
::
Half
*
A
,
int
lda
,
at
::
Half
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_16F
,
lda
,
B
,
CUDA_R_16F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
// FP32 Tensor core wrapper around cublas GEMMEx
cublasStatus_t
gemmex_wrapper
(
cublasHandle_t
handle
,
cublasOperation_t
transa
,
cublasOperation_t
transb
,
int
m
,
int
n
,
int
k
,
const
float
*
alpha
,
float
*
A
,
int
lda
,
float
*
B
,
int
ldb
,
const
float
*
beta
,
float
*
C
,
int
ldc
)
{
return
cublasGemmEx
(
handle
,
transa
,
transb
,
m
,
n
,
k
,
alpha
,
A
,
CUDA_R_32F
,
lda
,
B
,
CUDA_R_32F
,
ldb
,
beta
,
C
,
CUDA_R_32F
,
ldc
,
CUDA_R_32F
,
CUBLAS_GEMM_DEFAULT_TENSOR_OP
);
}
template
<
typename
T
>
int
wgrad_gemm_accum_fp32_cuda
(
T
*
input
,
T
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
)
{
cublasHandle_t
handle
=
at
::
cuda
::
getCurrentCUDABlasHandle
();
cudaStream_t
stream
;
cublasGetStream
(
handle
,
&
stream
);
const
float
alpha
=
1.0
;
const
float
beta
=
1.0
;
int
status
=
1
;
status
=
gemmex_wrapper
(
handle
,
CUBLAS_OP_N
,
CUBLAS_OP_T
,
in_dim
,
out_dim
,
hidden_dim
,
&
alpha
,
input
,
in_dim
,
d_output
,
out_dim
,
&
beta
,
d_weight
,
in_dim
);
return
status
;
}
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
Half
>(
at
::
Half
*
input
,
at
::
Half
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
at
::
BFloat16
>(
at
::
BFloat16
*
input
,
at
::
BFloat16
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
template
int
wgrad_gemm_accum_fp32_cuda
<
float
>(
float
*
input
,
float
*
d_output
,
float
*
d_weight
,
int
in_dim
,
int
hidden_dim
,
int
out_dim
);
megatron/fused_kernels/layer_norm_cuda.cpp
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include <vector>
#include <cassert>
#include "compat.h"
namespace
{
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
assert
(
input
.
sizes
()[
i
+
idiff
]
==
normalized_shape
[
i
]
);
n2
*=
normalized_shape
[
i
];
}
n1
=
1
;
for
(
int
i
=
0
;
i
<
idiff
;
++
i
)
{
n1
*=
input
.
sizes
()[
i
];
}
}
void
check_args
(
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
)
{
TORCH_CHECK
(
!
gamma
.
defined
()
||
gamma
.
sizes
().
equals
(
normalized_shape
));
TORCH_CHECK
(
!
beta
.
defined
()
||
beta
.
sizes
().
equals
(
normalized_shape
));
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n2
)
{
int64_t
normalized_ndim
=
normalized_shape
.
size
();
if
(
normalized_ndim
<
1
)
{
std
::
stringstream
ss
;
ss
<<
"Expected normalized_shape to be at least 1-dimensional, i.e., "
<<
"containing at least one element, but got normalized_shape="
<<
normalized_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
auto
input_shape
=
input
.
sizes
();
auto
input_ndim
=
input
.
dim
();
if
(
input_ndim
<
normalized_ndim
||
!
input_shape
.
slice
(
input_ndim
-
normalized_ndim
).
equals
(
normalized_shape
))
{
std
::
stringstream
ss
;
ss
<<
"Given normalized_shape="
<<
normalized_shape
<<
", expected input with shape [*"
;
for
(
auto
size
:
normalized_shape
)
{
ss
<<
", "
<<
size
;
}
ss
<<
"], but got input of size"
<<
input_shape
;
throw
std
::
runtime_error
(
ss
.
str
());
}
compute_n1_n2
(
input
,
normalized_shape
,
n1
,
n2
);
}
void
check_args
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
int
&
n1
,
int
&
n2
)
{
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
normalized_shape
,
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
);
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/DeviceUtils.cuh"
#include <cuda.h>
#include <cuda_runtime.h>
#include "type_shim.h"
template
<
typename
U
>
__device__
void
cuWelfordOnlineSum
(
const
U
curr
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
count
=
count
+
U
(
1
);
U
delta
=
curr
-
mu
;
U
lmean
=
mu
+
delta
/
count
;
mu
=
lmean
;
U
delta2
=
curr
-
lmean
;
sigma2
=
sigma2
+
delta
*
delta2
;
}
template
<
typename
U
>
__device__
void
cuChanOnlineSum
(
const
U
muB
,
const
U
sigma2B
,
const
U
countB
,
U
&
mu
,
U
&
sigma2
,
U
&
count
)
{
U
delta
=
muB
-
mu
;
U
nA
=
count
;
U
nB
=
countB
;
count
=
count
+
countB
;
U
nX
=
count
;
if
(
nX
>
U
(
0
))
{
nA
=
nA
/
nX
;
nB
=
nB
/
nX
;
mu
=
nA
*
mu
+
nB
*
muB
;
sigma2
=
sigma2
+
sigma2B
+
delta
*
delta
*
nA
*
nB
*
nX
;
}
else
{
mu
=
U
(
0
);
sigma2
=
U
(
0
);
}
}
template
<
typename
T
,
typename
U
>
__device__
void
cuWelfordMuSigma2
(
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
U
&
mu
,
U
&
sigma2
,
U
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U
count
=
U
(
0
);
mu
=
U
(
0
);
sigma2
=
U
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
T
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
+
k
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
l
]);
cuWelfordOnlineSum
<
U
>
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
U
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
U
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
U
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
U
*
ubuf
=
(
U
*
)
buf
;
U
*
ibuf
=
(
U
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
U
muB
=
ubuf
[
2
*
threadIdx
.
y
];
U
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
U
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
<
U
>
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
U
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
U
(
n2
),
0
);
}
}
}
template
<
>
__device__
void
cuWelfordMuSigma2
(
const
at
::
Half
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
int
i1
,
float
&
mu
,
float
&
sigma2
,
float
*
buf
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float
count
=
0.0
f
;
mu
=
float
(
0
);
sigma2
=
float
(
0
);
if
(
i1
<
n1
)
{
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
const
at
::
Half
*
lvals
=
vals
+
i1
*
n2
;
int
l
=
8
*
thrx
;
if
((((
size_t
)
lvals
)
&
3
)
!=
0
)
{
// 16 bit alignment
// first thread consumes first point
if
(
thrx
==
0
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
0
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
++
l
;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
for
(;
l
<
n2
;
++
l
)
{
float
curr
=
static_cast
<
float
>
(
lvals
[
l
]);
cuWelfordOnlineSum
(
curr
,
mu
,
sigma2
,
count
);
}
// intra-warp reductions
for
(
int
l
=
0
;
l
<=
4
;
++
l
)
{
int
srcLaneB
=
(
threadIdx
.
x
+
(
1
<<
l
))
&
31
;
float
muB
=
WARP_SHFL
(
mu
,
srcLaneB
);
float
countB
=
WARP_SHFL
(
count
,
srcLaneB
);
float
sigma2B
=
WARP_SHFL
(
sigma2
,
srcLaneB
);
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
float
*
ubuf
=
(
float
*
)
buf
;
float
*
ibuf
=
(
float
*
)(
ubuf
+
blockDim
.
y
);
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_y
=
threadIdx
.
y
-
offset
;
ubuf
[
2
*
wrt_y
]
=
mu
;
ubuf
[
2
*
wrt_y
+
1
]
=
sigma2
;
ibuf
[
wrt_y
]
=
count
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
<
offset
)
{
float
muB
=
ubuf
[
2
*
threadIdx
.
y
];
float
sigma2B
=
ubuf
[
2
*
threadIdx
.
y
+
1
];
float
countB
=
ibuf
[
threadIdx
.
y
];
cuChanOnlineSum
(
muB
,
sigma2B
,
countB
,
mu
,
sigma2
,
count
);
}
__syncthreads
();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
ubuf
[
0
]
=
mu
;
ubuf
[
1
]
=
sigma2
;
}
__syncthreads
();
mu
=
ubuf
[
0
];
sigma2
=
ubuf
[
1
]
/
float
(
n2
);
// don't care about final value of count, we know count == n2
}
else
{
mu
=
WARP_SHFL
(
mu
,
0
);
sigma2
=
WARP_SHFL
(
sigma2
/
float
(
n2
),
0
);
}
}
}
template
<
typename
U
>
__device__
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
}
template
<
>
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
template
<
>
__device__
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
}
namespace
{
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template
<
typename
T
>
struct
SharedMemory
;
template
<
>
struct
SharedMemory
<
float
>
{
__device__
float
*
getPointer
()
{
extern
__shared__
float
s_float
[];
return
s_float
;
}
};
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuApplyLayerNorm
(
V
*
__restrict__
output_vals
,
U
*
__restrict__
mean
,
U
*
__restrict__
invvar
,
const
T
*
__restrict__
vals
,
const
int
n1
,
const
int
n2
,
const
U
epsilon
,
const
V
*
__restrict__
gamma
,
const
V
*
__restrict__
beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
cuWelfordMuSigma2
(
vals
,
n1
,
n2
,
i1
,
mu
,
sigma2
,
buf
);
const
T
*
lvals
=
vals
+
i1
*
n2
;
V
*
ovals
=
output_vals
+
i1
*
n2
;
U
c_invvar
=
rsqrt
(
sigma2
+
epsilon
);
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
gamma
[
i
]
*
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
))
+
beta
[
i
];
}
}
else
{
for
(
int
i
=
thrx
;
i
<
n2
;
i
+=
numx
)
{
U
curr
=
static_cast
<
U
>
(
lvals
[
i
]);
ovals
[
i
]
=
static_cast
<
V
>
(
c_invvar
*
(
curr
-
mu
));
}
}
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
mean
[
i1
]
=
mu
;
invvar
[
i1
]
=
c_invvar
;
}
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadWriteStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
=
curr_dout
;
warp_buf2
[
write_idx
]
=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
else
{
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
else
{
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
warp_buf1
[
write_idx
]
=
U
(
0
);
warp_buf2
[
write_idx
]
=
U
(
0
);
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__device__
void
cuLoadAddStridedInputs
(
const
int
i1_block
,
const
int
thr_load_row_off
,
const
int
thr_load_col_off
,
const
int
i2_off
,
const
int
row_stride
,
U
*
warp_buf1
,
U
*
warp_buf2
,
const
T
*
input
,
const
V
*
dout
,
const
int
i1_end
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
)
{
int
i1
=
i1_block
+
thr_load_row_off
;
if
(
i1
<
i1_end
)
{
U
curr_mean
=
mean
[
i1
];
U
curr_invvar
=
invvar
[
i1
];
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
i2
=
i2_off
+
k
;
int
load_idx
=
i1
*
n2
+
i2
;
int
write_idx
=
thr_load_row_off
*
row_stride
+
thr_load_col_off
+
k
;
if
(
i2
<
n2
)
{
U
curr_input
=
static_cast
<
U
>
(
input
[
load_idx
]);
U
curr_dout
=
static_cast
<
U
>
(
dout
[
load_idx
]);
warp_buf1
[
write_idx
]
+=
curr_dout
;
warp_buf2
[
write_idx
]
+=
curr_dout
*
(
curr_input
-
curr_mean
)
*
curr_invvar
;
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputePartGradGammaBeta
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
U
*
part_grad_gamma
,
U
*
part_grad_beta
)
{
const
int
numsegs_n1
=
(
n1
+
blockDim
.
y
*
blockDim
.
y
-
1
)
/
(
blockDim
.
y
*
blockDim
.
y
);
const
int
segs_per_block
=
(
numsegs_n1
+
gridDim
.
y
-
1
)
/
gridDim
.
y
;
const
int
i1_beg
=
blockIdx
.
y
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_beg_plus_one
=
(
blockIdx
.
y
+
1
)
*
segs_per_block
*
blockDim
.
y
*
blockDim
.
y
;
const
int
i1_end
=
i1_beg_plus_one
<
n1
?
i1_beg_plus_one
:
n1
;
const
int
row_stride
=
blockDim
.
x
+
1
;
const
int
thr_load_col_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
&
(
blockDim
.
x
-
1
);
const
int
thr_load_row_off
=
(
threadIdx
.
x
*
blockDim
.
y
)
/
blockDim
.
x
+
threadIdx
.
y
*
blockDim
.
y
;
const
int
i2_off
=
blockIdx
.
x
*
blockDim
.
x
+
thr_load_col_off
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
// buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U
*
warp_buf1
=
(
U
*
)
buf
;
U
*
warp_buf2
=
warp_buf1
+
blockDim
.
y
*
blockDim
.
y
*
row_stride
;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs
(
i1_beg
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
for
(
int
i1_block
=
i1_beg
+
blockDim
.
y
*
blockDim
.
y
;
i1_block
<
i1_end
;
i1_block
+=
blockDim
.
y
*
blockDim
.
y
)
{
cuLoadAddStridedInputs
(
i1_block
,
thr_load_row_off
,
thr_load_col_off
,
i2_off
,
row_stride
,
warp_buf1
,
warp_buf2
,
input
,
dout
,
i1_end
,
n2
,
mean
,
invvar
);
}
__syncthreads
();
// inter-warp reductions
// sum within each warp
U
acc1
=
U
(
0
);
U
acc2
=
U
(
0
);
for
(
int
k
=
0
;
k
<
blockDim
.
y
;
++
k
)
{
int
row1
=
threadIdx
.
y
+
k
*
blockDim
.
y
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
acc1
+=
warp_buf1
[
idx1
];
acc2
+=
warp_buf2
[
idx1
];
}
warp_buf1
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc1
;
warp_buf2
[
threadIdx
.
y
*
row_stride
+
threadIdx
.
x
]
=
acc2
;
__syncthreads
();
// sum all warps
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
1
;
offset
/=
2
)
{
if
(
threadIdx
.
y
<
offset
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
offset
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
warp_buf1
[
idx1
]
+=
warp_buf1
[
idx2
];
warp_buf2
[
idx1
]
+=
warp_buf2
[
idx2
];
}
__syncthreads
();
}
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
threadIdx
.
y
==
0
&&
i2
<
n2
)
{
int
row1
=
threadIdx
.
y
;
int
row2
=
threadIdx
.
y
+
1
;
int
idx1
=
row1
*
row_stride
+
threadIdx
.
x
;
int
idx2
=
row2
*
row_stride
+
threadIdx
.
x
;
part_grad_beta
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf1
[
idx1
]
+
warp_buf1
[
idx2
];
part_grad_gamma
[
blockIdx
.
y
*
n2
+
i2
]
=
warp_buf2
[
idx1
]
+
warp_buf2
[
idx2
];
}
}
template
<
typename
U
,
typename
V
>
__global__
void
cuComputeGradGammaBeta
(
const
U
*
part_grad_gamma
,
const
U
*
part_grad_beta
,
const
int
part_size
,
const
int
n1
,
const
int
n2
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
// sum partial gradients for gamma and beta
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
int
i2
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
i2
<
n2
)
{
// each warp does sequential reductions until reduced part_size is num_warps
int
num_warp_reductions
=
part_size
/
blockDim
.
y
;
U
sum_gamma
=
U
(
0
);
U
sum_beta
=
U
(
0
);
const
U
*
part_grad_gamma_ptr
=
part_grad_gamma
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
const
U
*
part_grad_beta_ptr
=
part_grad_beta
+
threadIdx
.
y
*
num_warp_reductions
*
n2
+
i2
;
for
(
int
warp_offset
=
0
;
warp_offset
<
num_warp_reductions
;
++
warp_offset
)
{
sum_gamma
+=
part_grad_gamma_ptr
[
warp_offset
*
n2
];
sum_beta
+=
part_grad_beta_ptr
[
warp_offset
*
n2
];
}
// inter-warp reductions
const
int
nbsize3
=
blockDim
.
x
*
blockDim
.
y
/
2
;
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>=
1
;
offset
/=
2
)
{
// top half write to shared memory
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
write_idx
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
write_idx
]
=
sum_gamma
;
buf
[
write_idx
+
nbsize3
]
=
sum_beta
;
}
__syncthreads
();
// bottom half sums
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_idx
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_gamma
+=
buf
[
read_idx
];
sum_beta
+=
buf
[
read_idx
+
nbsize3
];
}
__syncthreads
();
}
// write out fully summed gradients
if
(
threadIdx
.
y
==
0
)
{
grad_gamma
[
i2
]
=
sum_gamma
;
grad_beta
[
i2
]
=
sum_beta
;
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
void
cuComputeGradInput
(
const
V
*
__restrict__
dout
,
const
T
*
__restrict__
input
,
const
int
n1
,
const
int
n2
,
const
U
*
__restrict__
mean
,
const
U
*
__restrict__
invvar
,
U
epsilon
,
const
V
*
gamma
,
T
*
grad_input
)
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_invvar
=
invvar
[
i1
];
const
T
*
k_input
=
input
+
i1
*
n2
;
const
V
*
k_dout
=
dout
+
i1
*
n2
;
const
int
numx
=
blockDim
.
x
*
blockDim
.
y
;
const
int
thrx
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
if
(
gamma
!=
NULL
)
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
*
gamma
[
l
+
k
];
sum_loss2
+=
c_loss
*
gamma
[
l
+
k
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
*
gamma
[
l
];
sum_loss2
+=
c_loss
*
gamma
[
l
]
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
else
{
int
l
=
4
*
thrx
;
for
(;
l
+
3
<
n2
;
l
+=
4
*
numx
)
{
for
(
int
k
=
0
;
k
<
4
;
++
k
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
+
k
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
+
k
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
for
(;
l
<
n2
;
++
l
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
sum_loss1
+=
c_loss
;
sum_loss2
+=
c_loss
*
(
c_h
-
c_mean
)
*
c_invvar
;
}
}
// intra-warp reductions
for
(
int
mask
=
blockDim
.
x
/
2
;
mask
>
0
;
mask
/=
2
)
{
sum_loss1
+=
WARP_SHFL_XOR
(
sum_loss1
,
mask
);
sum_loss2
+=
WARP_SHFL_XOR
(
sum_loss2
,
mask
);
}
// inter-warp reductions
if
(
blockDim
.
y
>
1
)
{
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
for
(
int
offset
=
blockDim
.
y
/
2
;
offset
>
0
;
offset
/=
2
)
{
// upper half of warps write to shared
if
(
threadIdx
.
y
>=
offset
&&
threadIdx
.
y
<
2
*
offset
)
{
const
int
wrt_i
=
(
threadIdx
.
y
-
offset
)
*
blockDim
.
x
+
threadIdx
.
x
;
buf
[
2
*
wrt_i
]
=
sum_loss1
;
buf
[
2
*
wrt_i
+
1
]
=
sum_loss2
;
}
__syncthreads
();
// lower half merges
if
(
threadIdx
.
y
<
offset
)
{
const
int
read_i
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
sum_loss1
+=
buf
[
2
*
read_i
];
sum_loss2
+=
buf
[
2
*
read_i
+
1
];
}
__syncthreads
();
}
if
(
threadIdx
.
y
==
0
)
{
buf
[
2
*
threadIdx
.
x
]
=
sum_loss1
;
buf
[
2
*
threadIdx
.
x
+
1
]
=
sum_loss2
;
}
__syncthreads
();
if
(
threadIdx
.
y
!=
0
)
{
sum_loss1
=
buf
[
2
*
threadIdx
.
x
];
sum_loss2
=
buf
[
2
*
threadIdx
.
x
+
1
];
}
}
// all threads now have the two sums over l
U
fH
=
(
U
)
n2
;
U
term1
=
(
U
(
1
)
/
fH
)
*
c_invvar
;
T
*
k_grad_input
=
grad_input
+
i1
*
n2
;
if
(
gamma
!=
NULL
)
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
*
gamma
[
l
];
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
else
{
for
(
int
l
=
thrx
;
l
<
n2
;
l
+=
numx
)
{
const
U
c_h
=
static_cast
<
U
>
(
k_input
[
l
]);
const
U
c_loss
=
static_cast
<
U
>
(
k_dout
[
l
]);
U
f_grad_input
=
fH
*
c_loss
;
f_grad_input
-=
sum_loss1
;
f_grad_input
-=
(
c_h
-
c_mean
)
*
c_invvar
*
sum_loss2
;
f_grad_input
*=
term1
;
k_grad_input
[
l
]
=
static_cast
<
T
>
(
f_grad_input
);
}
}
// prevent race where buf is written again before reads are done
__syncthreads
();
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
V
*
output
,
U
*
mean
,
U
*
invvar
,
const
T
*
input
,
int
n1
,
int
n2
,
double
epsilon
,
const
V
*
gamma
,
const
V
*
beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
threads
.
y
>
1
?
threads
.
y
*
sizeof
(
U
)
+
(
threads
.
y
/
2
)
*
sizeof
(
U
)
:
0
;
cuApplyLayerNorm
<<<
blocks
,
threads
,
nshared
,
stream
>>>
(
output
,
mean
,
invvar
,
input
,
n1
,
n2
,
U
(
epsilon
),
gamma
,
beta
);
}
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
->
DATA_PTR
<
scalar_t_in
>
(),
n1
,
n2
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
const
V
*
dout
,
const
U
*
mean
,
const
U
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
const
V
*
gamma
,
const
V
*
beta
,
double
epsilon
,
T
*
grad_input
,
V
*
grad_gamma
,
V
*
grad_beta
)
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
gamma
!=
NULL
&&
beta
!=
NULL
)
{
// compute grad_gamma(j) and grad_beta(j)
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
());
const
dim3
threads3
(
32
,
8
,
1
);
const
dim3
blocks3
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
1
,
1
);
const
int
nshared3
=
threads3
.
x
*
threads3
.
y
*
sizeof
(
U
);
cuComputeGradGammaBeta
<<<
blocks3
,
threads3
,
nshared3
,
stream
>>>
(
part_grad_gamma
.
DATA_PTR
<
U
>
(),
part_grad_beta
.
DATA_PTR
<
U
>
(),
part_size
,
n1
,
n2
,
grad_gamma
,
grad_beta
);
}
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
threads1
.
y
>
1
?
threads1
.
y
*
threads1
.
x
*
sizeof
(
U
)
:
0
;
cuComputeGradInput
<<<
blocks1
,
threads1
,
nshared
,
stream
>>>
(
dout
,
input
->
DATA_PTR
<
T
>
(),
n1
,
n2
,
mean
,
invvar
,
U
(
epsilon
),
gamma
,
grad_input
);
}
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
invvar
,
at
::
Tensor
*
input
,
int
n1
,
int
n2
,
#ifdef VERSION_GE_1_1
at
::
IntArrayRef
normalized_shape
,
#else
at
::
IntList
normalized_shape
,
#endif
at
::
Tensor
*
gamma
,
at
::
Tensor
*
beta
,
double
epsilon
,
at
::
Tensor
*
grad_input
,
at
::
Tensor
*
grad_gamma
,
at
::
Tensor
*
grad_beta
)
{
using
namespace
at
;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
input
->
scalar_type
(),
gamma
->
scalar_type
(),
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_out
>
(),
mean
->
DATA_PTR
<
float
>
(),
invvar
->
DATA_PTR
<
float
>
(),
input
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_out
>
()
:
NULL
);
)
}
megatron/fused_kernels/layer_norm_hip_kernel.hip
0 → 100644
View file @
1016e98a
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include "ATen/ATen.h"
#include "ATen/AccumulateType.h"
#include "ATen/hip/HIPContext.h"
#include "ATen/hip/DeviceUtils.cuh"
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include "type_shim.h"
template<typename U> __device__
void cuWelfordOnlineSum(
const U curr,
U& mu,
U& sigma2,
U& count)
{
count = count + U(1);
U delta = curr - mu;
U lmean = mu + delta / count;
mu = lmean;
U delta2 = curr - lmean;
sigma2 = sigma2 + delta * delta2;
}
template<typename U> __device__
void cuChanOnlineSum(
const U muB,
const U sigma2B,
const U countB,
U& mu,
U& sigma2,
U& count)
{
U delta = muB - mu;
U nA = count;
U nB = countB;
count = count + countB;
U nX = count;
if (nX > U(0)) {
nA = nA / nX;
nB = nB / nX;
mu = nA*mu + nB*muB;
sigma2 = sigma2 + sigma2B + delta * delta * nA * nB * nX;
} else {
mu = U(0);
sigma2 = U(0);
}
}
template<typename T, typename U> __device__
void cuWelfordMuSigma2(
const T* __restrict__ vals,
const int n1,
const int n2,
const int i1,
U& mu,
U& sigma2,
U* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
U count = U(0);
mu= U(0);
sigma2 = U(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const T* lvals = vals + i1*n2;
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
U curr = static_cast<U>(lvals[l+k]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
U curr = static_cast<U>(lvals[l]);
cuWelfordOnlineSum<U>(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
U muB = WARP_SHFL(mu, srcLaneB);
U countB = WARP_SHFL(count, srcLaneB);
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
U* ubuf = (U*)buf;
U* ibuf = (U*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
U muB = ubuf[2*threadIdx.y];
U sigma2B = ubuf[2*threadIdx.y+1];
U countB = ibuf[threadIdx.y];
cuChanOnlineSum<U>(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/U(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/U(n2), 0);
}
}
}
template<> __device__
void cuWelfordMuSigma2(
const at::Half* __restrict__ vals,
const int n1,
const int n2,
const int i1,
float& mu,
float& sigma2,
float* buf)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensor is contiguous
// 3) 2*blockDim.y*sizeof(U)+blockDim.y*sizeof(int) shared memory available.
//
// compute variance and mean over n2
float count = 0.0f;
mu= float(0);
sigma2 = float(0);
if (i1 < n1) {
// one warp normalizes one n1 index,
// synchronization is implicit
// initialize with standard Welford algorithm
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
const at::Half* lvals = vals + i1*n2;
int l = 8*thrx;
if ((((size_t)lvals)&3) != 0) {
// 16 bit alignment
// first thread consumes first point
if (thrx == 0) {
float curr = static_cast<float>(lvals[0]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
++l;
}
// at this point, lvals[l] are 32 bit aligned for all threads.
for (; l+7 < n2; l+=8*numx) {
for (int k = 0; k < 8; k+=2) {
float2 curr = __half22float2(*((__half2*)(lvals+l+k)));
cuWelfordOnlineSum(curr.x,mu,sigma2,count);
cuWelfordOnlineSum(curr.y,mu,sigma2,count);
}
}
for (; l < n2; ++l) {
float curr = static_cast<float>(lvals[l]);
cuWelfordOnlineSum(curr,mu,sigma2,count);
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
int srcLaneB = (threadIdx.x+(1<<l))&31;
float muB = WARP_SHFL(mu, srcLaneB);
float countB = WARP_SHFL(count, srcLaneB);
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
// threadIdx.x == 0 has correct values for each warp
// inter-warp reductions
if (blockDim.y > 1) {
float* ubuf = (float*)buf;
float* ibuf = (float*)(ubuf + blockDim.y);
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_y = threadIdx.y - offset;
ubuf[2*wrt_y] = mu;
ubuf[2*wrt_y+1] = sigma2;
ibuf[wrt_y] = count;
}
__syncthreads();
// lower half merges
if (threadIdx.x == 0 && threadIdx.y < offset) {
float muB = ubuf[2*threadIdx.y];
float sigma2B = ubuf[2*threadIdx.y+1];
float countB = ibuf[threadIdx.y];
cuChanOnlineSum(muB,sigma2B,countB,mu,sigma2,count);
}
__syncthreads();
}
// threadIdx.x = 0 && threadIdx.y == 0 only thread that has correct values
if (threadIdx.x == 0 && threadIdx.y == 0) {
ubuf[0] = mu;
ubuf[1] = sigma2;
}
__syncthreads();
mu = ubuf[0];
sigma2 = ubuf[1]/float(n2);
// don't care about final value of count, we know count == n2
} else {
mu = WARP_SHFL(mu, 0);
sigma2 = WARP_SHFL(sigma2/float(n2), 0);
}
}
}
template<typename U> __device__ U rsqrt(U v) {
return U(1) / sqrt(v);
}
template<> __device__ float rsqrt(float v) {
return rsqrtf(v);
}
template<> __device__ double rsqrt(double v) {
return rsqrt(v);
}
namespace {
// This is the un-specialized struct. Note that we prevent instantiation of this
// struct by putting an undefined symbol in the function body so it won't compile.
// template <typename T>
// struct SharedMemory
// {
// // Ensure that we won't compile any un-specialized types
// __device__ T *getPointer()
// {
// extern __device__ void error(void);
// error();
// return NULL;
// }
// };
// https://github.com/NVIDIA/apex/issues/246
template <typename T>
struct SharedMemory;
template <>
struct SharedMemory <float>
{
__device__ float *getPointer()
{
HIP_DYNAMIC_SHARED( float, s_float)
return s_float;
}
};
}
template<typename T, typename U, typename V> __global__
void cuApplyLayerNorm(
V* __restrict__ output_vals,
U* __restrict__ mean,
U* __restrict__ invvar,
const T* __restrict__ vals,
const int n1,
const int n2,
const U epsilon,
const V* __restrict__ gamma,
const V* __restrict__ beta
)
{
// Assumptions:
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
//
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
U mu,sigma2;
cuWelfordMuSigma2(vals,n1,n2,i1,mu,sigma2,buf);
const T* lvals = vals + i1*n2;
V* ovals = output_vals + i1*n2;
U c_invvar = rsqrt(sigma2 + epsilon);
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL && beta != NULL) {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = gamma[i] * static_cast<V>(c_invvar * (curr - mu)) + beta[i];
}
} else {
for (int i = thrx; i < n2; i+=numx) {
U curr = static_cast<U>(lvals[i]);
ovals[i] = static_cast<V>(c_invvar * (curr - mu));
}
}
if (threadIdx.x == 0 && threadIdx.y == 0) {
mean[i1] = mu;
invvar[i1] = c_invvar;
}
__syncthreads();
}
}
template<typename T, typename U, typename V> __device__
void cuLoadWriteStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] = curr_dout;
warp_buf2[write_idx] = curr_dout * (curr_input - curr_mean) * curr_invvar;
} else {
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
} else {
for (int k = 0; k < blockDim.y; ++k) {
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
warp_buf1[write_idx] = U(0);
warp_buf2[write_idx] = U(0);
}
}
}
template<typename T, typename U, typename V> __device__
void cuLoadAddStridedInputs(
const int i1_block,
const int thr_load_row_off,
const int thr_load_col_off,
const int i2_off,
const int row_stride,
U* warp_buf1,
U* warp_buf2,
const T* input,
const V* dout,
const int i1_end,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar
)
{
int i1 = i1_block+thr_load_row_off;
if (i1 < i1_end) {
U curr_mean = mean[i1];
U curr_invvar = invvar[i1];
for (int k = 0; k < blockDim.y; ++k) {
int i2 = i2_off + k;
int load_idx = i1*n2+i2;
int write_idx = thr_load_row_off*row_stride+thr_load_col_off+k;
if (i2<n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] += curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputePartGradGammaBeta(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
U* part_grad_gamma,
U* part_grad_beta)
{
const int numsegs_n1 = (n1+blockDim.y*blockDim.y-1) / (blockDim.y*blockDim.y);
const int segs_per_block = (numsegs_n1 + gridDim.y - 1) / gridDim.y;
const int i1_beg = blockIdx.y * segs_per_block * blockDim.y*blockDim.y;
const int i1_beg_plus_one = (blockIdx.y+1) * segs_per_block * blockDim.y*blockDim.y;
const int i1_end = i1_beg_plus_one < n1 ? i1_beg_plus_one : n1;
const int row_stride = blockDim.x+1;
const int thr_load_col_off = (threadIdx.x*blockDim.y)&(blockDim.x-1);
const int thr_load_row_off = (threadIdx.x*blockDim.y)/blockDim.x + threadIdx.y*blockDim.y;
const int i2_off = blockIdx.x * blockDim.x + thr_load_col_off;
SharedMemory<U> shared;
U* buf = shared.getPointer(); // buf has at least blockDim.x * blockDim.y * blockDim.y + (blockDim.y - 1)*(blockDim.x/blockDim.y) elements
U* warp_buf1 = (U*)buf;
U* warp_buf2 = warp_buf1 + blockDim.y * blockDim.y * row_stride;
// compute partial sums from strided inputs
// do this to increase number of loads in flight
cuLoadWriteStridedInputs(i1_beg,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
for (int i1_block = i1_beg+blockDim.y*blockDim.y; i1_block < i1_end; i1_block+=blockDim.y*blockDim.y) {
cuLoadAddStridedInputs(i1_block,thr_load_row_off,thr_load_col_off,i2_off,row_stride,warp_buf1,warp_buf2,input,dout,i1_end,n2,mean,invvar);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < blockDim.y; ++k) {
int row1 = threadIdx.y + k*blockDim.y;
int idx1 = row1*row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y*row_stride+threadIdx.x] = acc1;
warp_buf2[threadIdx.y*row_stride+threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = blockDim.y/2; offset > 1; offset /= 2) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1*row_stride + threadIdx.x;
int idx2 = row2*row_stride + threadIdx.x;
part_grad_beta[blockIdx.y*n2+i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y*n2+i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template<typename U, typename V> __global__
void cuComputeGradGammaBeta(
const U* part_grad_gamma,
const U* part_grad_beta,
const int part_size,
const int n1,
const int n2,
V* grad_gamma,
V* grad_beta)
{
// sum partial gradients for gamma and beta
SharedMemory<U> shared;
U* buf = shared.getPointer();
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / blockDim.y;
U sum_gamma = U(0);
U sum_beta = U(0);
const U* part_grad_gamma_ptr = part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U* part_grad_beta_ptr = part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions; ++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset*n2];
sum_beta += part_grad_beta_ptr[warp_offset*n2];
}
// inter-warp reductions
const int nbsize3 = blockDim.x * blockDim.y / 2;
for (int offset = blockDim.y/2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx+nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * blockDim.x + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx+nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template<typename T, typename U, typename V> __global__
void cuComputeGradInput(
const V* __restrict__ dout,
const T* __restrict__ input,
const int n1,
const int n2,
const U* __restrict__ mean,
const U* __restrict__ invvar,
U epsilon,
const V* gamma,
T* grad_input)
{
for (auto i1=blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = invvar[i1];
const T* k_input = input + i1*n2;
const V* k_dout = dout + i1*n2;
const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
if (gamma != NULL) {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss * gamma[l+k];
sum_loss2 += c_loss * gamma[l+k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4*thrx;
for (; l+3 < n2; l+=4*numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l+k]);
const U c_loss = static_cast<U>(k_dout[l+k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = blockDim.x/2; mask > 0; mask /= 2) {
sum_loss1 += WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 += WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (blockDim.y > 1) {
SharedMemory<U> shared;
U* buf = shared.getPointer();
for (int offset = blockDim.y/2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2*offset) {
const int wrt_i = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[2*wrt_i] = sum_loss1;
buf[2*wrt_i+1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2*read_i];
sum_loss2 += buf[2*read_i+1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2*threadIdx.x] = sum_loss1;
buf[2*threadIdx.x+1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y !=0) {
sum_loss1 = buf[2*threadIdx.x];
sum_loss2 = buf[2*threadIdx.x+1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T* k_grad_input = grad_input + i1*n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l+=numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
// prevent race where buf is written again before reads are done
__syncthreads();
}
}
template<typename T, typename U, typename V>
void HostApplyLayerNorm(
V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta
)
{
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
const dim3 threads(32,4,1);
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks(1, ::min((uint64_t)n1, maxGridY), 1);
int nshared =
threads.y > 1 ?
threads.y*sizeof(U)+(threads.y/2)*sizeof(U) :
0;
hipLaunchKernelGGL(( cuApplyLayerNorm), dim3(blocks), dim3(threads), nshared, stream,
output,
mean,
invvar,
input,
n1,n2,
U(epsilon),
gamma,beta);
}
void cuda_layer_norm(
at::Tensor* output,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), output->scalar_type(), "cuda_layer_norm_kernel",
HostApplyLayerNorm(
output->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input->DATA_PTR<scalar_t_in>(),
n1,n2,
epsilon,
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
beta != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
template<typename T, typename U, typename V>
void HostLayerNormGradient(
const V* dout,
const U* mean,
const U* invvar,
at::Tensor* input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta
)
{
auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream();
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
const dim3 threads2(32,4,1);
const dim3 blocks2((n2+threads2.x-1)/threads2.x,part_size,1);
const int nshared2_a = 2 * sizeof(U) * threads2.y * threads2.y *
(threads2.x + 1);
const int nshared2_b = threads2.x * threads2.y * sizeof(U);
const int nshared2 = nshared2_a > nshared2_b ? nshared2_a : nshared2_b;
at::Tensor part_grad_gamma = at::empty(
{part_size,n2}, input->options().dtype(at::ScalarType::Float));
at::Tensor part_grad_beta = at::empty_like(part_grad_gamma);
hipLaunchKernelGGL(( cuComputePartGradGammaBeta), dim3(blocks2), dim3(threads2), nshared2, stream,
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>());
const dim3 threads3(32,8,1);
const dim3 blocks3((n2+threads2.x-1)/threads2.x,1,1);
const int nshared3 = threads3.x * threads3.y * sizeof(U);
hipLaunchKernelGGL(( cuComputeGradGammaBeta), dim3(blocks3), dim3(threads3), nshared3, stream,
part_grad_gamma.DATA_PTR<U>(),
part_grad_beta.DATA_PTR<U>(),
part_size,
n1,n2,
grad_gamma,
grad_beta);
}
// compute grad_input
const uint64_t maxGridY =
at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const dim3 blocks1(1, ::min((uint64_t)n1, maxGridY), 1);
const dim3 threads1(32,4,1);
int nshared =
threads1.y > 1 ?
threads1.y*threads1.x*sizeof(U) :
0;
hipLaunchKernelGGL(( cuComputeGradInput), dim3(blocks1), dim3(threads1), nshared, stream,
dout,
input->DATA_PTR<T>(),
n1,n2,
mean,
invvar,
U(epsilon),
gamma,
grad_input);
}
void cuda_layer_norm_gradient(
at::Tensor* dout,
at::Tensor* mean,
at::Tensor* invvar,
at::Tensor* input,
int n1,
int n2,
#ifdef VERSION_GE_1_1
at::IntArrayRef normalized_shape,
#else
at::IntList normalized_shape,
#endif
at::Tensor* gamma,
at::Tensor* beta,
double epsilon,
at::Tensor* grad_input,
at::Tensor* grad_gamma,
at::Tensor* grad_beta)
{
using namespace at;
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(
input->scalar_type(), gamma->scalar_type(),
"cuda_layer_norm_gradient_kernel",
HostLayerNormGradient(
dout->DATA_PTR<scalar_t_out>(),
mean->DATA_PTR<float>(),
invvar->DATA_PTR<float>(),
input,
n1,n2,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
gamma != NULL ? gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? beta->DATA_PTR<scalar_t_out>() : NULL,
epsilon,
grad_input->DATA_PTR<scalar_t_in>(),
gamma != NULL ? grad_gamma->DATA_PTR<scalar_t_out>() : NULL,
gamma != NULL ? grad_beta->DATA_PTR<scalar_t_out>() : NULL);
)
}
megatron/fused_kernels/scaled_masked_softmax.cpp
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
return
get_batch_per_block_cuda
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
}
// end namespace scaled_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
m
.
def
(
"get_batch_per_block"
,
&
multihead_attn
::
fused_softmax
::
scaled_masked_softmax
::
get_batch_per_block
,
"Return Batch per block size."
);
}
megatron/fused_kernels/scaled_masked_softmax.h
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <assert.h>
#include <hip/hip_fp16.h>
#include <cfloat>
#include <limits>
#include <stdint.h>
#include <hip/hip_fp16.h>
#include <c10/macros/Macros.h>
namespace
{
template
<
typename
Datatype
,
int
ELEMENTS_PER_LDG
>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
4
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
int
log2_ceil
(
int
value
)
{
int
log2_value
=
0
;
while
((
1
<<
log2_value
)
<
value
)
++
log2_value
;
return
log2_value
;
}
template
<
typename
T
>
struct
Add
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
Max
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
WARP_SHFL_XOR_NATIVE
(
T
value
,
int
laneMask
,
int
width
=
warpSize
,
unsigned
int
mask
=
0xffffffff
)
{
#if HIP_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
acc_t
,
int
WARP_BATCH
,
int
WARP_SIZE
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce
(
acc_t
*
sum
)
{
ReduceOp
<
acc_t
>
r
;
#pragma unroll
for
(
int
offset
=
WARP_SIZE
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
acc_t
b
=
WARP_SHFL_XOR_NATIVE
(
sum
[
i
],
offset
,
WARP_SIZE
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
/*
* Extended softmax (from native aten pytorch) with following additional features
* 1) input scaling
* 2) Explicit masking
*/
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
acc_t
scale
,
int
micro_batch_size
,
int
element_count
,
int
pad_batches
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_forward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
(
blockIdx
.
y
+
gridDim
.
y
*
blockIdx
.
z
))
+
threadIdx
.
y
)
*
WARP_BATCH
;
int
pad_first_batch
=
0
;
if
(
pad_batches
!=
1
)
{
// bert style
pad_first_batch
=
(
blockDim
.
y
*
(
blockIdx
.
x
+
gridDim
.
x
*
blockIdx
.
z
)
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
else
{
// gpt2 style
pad_first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
}
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
src
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
dst
+=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
mask
+=
pad_first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
// load data from global memory
acc_t
elements
[
WARP_BATCH
][
WARP_ITERATIONS
];
input_t
temp_data
[
ELEMENTS_PER_LDG_STG
];
uint8_t
temp_mask
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
int
itr_idx
=
i
*
element_count
+
it
*
WARP_SIZE
;
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_data
,
src
+
itr_idx
);
copy_vector
<
uint8_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_mask
,
mask
+
itr_idx
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
if
(
temp_mask
[
element
]
!=
1
)
{
elements
[
i
][
it
+
element
]
=
(
acc_t
)
temp_data
[
element
]
*
scale
;
}
else
{
elements
[
i
][
it
+
element
]
=
-
10000.0
;
}
}
}
else
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
elements
[
i
][
it
+
element
]
=
-
std
::
numeric_limits
<
acc_t
>::
infinity
();
}
}
}
}
// compute max_value
acc_t
max_value
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
max_value
[
i
]
=
elements
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
elements
[
i
][
it
])
?
max_value
[
i
]
:
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Max
>
(
max_value
);
acc_t
sum
[
WARP_BATCH
]
{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
++
it
)
{
elements
[
i
][
it
]
=
std
::
exp
((
elements
[
i
][
it
]
-
max_value
[
i
]));
sum
[
i
]
+=
elements
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
elements
[
i
][
it
+
element
]
/
sum
[
i
];
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
dst
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
else
{
break
;
}
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
,
int
log2_elements
>
__global__
void
scaled_masked_softmax_warp_backward
(
output_t
*
gradInput
,
input_t
*
grad
,
const
input_t
*
output
,
acc_t
scale
,
int
micro_batch_size
,
int
element_count
)
{
// WARP_SIZE and WARP_BATCH must match the return values batches_per_warp and
// warp_size of method warp_softmax_backward_kernel.
constexpr
int
next_power_of_two
=
1
<<
log2_elements
;
constexpr
int
WARP_SIZE
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
constexpr
int
WARP_ITERATIONS
=
next_power_of_two
/
WARP_SIZE
;
constexpr
int
WARP_BATCH
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
ELEMENTS_PER_LDG_STG
=
(
WARP_ITERATIONS
<
4
)
?
1
:
4
;
// blockDim/threadIdx = (WARP_SIZE, WARPS_PER_BLOCK, )
// gridDim/blockIdx = (seq_len, attn_heads, batches)
int
first_batch
=
(
blockDim
.
y
*
blockIdx
.
x
+
threadIdx
.
y
)
*
WARP_BATCH
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
micro_batch_size
-
first_batch
;
if
(
local_batches
>
WARP_BATCH
)
local_batches
=
WARP_BATCH
;
// there might be multiple batches per warp. compute the index within the batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
thread_offset
=
first_batch
*
element_count
+
ELEMENTS_PER_LDG_STG
*
local_idx
;
grad
+=
thread_offset
;
output
+=
thread_offset
;
gradInput
+=
thread_offset
;
// load data from global memory
acc_t
grad_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
acc_t
output_reg
[
WARP_BATCH
][
WARP_ITERATIONS
]
{
0.0
f
};
input_t
temp_grad
[
ELEMENTS_PER_LDG_STG
];
input_t
temp_output
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
int
batch_element_count
=
(
i
>=
local_batches
)
?
0
:
element_count
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
batch_element_count
)
{
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_grad
,
grad
+
i
*
element_count
+
it
*
WARP_SIZE
);
copy_vector
<
input_t
,
ELEMENTS_PER_LDG_STG
>
(
temp_output
,
output
+
i
*
element_count
+
it
*
WARP_SIZE
);
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
output_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_output
[
element
];
}
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
grad_reg
[
i
][
it
+
element
]
=
(
acc_t
)
temp_grad
[
element
]
*
output_reg
[
i
][
it
+
element
];
}
}
}
}
acc_t
sum
[
WARP_BATCH
];
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
sum
[
i
]
=
grad_reg
[
i
][
0
];
#pragma unroll
for
(
int
it
=
1
;
it
<
WARP_ITERATIONS
;
++
it
)
{
sum
[
i
]
+=
grad_reg
[
i
][
it
];
}
}
warp_reduce
<
acc_t
,
WARP_BATCH
,
WARP_SIZE
,
Add
>
(
sum
);
// store result
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_BATCH
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
it
=
0
;
it
<
WARP_ITERATIONS
;
it
+=
ELEMENTS_PER_LDG_STG
)
{
int
element_index
=
ELEMENTS_PER_LDG_STG
*
local_idx
+
it
*
WARP_SIZE
;
if
(
element_index
<
element_count
)
{
// compute gradients
output_t
out
[
ELEMENTS_PER_LDG_STG
];
#pragma unroll
for
(
int
element
=
0
;
element
<
ELEMENTS_PER_LDG_STG
;
++
element
)
{
out
[
element
]
=
(
output_t
)(
scale
*
(
grad_reg
[
i
][
it
+
element
]
-
output_reg
[
i
][
it
+
element
]
*
sum
[
i
]));
}
copy_vector
<
output_t
,
ELEMENTS_PER_LDG_STG
>
(
gradInput
+
i
*
element_count
+
it
*
WARP_SIZE
,
out
);
}
}
}
}
}
// end of anonymous namespace
int
get_batch_per_block
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
return
batches_per_block
;
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_forward
(
output_t
*
dst
,
const
input_t
*
src
,
const
uint8_t
*
mask
,
const
input_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
,
int
pad_batches
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_forward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_forward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
TORCH_INTERNAL_ASSERT
(
query_seq_len
%
batches_per_block
==
0
);
dim3
blocks
(
query_seq_len
/
batches_per_block
,
attn_heads
,
batches
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_forward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
dst
,
src
,
mask
,
scale
,
batch_count
,
key_seq_len
,
pad_batches
);
break
;
default:
break
;
}
}
}
template
<
typename
input_t
,
typename
output_t
,
typename
acc_t
>
void
dispatch_scaled_masked_softmax_backward
(
output_t
*
grad_input
,
input_t
*
grad
,
const
input_t
*
output
,
const
acc_t
scale
,
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
)
{
TORCH_INTERNAL_ASSERT
(
key_seq_len
>=
0
&&
key_seq_len
<=
4096
);
if
(
key_seq_len
==
0
)
{
return
;
}
else
{
int
log2_elements
=
log2_ceil
(
key_seq_len
);
const
int
next_power_of_two
=
1
<<
log2_elements
;
int
batch_count
=
batches
*
attn_heads
*
query_seq_len
;
// This value must match the WARP_SIZE constexpr value computed inside softmax_warp_backward.
int
warp_size
=
(
next_power_of_two
<
C10_WARP_SIZE
)
?
next_power_of_two
:
C10_WARP_SIZE
;
// This value must match the WARP_BATCH constexpr value computed inside softmax_warp_backward.
int
batches_per_warp
=
(
next_power_of_two
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximimize gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
int
blocks
=
batch_count
/
batches_per_block
;
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
// Launch code would be more elegant if C++ supported FOR CONSTEXPR
switch
(
log2_elements
)
{
case
0
:
// 1
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
0
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
1
:
// 2
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
1
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
2
:
// 4
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
2
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
3
:
// 8
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
3
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
4
:
// 16
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
4
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
5
:
// 32
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
5
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
6
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
7
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
8
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
9
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
10
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
11
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
scaled_masked_softmax_warp_backward
<
input_t
,
output_t
,
acc_t
,
12
>
<<<
blocks
,
threads
,
0
,
at
::
hip
::
getCurrentHIPStream
()
>>>
(
grad_input
,
grad
,
output
,
scale
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
}
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_masked_softmax
{
int
get_batch_per_block_cuda
(
int
query_seq_len
,
int
key_seq_len
,
int
batches
,
int
attn_heads
){
return
get_batch_per_block
(
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
}
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
pad_batches
=
mask
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
TORCH_INTERNAL_ASSERT
(
pad_batches
==
1
||
pad_batches
==
batches
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
1
)
==
1
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
2
)
==
query_seq_len
);
TORCH_INTERNAL_ASSERT
(
mask
.
size
(
3
)
==
key_seq_len
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_masked_softmax_forward"
,
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/fused_kernels/scaled_masked_softmax_hip.hip
0 → 100644
View file @
1016e98a
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_profiler_api.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_masked_softmax {
int get_batch_per_block_cuda(int query_seq_len, int key_seq_len, int batches, int attn_heads){
return get_batch_per_block(query_seq_len, key_seq_len, batches, attn_heads);
}
torch::Tensor fwd_cuda(
torch::Tensor const& input,
torch::Tensor const& mask,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int pad_batches = mask.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
TORCH_INTERNAL_ASSERT(pad_batches == 1 || pad_batches == batches);
TORCH_INTERNAL_ASSERT(mask.size(1) == 1);
TORCH_INTERNAL_ASSERT(mask.size(2) == query_seq_len);
TORCH_INTERNAL_ASSERT(mask.size(3) == key_seq_len);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* mask_ptr = static_cast<void*>(mask.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_masked_softmax_forward",
dispatch_scaled_masked_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
reinterpret_cast<const uint8_t*>(mask_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads,
pad_batches);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
megatron/fused_kernels/scaled_softmax.cpp
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
fwd
,
"Self Multihead Attention scaled, softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_softmax
::
bwd
,
"Self Multihead Attention scaled, softmax -- Backward."
);
}
megatron/fused_kernels/scaled_softmax_cuda.cu
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_softmax_forward"
,
dispatch_scaled_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/fused_kernels/scaled_softmax_hip.hip
0 → 100644
View file @
1016e98a
// !!! This is a file automatically generated by hipify!!!
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
//#include <cuda_profiler_api.h>
#include <ATen/hip/HIPContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace multihead_attn {
namespace fused_softmax {
namespace scaled_softmax {
torch::Tensor fwd_cuda(
torch::Tensor const& input,
float scale_factor)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = input.size(0);
const int attn_heads = input.size(1);
const int query_seq_len = input.size(2);
const int key_seq_len = input.size(3);
TORCH_INTERNAL_ASSERT(key_seq_len <= 4096);
TORCH_INTERNAL_ASSERT(query_seq_len > 1);
// Output
auto act_options = input.options().requires_grad(false);
torch::Tensor softmax_results =
torch::empty({batches, attn_heads, query_seq_len, key_seq_len}, act_options);
// Softmax Intermediate Result Ptr
void* input_ptr = static_cast<void*>(input.data_ptr());
void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
DISPATCH_HALF_AND_BFLOAT(
input.scalar_type(),
"dispatch_scaled_softmax_forward",
dispatch_scaled_softmax_forward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(softmax_results_ptr),
reinterpret_cast<const scalar_t*>(input_ptr),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
return softmax_results;
}
torch::Tensor bwd_cuda(
torch::Tensor const& output_grads_,
torch::Tensor const& softmax_results_,
float scale_factor) {
auto output_grads = output_grads_.contiguous();
auto softmax_results = softmax_results_.contiguous();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const int batches = output_grads.size(0);
const int attn_heads = output_grads.size(1);
const int query_seq_len = output_grads.size(2);
const int key_seq_len = output_grads.size(3);
void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT(
output_grads_.scalar_type(),
"dispatch_scaled_masked_softmax_backward",
dispatch_scaled_masked_softmax_backward<scalar_t, scalar_t, float>(
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t*>(output_grads_ptr),
reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
scale_factor,
query_seq_len,
key_seq_len,
batches,
attn_heads);
);
//backward pass is completely in-place
return output_grads;
}
}
}
}
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
0 → 100644
View file @
1016e98a
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <hip/hip_fp16.h>
#include <torch/extension.h>
#include <vector>
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_upper_triang_masked_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
);
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
);
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
}
torch
::
Tensor
bwd
(
torch
::
Tensor
const
&
output_grads
,
torch
::
Tensor
const
&
softmax_results
,
float
scale_factor
)
{
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
// end namespace scaled_upper_triang_masked_softmax
}
// end namespace fused_softmax
}
// end namespace multihead_attn
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
"Self Multihead Attention scaled, time masked softmax -- Backward."
);
}
Prev
1
2
3
4
5
6
7
8
9
…
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment