Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0fa7175f
Commit
0fa7175f
authored
Mar 19, 2021
by
Mohammad Shoeybi
Committed by
Jared Casper
Mar 19, 2021
Browse files
Bfloat fused softmax + fused layer norm
parent
d9b1c681
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
282 additions
and
517 deletions
+282
-517
megatron/arguments.py
megatron/arguments.py
+1
-2
megatron/fused_kernels/__init__.py
megatron/fused_kernels/__init__.py
+5
-6
megatron/fused_kernels/layer_norm_cuda.cpp
megatron/fused_kernels/layer_norm_cuda.cpp
+31
-50
megatron/fused_kernels/layer_norm_cuda_kernel.cu
megatron/fused_kernels/layer_norm_cuda_kernel.cu
+33
-33
megatron/fused_kernels/scaled_masked_softmax.cpp
megatron/fused_kernels/scaled_masked_softmax.cpp
+9
-6
megatron/fused_kernels/scaled_masked_softmax.h
megatron/fused_kernels/scaled_masked_softmax.h
+8
-2
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
+28
-19
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
...tron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
+10
-7
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
+16
-5
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
.../fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
+25
-15
megatron/fused_kernels/type_shim.h
megatron/fused_kernels/type_shim.h
+69
-205
megatron/model/__init__.py
megatron/model/__init__.py
+1
-17
megatron/model/bert_model.py
megatron/model/bert_model.py
+1
-2
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+28
-117
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+13
-5
megatron/model/transformer.py
megatron/model/transformer.py
+3
-13
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+1
-3
megatron/training.py
megatron/training.py
+0
-10
No files found.
megatron/arguments.py
View file @
0fa7175f
...
@@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
...
@@ -133,8 +133,7 @@ def parse_args(extra_args_provider=None, defaults={},
if
args
.
bf16
:
if
args
.
bf16
:
assert
not
args
.
fp16
assert
not
args
.
fp16
args
.
params_dtype
=
torch
.
bfloat16
args
.
params_dtype
=
torch
.
bfloat16
# No fusion is support for bfloat for now
# Jitting fusion is not supported for bfloat for now
assert
not
args
.
masked_softmax_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_gelu_fusion
assert
not
args
.
bias_dropout_fusion
assert
not
args
.
bias_dropout_fusion
...
...
megatron/fused_kernels/__init__.py
View file @
0fa7175f
...
@@ -82,12 +82,11 @@ def load(args):
...
@@ -82,12 +82,11 @@ def load(args):
# Mixed precision fused layer norm.
# Mixed precision fused layer norm.
# =================================
# =================================
if
args
.
fp32_residual_connection
:
extra_cuda_flags
=
[
'-maxrregcount=50'
]
extra_cuda_flags
=
[
'-maxrregcount=50'
]
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
sources
=
[
srcpath
/
'layer_norm_cuda.cpp'
,
srcpath
/
'layer_norm_cuda_kernel.cu'
]
srcpath
/
'layer_norm_cuda_kernel.cu'
]
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
fused_mix_prec_layer_norm_cuda
=
_cpp_extention_load_helper
(
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
"fused_mix_prec_layer_norm_cuda"
,
sources
,
extra_cuda_flags
)
def
_get_cuda_bare_metal_version
(
cuda_dir
):
def
_get_cuda_bare_metal_version
(
cuda_dir
):
...
...
megatron/fused_kernels/layer_norm_cuda.cpp
View file @
0fa7175f
...
@@ -24,12 +24,12 @@
...
@@ -24,12 +24,12 @@
#include "compat.h"
#include "compat.h"
namespace
{
namespace
{
void
compute_n1_n2
(
void
compute_n1_n2
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
int
&
n1
,
int
&
n1
,
int
&
n2
)
int
&
n2
)
{
{
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
int
idiff
=
input
.
ndimension
()
-
normalized_shape
.
size
();
n2
=
1
;
n2
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
(
int
)
normalized_shape
.
size
();
++
i
)
{
...
@@ -118,39 +118,33 @@ void cuda_layer_norm(
...
@@ -118,39 +118,33 @@ void cuda_layer_norm(
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
std
::
vector
<
at
::
Tensor
>
layer_norm
(
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
double
epsilon
)
{
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
return
{
output
,
mean
,
invvar
};
}
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_affine
(
at
::
Tensor
input
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
at
::
IntArrayRef
normalized_shape
,
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
,
input
.
options
().
dtype
(
at
::
ScalarType
::
Half
));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
output
=
at
::
empty_like
(
input
,
gamma
.
options
().
dtype
(
gamma
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
(
{
n1
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
return
{
output
,
mean
,
invvar
};
return
{
output
,
mean
,
invvar
};
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -167,25 +161,6 @@ void cuda_layer_norm_gradient(
...
@@ -167,25 +161,6 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
at
::
Tensor
*
grad_beta
);
);
at
::
Tensor
layer_norm_gradient
(
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
invvar
,
at
::
Tensor
input
,
at
::
IntArrayRef
normalized_shape
,
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
,
&
grad_input
,
NULL
,
NULL
);
return
grad_input
;
}
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
std
::
vector
<
at
::
Tensor
>
layer_norm_gradient_affine
(
at
::
Tensor
dout
,
at
::
Tensor
dout
,
at
::
Tensor
mean
,
at
::
Tensor
mean
,
...
@@ -195,26 +170,32 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
...
@@ -195,26 +170,32 @@ std::vector<at::Tensor> layer_norm_gradient_affine(
at
::
Tensor
gamma
,
at
::
Tensor
gamma
,
at
::
Tensor
beta
,
at
::
Tensor
beta
,
double
epsilon
)
{
double
epsilon
)
{
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
dout
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
mean
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
invvar
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
input
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
gamma
);
CHECK_INPUT
(
beta
);
CHECK_INPUT
(
beta
);
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_input
=
at
::
empty_like
(
input
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_gamma
=
at
::
empty_like
(
gamma
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
at
::
Tensor
grad_beta
=
at
::
empty_like
(
beta
);
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
cuda_layer_norm_gradient
(
&
dout
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
,
&
grad_input
,
&
grad_gamma
,
&
grad_beta
);
return
{
grad_input
,
grad_gamma
,
grad_beta
};
return
{
grad_input
,
grad_gamma
,
grad_beta
};
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
"LayerNorm forward (CUDA)"
);
m
.
def
(
"forward_affine"
,
&
layer_norm_affine
,
m
.
def
(
"forward"
,
&
layer_norm
,
"LayerNorm forward (CUDA)"
);
"LayerNorm forward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
"LayerNorm backward (CUDA)"
);
m
.
def
(
"backward_affine"
,
&
layer_norm_gradient_affine
,
m
.
def
(
"backward"
,
&
layer_norm_gradient
,
"LayerNorm backward (CUDA)"
);
"LayerNorm backward (CUDA)"
);
}
}
megatron/fused_kernels/layer_norm_cuda_kernel.cu
View file @
0fa7175f
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
...
@@ -285,15 +285,6 @@ struct SharedMemory <float>
}
}
};
};
template
<
>
struct
SharedMemory
<
double
>
{
__device__
double
*
getPointer
()
{
extern
__shared__
double
s_double
[];
return
s_double
;
}
};
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
template
<
typename
T
,
typename
U
,
typename
V
>
__global__
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
...
@@ -656,6 +647,9 @@ void cuComputeGradInput(
}
}
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostApplyLayerNorm
(
void
HostApplyLayerNorm
(
V
*
output
,
V
*
output
,
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
...
@@ -671,7 +665,8 @@ void HostApplyLayerNorm(
{
{
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
const
dim3
threads
(
32
,
4
,
1
);
const
dim3
threads
(
32
,
4
,
1
);
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
int
nshared
=
int
nshared
=
threads
.
y
>
1
?
threads
.
y
>
1
?
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
...
@@ -687,6 +682,7 @@ void HostApplyLayerNorm(
gamma
,
beta
);
gamma
,
beta
);
}
}
void
cuda_layer_norm
(
void
cuda_layer_norm
(
at
::
Tensor
*
output
,
at
::
Tensor
*
output
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
...
@@ -704,21 +700,21 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
output
->
scalar_type
(),
"cuda_layer_norm_kernel"
,
using
output_t
=
at
::
Half
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
DATA_PTR
<
output_
t
>
(),
output
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
->
DATA_PTR
<
scalar_t_
0
>
(),
input
->
DATA_PTR
<
scalar_t_
in
>
(),
n1
,
n2
,
n1
,
n2
,
epsilon
,
epsilon
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
beta
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
template
<
typename
T
,
typename
U
,
typename
V
>
template
<
typename
T
,
typename
U
,
typename
V
>
void
HostLayerNormGradient
(
void
HostLayerNormGradient
(
const
V
*
dout
,
const
V
*
dout
,
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
...
@@ -742,10 +738,12 @@ void HostLayerNormGradient(
const
int
part_size
=
16
;
const
int
part_size
=
16
;
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
threads2
(
32
,
4
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
dim3
blocks2
((
n2
+
threads2
.
x
-
1
)
/
threads2
.
x
,
part_size
,
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
(
{
part_size
,
n2
},
input
->
options
().
dtype
(
at
::
ScalarType
::
Float
));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
...
@@ -770,7 +768,8 @@ void HostLayerNormGradient(
}
}
// compute grad_input
// compute grad_input
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
uint64_t
maxGridY
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
maxGridSize
[
1
];
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
blocks1
(
1
,
std
::
min
((
uint64_t
)
n1
,
maxGridY
),
1
);
const
dim3
threads1
(
32
,
4
,
1
);
const
dim3
threads1
(
32
,
4
,
1
);
int
nshared
=
int
nshared
=
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
...
@@ -788,6 +787,7 @@ void HostLayerNormGradient(
grad_input
);
grad_input
);
}
}
void
cuda_layer_norm_gradient
(
void
cuda_layer_norm_gradient
(
at
::
Tensor
*
dout
,
at
::
Tensor
*
dout
,
at
::
Tensor
*
mean
,
at
::
Tensor
*
mean
,
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
...
@@ -808,22 +808,22 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_
AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
DISPATCH_FLOAT_
HALF_AND_BFLOAT_INOUT_TYPES
(
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
input
->
scalar_type
(),
gamma
->
scalar_type
(),
using
output_t
=
at
::
Half
;
"cuda_layer_norm_gradient_kernel"
,
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
DATA_PTR
<
output_
t
>
(),
dout
->
DATA_PTR
<
scalar_t_ou
t
>
(),
mean
->
DATA_PTR
<
accscalar_
t
>
(),
mean
->
DATA_PTR
<
floa
t
>
(),
invvar
->
DATA_PTR
<
accscalar_
t
>
(),
invvar
->
DATA_PTR
<
floa
t
>
(),
input
,
input
,
n1
,
n2
,
n1
,
n2
,
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// TMJ pass NULL argument for gamma, beta, grad_gamma and grad_beta
// if gamma Tensor is NULL on input.
// if gamma Tensor is NULL on input.
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
epsilon
,
epsilon
,
grad_input
->
DATA_PTR
<
scalar_t_
0
>
(),
grad_input
->
DATA_PTR
<
scalar_t_
in
>
(),
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
output_
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_gamma
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
,
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
output_
t
>
()
:
NULL
);
gamma
!=
NULL
?
grad_beta
->
DATA_PTR
<
scalar_t_ou
t
>
()
:
NULL
);
)
)
}
}
megatron/fused_kernels/scaled_masked_softmax.cpp
View file @
0fa7175f
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
...
@@ -37,8 +37,9 @@ torch::Tensor fwd(
torch
::
Tensor
const
&
mask
,
torch
::
Tensor
const
&
mask
,
float
scale_factor
)
{
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
AT_ASSERTM
(
mask
.
dim
()
==
4
,
"expected 4D tensor"
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
return
fwd_cuda
(
input
,
mask
,
scale_factor
);
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
...
@@ -52,10 +53,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
4
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
...
megatron/fused_kernels/scaled_masked_softmax.h
View file @
0fa7175f
...
@@ -30,10 +30,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
...
@@ -30,10 +30,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
4
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
4
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
...
...
megatron/fused_kernels/scaled_masked_softmax_cuda.cu
View file @
0fa7175f
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -55,16 +56,20 @@ torch::Tensor fwd_cuda(
...
@@ -55,16 +56,20 @@ torch::Tensor fwd_cuda(
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
mask_ptr
=
static_cast
<
void
*>
(
mask
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_masked_softmax_forward"
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
dispatch_scaled_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
query_seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
key_seq_len
,
reinterpret_cast
<
const
uint8_t
*>
(
mask_ptr
),
batches
,
scale_factor
,
attn_heads
,
query_seq_len
,
pad_batches
);
key_seq_len
,
batches
,
attn_heads
,
pad_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
...
@@ -85,15 +90,19 @@ torch::Tensor bwd_cuda(
...
@@ -85,15 +90,19 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
query_seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
key_seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
batches
,
scale_factor
,
attn_heads
);
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.cpp
View file @
0fa7175f
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
...
@@ -33,8 +33,9 @@ torch::Tensor bwd_cuda(
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
torch
::
Tensor
fwd
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
fwd_cuda
(
input
,
scale_factor
);
return
fwd_cuda
(
input
,
scale_factor
);
}
}
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
...
@@ -47,10 +48,12 @@ torch::Tensor bwd(
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
softmax_results
.
dim
()
==
3
,
"expected 3D tensor"
);
AT_ASSERTM
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
AT_ASSERTM
((
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
"Only HALF is supported"
);
(
output_grads
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
AT_ASSERTM
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"Only fp16 and bf16 are supported"
);
"Only HALF is supported"
);
AT_ASSERTM
((
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
||
(
softmax_results
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
),
"Only fp16 and bf16 are supported"
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
return
bwd_cuda
(
output_grads
,
softmax_results
,
scale_factor
);
}
}
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
...
@@ -61,7 +64,7 @@ torch::Tensor bwd(
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
m
.
def
(
"forward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
fwd
,
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
"Self Multihead Attention scaled, time masked softmax -- Forward."
);
m
.
def
(
"backward"
,
m
.
def
(
"backward"
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
&
multihead_attn
::
fused_softmax
::
scaled_upper_triang_masked_softmax
::
bwd
,
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
View file @
0fa7175f
...
@@ -21,7 +21,6 @@
...
@@ -21,7 +21,6 @@
#include <cfloat>
#include <cfloat>
#include <limits>
#include <limits>
#include <stdint.h>
#include <stdint.h>
#include <cuda_fp16.h>
#include <c10/macros/Macros.h>
#include <c10/macros/Macros.h>
namespace
{
namespace
{
...
@@ -30,10 +29,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
...
@@ -30,10 +29,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
__device__
__inline__
void
copy_vector
(
Datatype
*
dst
,
const
Datatype
*
src
);
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__half
,
1
>
(
__half
*
dst
,
const
__half
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
,
const
c10
::
BFloat16
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
__device__
__inline__
void
copy_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
,
const
c10
::
Half
*
src
)
{
*
dst
=
*
src
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
__h
alf
,
4
>
(
__h
alf
*
dst
,
const
__h
alf
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
__device__
__inline__
void
copy_vector
<
c10
::
H
alf
,
4
>
(
c10
::
H
alf
*
dst
,
const
c10
::
H
alf
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
template
<
>
template
<
>
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
__device__
__inline__
void
copy_vector
<
uint8_t
,
1
>
(
uint8_t
*
dst
,
const
uint8_t
*
src
)
{
*
dst
=
*
src
;
}
...
@@ -45,10 +50,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
...
@@ -45,10 +50,16 @@ template <typename Datatype, int ELEMENTS_PER_LDG>
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
__device__
__inline__
void
copy_zero_vector
(
Datatype
*
dst
);
template
<
>
template
<
>
__device__
__inline__
void
copy_zero_vector
<
__half
,
1
>
(
__half
*
dst
)
{
*
dst
=
0.0
;
}
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
1
>
(
c10
::
BFloat16
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
BFloat16
,
4
>
(
c10
::
BFloat16
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
template
<
>
__device__
__inline__
void
copy_zero_vector
<
c10
::
Half
,
1
>
(
c10
::
Half
*
dst
)
{
*
dst
=
0.0
;
}
template
<
>
template
<
>
__device__
__inline__
void
copy_zero_vector
<
__h
alf
,
4
>
(
__h
alf
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
__device__
__inline__
void
copy_zero_vector
<
c10
::
H
alf
,
4
>
(
c10
::
H
alf
*
dst
)
{
*
((
float2
*
)
dst
)
=
make_float2
(
0.0
f
,
0.0
f
);
}
int
log2_ceil
(
int
value
)
{
int
log2_ceil
(
int
value
)
{
...
...
megatron/fused_kernels/scaled_upper_triang_masked_softmax_cuda.cu
View file @
0fa7175f
...
@@ -22,6 +22,7 @@
...
@@ -22,6 +22,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "scaled_upper_triang_masked_softmax.h"
#include "scaled_upper_triang_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
fused_softmax
{
...
@@ -45,15 +46,20 @@ torch::Tensor fwd_cuda(
...
@@ -45,15 +46,20 @@ torch::Tensor fwd_cuda(
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
dispatch_scaled_upper_triang_masked_softmax_forward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
softmax_results_ptr
),
input
.
scalar_type
(),
reinterpret_cast
<
const
half
*>
(
input_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_forward"
,
scale_factor
,
dispatch_scaled_upper_triang_masked_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
seq_len
,
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
return
softmax_results
;
return
softmax_results
;
}
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
output_grads_
,
...
@@ -71,14 +77,18 @@ torch::Tensor bwd_cuda(
...
@@ -71,14 +77,18 @@ torch::Tensor bwd_cuda(
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
//Softmax Grad
dispatch_scaled_upper_triang_masked_softmax_backward
<
half
,
half
,
float
>
(
DISPATCH_HALF_AND_BFLOAT
(
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
output_grads_
.
scalar_type
(),
reinterpret_cast
<
half
*>
(
output_grads_ptr
),
"dispatch_scaled_upper_triang_masked_softmax_backward"
,
reinterpret_cast
<
half
const
*>
(
softmax_results
.
data_ptr
()),
dispatch_scaled_upper_triang_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
scale_factor
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
seq_len
,
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
attn_batches
);
scale_factor
,
seq_len
,
seq_len
,
attn_batches
);
);
//backward pass is completely in-place
//backward pass is completely in-place
return
output_grads
;
return
output_grads
;
...
...
megatron/fused_kernels/type_shim.h
View file @
0fa7175f
...
@@ -14,214 +14,78 @@
...
@@ -14,214 +14,78 @@
* limitations under the License.
* limitations under the License.
*/
*/
/*This code is copied fron NVIDIA apex:
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "compat.h"
#include "compat.h"
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// struct TypeShim
// {
// const at::Type& payload;
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
switch(TYPE) \
{ \
{ \
case at::ScalarType::Float: \
case at::ScalarType::Half: \
{ \
{ \
using scalar_t_##LEVEL = float; \
using scalar_t = at::Half; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::Half: \
case at::ScalarType::BFloat16: \
{ \
{ \
using scalar_t_##LEVEL = at::Half; \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
default: \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_in = float; \
switch(TYPEOUT) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_out = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
} \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_in = at::Half; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
T
val
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
megatron/model/__init__.py
View file @
0fa7175f
...
@@ -13,23 +13,7 @@
...
@@ -13,23 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
_LAYER_NORM
=
None
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
def
import_layernorm
(
fp32_residual_connection
,
bf16
):
global
_LAYER_NORM
if
not
_LAYER_NORM
:
if
bf16
:
from
torch.nn
import
LayerNorm
elif
fp32_residual_connection
:
from
.fused_layer_norm
import
MixedFusedLayerNorm
as
LayerNorm
else
:
from
apex.normalization.fused_layer_norm
import
FusedLayerNorm
as
LayerNorm
_LAYER_NORM
=
LayerNorm
return
_LAYER_NORM
from
.distributed
import
*
from
.distributed
import
*
from
.bert_model
import
(
BertModel
,
from
.bert_model
import
(
BertModel
,
...
...
megatron/model/bert_model.py
View file @
0fa7175f
...
@@ -22,7 +22,7 @@ from megatron import mpu
...
@@ -22,7 +22,7 @@ from megatron import mpu
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.enums
import
AttnMaskType
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
parallel_lm_logits
from
megatron.model.language_model
import
get_language_model
from
megatron.model.language_model
import
get_language_model
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
from
megatron.model.utils
import
init_method_normal
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
...
@@ -78,7 +78,6 @@ class BertLMHead(MegatronModule):
self
.
parallel_output
=
parallel_output
self
.
parallel_output
=
parallel_output
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
layernorm_epsilon
)
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
self
.
gelu
=
torch
.
nn
.
functional
.
gelu
if
args
.
openai_gelu
:
if
args
.
openai_gelu
:
...
...
megatron/model/fused_layer_norm.py
View file @
0fa7175f
...
@@ -15,29 +15,23 @@
...
@@ -15,29 +15,23 @@
"""This code is copied fron NVIDIA apex:
"""This code is copied fron NVIDIA apex:
https://github.com/NVIDIA/apex
https://github.com/NVIDIA/apex
with
minor
changes. """
with
some
changes. """
import
math
import
torch
import
numbers
import
numbers
import
torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
torch.nn
import
init
from
torch.nn
import
init
from
torch.nn
import
functional
as
F
import
importlib
import
importlib
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
None
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
fused_mix_prec_layer_norm_cuda
=
None
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
class
FusedLayerNormAffineFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
normalized_shape
,
eps
):
global
fused_mix_prec_layer_norm_cuda
if
fused_mix_prec_layer_norm_cuda
is
None
:
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
input_
=
input
.
contiguous
()
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -46,134 +40,51 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
output
,
mean
,
invvar
=
fused_mix_prec_layer_norm_cuda
.
forward_affine
(
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
ctx
.
save_for_backward
(
input_
,
weight_
,
bias_
,
mean
,
invvar
)
return
output
return
output
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
input_
,
weight_
,
bias_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
=
grad_weight
=
grad_bias
=
None
grad_input
,
grad_weight
,
grad_bias
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_input
,
grad_weight
,
grad_bias
\
=
fused_mix_prec_layer_norm_cuda
.
backward_affine
(
grad_output
.
contiguous
(),
mean
,
invvar
,
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
input_
,
ctx
.
normalized_shape
,
weight_
,
bias_
,
ctx
.
eps
)
weight_
,
bias_
,
ctx
.
eps
)
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
class
FusedLayerNormFunction
(
torch
.
autograd
.
Function
):
@
staticmethod
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
def
forward
(
ctx
,
input
,
normalized_shape
,
eps
):
global
fused_layer_norm_cuda
if
fused_layer_norm_cuda
is
None
:
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
ctx
.
normalized_shape
=
normalized_shape
ctx
.
eps
=
eps
input_
=
input
.
contiguous
()
output
,
mean
,
invvar
=
fused_layer_norm_cuda
.
forward
(
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
ctx
.
save_for_backward
(
input_
,
mean
,
invvar
)
return
output
@
staticmethod
def
backward
(
ctx
,
grad_output
):
input_
,
mean
,
invvar
=
ctx
.
saved_tensors
grad_input
=
None
grad_input
=
fused_layer_norm_cuda
.
backward
(
grad_output
.
contiguous
(),
mean
,
invvar
,
input_
,
ctx
.
normalized_shape
,
ctx
.
eps
)
return
grad_input
,
None
,
None
def
fused_layer_norm_affine
(
input
,
normalized_shape
,
weight
,
bias
,
eps
=
1e-6
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
weight
,
bias
,
normalized_shape
,
eps
)
def
fused_layer_norm
(
input
,
normalized_shape
,
eps
=
1e-6
):
return
FusedLayerNormFunction
.
apply
(
input
,
normalized_shape
,
eps
)
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
r
"""Applies Layer Normalization over a mini-batch of inputs as described in
the paper `Layer Normalization`_ .
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
Currently only runs on cuda() tensors.
.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
The mean and standard-deviation are calculated separately over the last
certain number dimensions which have to be of the shape specified by
:attr:`normalized_shape`.
:math:`\gamma` and :math:`\beta` are learnable affine transform parameters of
:attr:`normalized_shape` if :attr:`elementwise_affine` is ``True``.
.. note::
Unlike Batch Normalization and Instance Normalization, which applies
scalar scale and bias for each entire channel/plane with the
:attr:`affine` option, Layer Normalization applies per-element scale and
bias with :attr:`elementwise_affine`.
This layer uses statistics computed from input data in both training and
evaluation modes.
Args:
normalized_shape (int or list or torch.Size): input shape from an expected input
of size
.. math::
[* \times \text{normalized}\_\text{shape}[0] \times \text{normalized}\_\text{shape}[1]
\times \ldots \times \text{normalized}\_\text{shape}[-1]]
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
eps: a value added to the denominator for numerical stability. Default: 1e-5
elementwise_affine: a boolean value that when set to ``True``, this module
has learnable per-element affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.
Shape:
- Input: :math:`(N, *)`
- Output: :math:`(N, *)` (same shape as input)
Examples::
>>> input = torch.randn(20, 5, 10, 10)
>>> # With Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:])
>>> # Without Learnable Parameters
>>> m = apex.normalization.FusedLayerNorm(input.size()[1:], elementwise_affine=False)
>>> # Normalize over last two dimensions
>>> m = apex.normalization.FusedLayerNorm([10, 10])
>>> # Normalize over last dimension of size 10
>>> m = apex.normalization.FusedLayerNorm(10)
>>> # Activating the module
>>> output = m(input)
.. _`Layer Normalization`: https://arxiv.org/abs/1607.06450
"""
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
elementwise_affine
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_layer_norm_cuda
fused_layer_norm_cuda
=
importlib
.
import_module
(
"fused_layer_norm_cuda"
)
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
self
.
eps
=
eps
self
.
eps
=
eps
self
.
elementwise_affine
=
elementwise_affine
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
if
self
.
elementwise_affine
:
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
else
:
self
.
register_parameter
(
'weight'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
reset_parameters
()
self
.
reset_parameters
()
def
reset_parameters
(
self
):
if
self
.
elementwise_affine
:
def
reset_parameters
(
self
):
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
init
.
ones_
(
self
.
weight
)
init
.
zeros_
(
self
.
bias
)
def
forward
(
self
,
input
):
if
not
input
.
is_cuda
:
return
F
.
layer_norm
(
def
forward
(
self
,
input
):
input
,
self
.
normalized_shape
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
if
self
.
elementwise_affine
:
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FusedLayerNormFunction
.
apply
(
input
,
self
.
normalized_shape
,
self
.
eps
)
def
extra_repr
(
self
):
return
'{normalized_shape}, eps={eps}, '
\
'elementwise_affine={elementwise_affine}'
.
format
(
**
self
.
__dict__
)
megatron/model/fused_softmax.py
View file @
0fa7175f
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -96,6 +96,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
def
__init__
(
def
__init__
(
self
,
self
,
input_in_fp16
,
input_in_fp16
,
input_in_bf16
,
attn_mask_type
,
attn_mask_type
,
scaled_masked_softmax_fusion
,
scaled_masked_softmax_fusion
,
mask_func
,
mask_func
,
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -104,6 +105,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
):
):
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
super
(
FusedScaleMaskSoftmax
,
self
).
__init__
()
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_fp16
=
input_in_fp16
self
.
input_in_bf16
=
input_in_bf16
assert
not
(
self
.
input_in_fp16
and
self
.
input_in_bf16
),
\
'both fp16 and bf16 flags cannot be active at the same time.'
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
scaled_masked_softmax_fusion
=
scaled_masked_softmax_fusion
self
.
mask_func
=
mask_func
self
.
mask_func
=
mask_func
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -128,8 +133,8 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
query_seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# invoke custom kernel
# invoke custom kernel
if
self
.
input_in_f
p
16
and
mask
is
not
None
and
\
if
self
.
input_in_f
loat
16
and
mask
is
not
None
and
\
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
custom_kernel_constraint
and
self
.
scaled_masked_softmax_fusion
:
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -142,7 +147,7 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
assert
self
.
attn_mask_type
==
AttnMaskType
.
padding
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
probs
=
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
else
:
if
self
.
input_in_f
p
16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_f
loat
16
and
self
.
softmax_in_fp32
:
input
=
input
.
float
()
input
=
input
.
float
()
if
self
.
scale
is
not
None
:
if
self
.
scale
is
not
None
:
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
...
@@ -150,7 +155,10 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
mask_output
=
self
.
mask_func
(
input
,
mask
)
if
mask
is
not
None
else
input
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
probs
=
torch
.
nn
.
Softmax
(
dim
=-
1
)(
mask_output
)
if
self
.
input_in_fp16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
probs
=
probs
.
half
()
if
self
.
input_in_fp16
:
probs
=
probs
.
half
()
else
:
probs
=
probs
.
bfloat16
()
return
probs
return
probs
megatron/model/transformer.py
View file @
0fa7175f
...
@@ -22,7 +22,7 @@ from megatron import get_args
...
@@ -22,7 +22,7 @@ from megatron import get_args
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
...
@@ -116,6 +116,7 @@ class ParallelAttention(MegatronModule):
super
(
ParallelAttention
,
self
).
__init__
()
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
...
@@ -164,7 +165,7 @@ class ParallelAttention(MegatronModule):
self
.
norm_factor
*=
coeff
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
attention_mask_func
,
...
@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -401,7 +402,6 @@ class ParallelTransformerLayer(MegatronModule):
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
# Layernorm on the input data.
# Layernorm on the input data.
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -443,8 +443,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
layernorm_output
,
...
@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -483,8 +481,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the self attention.
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
...
@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -507,8 +503,6 @@ class ParallelTransformerLayer(MegatronModule):
# Layer norm post the decoder attention
# Layer norm post the decoder attention
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
layernorm_output
=
self
.
post_inter_attention_layernorm
(
layernorm_input
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
layernorm_output
=
layernorm_output
.
bfloat16
()
# MLP.
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
mlp_output
,
mlp_bias
=
self
.
mlp
(
layernorm_output
)
...
@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -588,8 +582,6 @@ class ParallelTransformer(MegatronModule):
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
# Final layer norm before output.
# Final layer norm before output.
LayerNorm
=
import_layernorm
(
self
.
fp32_residual_connection
,
self
.
bf16
)
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
)
...
@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
...
@@ -676,8 +668,6 @@ class ParallelTransformer(MegatronModule):
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
bf16
and
self
.
fp32_residual_connection
:
output
=
output
.
bfloat16
()
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
if
get_key_value
:
...
...
megatron/optimizer/__init__.py
View file @
0fa7175f
...
@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
...
@@ -17,7 +17,7 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
import_l
ayer
n
orm
from
megatron.model
import
L
ayer
N
orm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
...
@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
...
@@ -27,8 +27,6 @@ def _get_params_for_weight_decay_optimization(modules):
"""Divide params into with-weight-decay and without-weight-decay groups.
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
Layernorms and baises will have no weight decay but the rest will.
"""
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
weight_decay_params
=
{
'params'
:
[]}
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
...
...
megatron/training.py
View file @
0fa7175f
...
@@ -224,16 +224,6 @@ def get_model(model_provider_func):
...
@@ -224,16 +224,6 @@ def get_model(model_provider_func):
# Fp16 conversion.
# Fp16 conversion.
if
args
.
fp16
or
args
.
bf16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
# For now, the layer norm does not support input float32 and outut bf16.
# For this, we move layernorm parameters to fp32 and cast output of the
# layernorm operation back to bf16.
if
args
.
bf16
and
args
.
fp32_residual_connection
:
from
megatron.model
import
import_layernorm
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
,
args
.
bf16
)
for
model_
in
model
:
for
module_
in
model_
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
module_
.
float
()
if
args
.
DDP_impl
==
'torch'
:
if
args
.
DDP_impl
==
'torch'
:
i
=
torch
.
cuda
.
current_device
()
i
=
torch
.
cuda
.
current_device
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment