Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
1f4c3979
Unverified
Commit
1f4c3979
authored
Oct 11, 2023
by
Kirthi Shankar Sivamani
Committed by
GitHub
Oct 11, 2023
Browse files
Remove TF support (#467)
Signed-off-by:
Kirthi Shankar Sivamani
<
ksivamani@nvidia.com
>
parent
2574a1ca
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
0 additions
and
4724 deletions
+0
-4724
transformer_engine/tensorflow/__init__.py
transformer_engine/tensorflow/__init__.py
+0
-18
transformer_engine/tensorflow/constants.py
transformer_engine/tensorflow/constants.py
+0
-28
transformer_engine/tensorflow/csrc/extensions.cu
transformer_engine/tensorflow/csrc/extensions.cu
+0
-1169
transformer_engine/tensorflow/csrc/get_stream_op.cpp
transformer_engine/tensorflow/csrc/get_stream_op.cpp
+0
-34
transformer_engine/tensorflow/fp8.py
transformer_engine/tensorflow/fp8.py
+0
-266
transformer_engine/tensorflow/jit.py
transformer_engine/tensorflow/jit.py
+0
-102
transformer_engine/tensorflow/module.py
transformer_engine/tensorflow/module.py
+0
-2027
transformer_engine/tensorflow/softmax.py
transformer_engine/tensorflow/softmax.py
+0
-197
transformer_engine/tensorflow/transformer.py
transformer_engine/tensorflow/transformer.py
+0
-856
transformer_engine/tensorflow/utils.py
transformer_engine/tensorflow/utils.py
+0
-27
No files found.
transformer_engine/tensorflow/__init__.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for Tensorflow"""
from
transformer_engine.common.recipe
import
DelayedScaling
from
transformer_engine.common.recipe
import
Format
from
.constants
import
TE_DType
from
.fp8
import
fp8_autocast
from
.module
import
Dense
from
.module
import
LayerNorm
from
.module
import
LayerNormDense
from
.module
import
LayerNormMLP
from
.module
import
get_stream_id
from
.transformer
import
MultiHeadAttention
from
.transformer
import
TransformerLayer
transformer_engine/tensorflow/constants.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Enums for e2e transformer"""
import
tensorflow
as
tf
import
transformer_engine_tensorflow
as
tex
"""
This is a map: tf.dtype -> int
Used for passing dtypes into cuda
extension. Has one to one mapping
with enum in transformer_engine.h
"""
TE_DType
=
{
tf
.
int8
:
tex
.
DType
.
kByte
,
tf
.
int32
:
tex
.
DType
.
kInt32
,
tf
.
float32
:
tex
.
DType
.
kFloat32
,
tf
.
half
:
tex
.
DType
.
kFloat16
,
tf
.
bfloat16
:
tex
.
DType
.
kBFloat16
,
}
AttnMaskTypes
=
(
"causal"
,
"padding"
)
AttnTypes
=
(
"self"
,
"cross"
)
LayerTypes
=
(
"encoder"
,
"decoder"
)
transformer_engine/tensorflow/csrc/extensions.cu
deleted
100644 → 0
View file @
2574a1ca
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <pybind11/pybind11.h>
#include <string>
#include "common/include/transformer_engine/activation.h"
#include "common/include/transformer_engine/cast.h"
#include "common/include/transformer_engine/gemm.h"
#include "common/include/transformer_engine/layer_norm.h"
#include "common/include/transformer_engine/softmax.h"
#include "common/include/transformer_engine/transformer_engine.h"
#include "common/include/transformer_engine/transpose.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_internal.h"
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/c/tf_status_internal.h"
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
#include "tensorflow/core/common_runtime/gpu/gpu_process_state.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/python/eager/pywrap_tensor.h"
#include "tensorflow/python/eager/pywrap_tfe.h"
#include "tensorflow/python/lib/core/pybind11_lib.h"
namespace
transformer_engine
{
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum
FP8FwdTensors
{
GEMM1_INPUT
=
0
,
GEMM1_WEIGHT
=
1
,
GEMM2_INPUT
=
2
,
GEMM2_WEIGHT
=
3
};
// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum
FP8BwdTensors
{
GRAD_OUTPUT1
=
0
,
GRAD_OUTPUT2
=
1
};
}
// namespace transformer_engine
namespace
{
void
CheckTensorIsOnGPU
(
TFE_TensorHandle
*
tensor
,
TF_Status
*
status
)
{
const
char
*
device_type
=
TFE_TensorHandleDeviceType
(
tensor
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
CHECK_EQ
(
std
::
string
(
device_type
),
std
::
string
(
"GPU"
))
<<
"Tensor must be on the GPU, but got device_type="
<<
device_type
;
}
std
::
vector
<
size_t
>
TensorShapeAsVector
(
TFE_TensorHandle
*
tensor
,
TF_Status
*
status
)
{
std
::
vector
<
size_t
>
shape
(
TFE_TensorHandleNumDims
(
tensor
,
status
));
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
shape
[
i
]
=
TFE_TensorHandleDim
(
tensor
,
i
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
}
return
shape
;
}
transformer_engine
::
DType
GetNVTEDataType
(
TF_DataType
t
)
{
switch
(
t
)
{
case
TF_HALF
:
return
transformer_engine
::
DType
::
kFloat16
;
case
TF_FLOAT
:
return
transformer_engine
::
DType
::
kFloat32
;
case
TF_BFLOAT16
:
return
transformer_engine
::
DType
::
kBFloat16
;
case
TF_BOOL
:
case
TF_INT8
:
return
transformer_engine
::
DType
::
kByte
;
case
TF_INT32
:
return
transformer_engine
::
DType
::
kInt32
;
default:
CHECK
(
false
)
<<
"TF dtype is not supported: "
<<
t
;
}
}
TF_DataType
GetTFDataType
(
transformer_engine
::
DType
t
)
{
switch
(
t
)
{
case
transformer_engine
::
DType
::
kFloat16
:
return
TF_HALF
;
case
transformer_engine
::
DType
::
kFloat32
:
return
TF_FLOAT
;
case
transformer_engine
::
DType
::
kBFloat16
:
return
TF_BFLOAT16
;
case
transformer_engine
::
DType
::
kByte
:
case
transformer_engine
::
DType
::
kFloat8E4M3
:
case
transformer_engine
::
DType
::
kFloat8E5M2
:
return
TF_INT8
;
case
transformer_engine
::
DType
::
kInt32
:
return
TF_INT32
;
default:
CHECK
(
false
)
<<
"NVTE dtype is not supported: "
<<
static_cast
<
int
>
(
t
);
}
}
void
*
TFE_TensorHandleDevicePointerNoSync
(
TFE_TensorHandle
*
h
,
TF_Status
*
status
)
{
if
(
h
==
nullptr
)
{
status
->
status
=
tensorflow
::
errors
::
InvalidArgument
(
"Invalid handle"
);
return
nullptr
;
}
tensorflow
::
ImmediateExecutionTensorHandle
*
unwrapped_handle
=
tensorflow
::
unwrap
(
h
);
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if
(
tensorflow
::
CustomDeviceTensorHandle
::
classof
(
unwrapped_handle
))
{
return
tensorflow
::
down_cast
<
tensorflow
::
CustomDeviceTensorHandle
*>
(
unwrapped_handle
)
->
DevicePointer
();
}
// TODO(b/175427838): It would be nice to be able to use tensorflow::isa here.
if
(
!
tensorflow
::
TensorHandle
::
classof
(
unwrapped_handle
))
{
status
->
status
=
tensorflow
::
errors
::
InvalidArgument
(
"Invalid handle"
);
return
nullptr
;
}
tensorflow
::
TensorHandle
*
handle
=
tensorflow
::
TensorHandleFromInterface
(
unwrapped_handle
);
if
(
handle
->
Type
()
!=
tensorflow
::
TensorHandle
::
LOCAL
)
{
status
->
status
=
tensorflow
::
errors
::
InvalidArgument
(
"TFE_TensorHandleDevicePointer may not be called on a "
,
handle
->
TypeString
(),
" tensor handle."
);
return
nullptr
;
}
const
tensorflow
::
Tensor
*
tensor
;
status
->
status
=
handle
->
Tensor
(
&
tensor
);
if
(
!
status
->
status
.
ok
())
{
return
nullptr
;
}
return
const_cast
<
void
*>
(
static_cast
<
const
void
*>
(
tensor
->
tensor_data
().
data
()));
}
// We assume the dptr is float when applying the offset. The offset is only
// meaningful for the amax/scale/scale_inv tensors.
void
*
GetDevicePtr
(
const
pybind11
::
handle
&
handle
,
int
offset
=
0
)
{
if
(
offset
==
-
1
)
return
nullptr
;
CHECK
(
EagerTensor_CheckExact
(
handle
.
ptr
()))
<<
"EagerTensor required!"
;
auto
in_eager
=
EagerTensor_Handle
(
handle
.
ptr
());
auto
status
=
TF_NewStatus
();
CheckTensorIsOnGPU
(
in_eager
,
status
);
void
*
in_dptr
=
nullptr
;
if
(
in_eager
)
{
in_dptr
=
TFE_TensorHandleDevicePointerNoSync
(
in_eager
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
}
TF_DeleteStatus
(
status
);
return
reinterpret_cast
<
float
*>
(
in_dptr
)
+
offset
;
}
std
::
vector
<
size_t
>
GetShape
(
const
pybind11
::
handle
&
handle
)
{
TFE_TensorHandle
*
in_eager
=
EagerTensor_Handle
(
handle
.
ptr
());
TF_Status
*
status
=
TF_NewStatus
();
std
::
vector
<
size_t
>
shape
=
TensorShapeAsVector
(
in_eager
,
status
);
TF_DeleteStatus
(
status
);
return
shape
;
}
transformer_engine
::
DType
GetDataType
(
const
pybind11
::
handle
&
handle
)
{
TFE_TensorHandle
*
in_eager
=
EagerTensor_Handle
(
handle
.
ptr
());
auto
tf_itype
=
TFE_TensorHandleDataType
(
in_eager
);
return
GetNVTEDataType
(
tf_itype
);
}
transformer_engine
::
TensorWrapper
MakeNVTETensor
(
void
*
data_ptr
,
const
std
::
vector
<
size_t
>&
shape
,
const
transformer_engine
::
DType
type
,
void
*
amax_ptr
=
nullptr
,
void
*
scale_ptr
=
nullptr
,
void
*
scale_inv_ptr
=
nullptr
)
{
return
transformer_engine
::
TensorWrapper
(
data_ptr
,
shape
,
type
,
reinterpret_cast
<
float
*>
(
amax_ptr
),
reinterpret_cast
<
float
*>
(
scale_ptr
),
reinterpret_cast
<
float
*>
(
scale_inv_ptr
));
}
tensorflow
::
Allocator
*
GetAllocator
()
{
static
tensorflow
::
Allocator
*
allocator
=
nullptr
;
if
(
allocator
==
nullptr
)
{
tensorflow
::
GPUOptions
gpu_options
;
tsl
::
TfDeviceId
device_id
(
0
);
allocator
=
tensorflow
::
GPUProcessState
::
singleton
()
->
GetGPUAllocator
(
gpu_options
,
device_id
,
/*total_bytes=*/
1
,
/*peer_gpu_ids=*/
{});
}
return
allocator
;
}
TFE_Context
*
GetContext
(
TF_Status
*
status
)
{
// Cache TF context.
static
TFE_Context
*
context
=
nullptr
;
if
(
context
==
nullptr
)
{
TFE_ContextOptions
*
opts
=
TFE_NewContextOptions
();
// Current TF-TE only supports a single GPU. Here we need to manually set
// the GPU number to 1 in case of the multi-GPU environment. Otherwise, the
// TF will still traverse all the valid GPUs (to get stream priority ranges)
// and eventually cudaSetDevice to the last one (This logic is defined in
// BaseGPUDeviceFactory::CreateDevices). This would cause the other pybind
// functions to be dispatched onto other GPUs, leading to bad results.
auto
*
device_count
=
opts
->
session_options
.
options
.
config
.
mutable_device_count
();
device_count
->
insert
({
"GPU"
,
1
});
context
=
TFE_NewContext
(
opts
,
status
);
}
return
context
;
}
void
Deallocator
(
void
*
data
,
size_t
unused
,
void
*
tensor_handle
)
{
GetAllocator
()
->
DeallocateRaw
(
data
);
}
void
*
AllocateSpace
(
const
std
::
vector
<
size_t
>&
shape
,
transformer_engine
::
DType
te_dtype
,
cudaStream_t
stream
=
0
,
bool
init_to_zeros
=
false
)
{
auto
dtype
=
GetTFDataType
(
te_dtype
);
// Allocate GPU memory.
size_t
num_bytes
=
TF_DataTypeSize
(
dtype
);
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
num_bytes
*=
shape
[
i
];
void
*
data
=
GetAllocator
()
->
AllocateRaw
(
tensorflow
::
Allocator
::
kAllocatorAlignment
,
num_bytes
);
if
(
init_to_zeros
)
{
CHECK_EQ
(
cudaMemsetAsync
(
data
,
0
,
num_bytes
,
stream
),
cudaSuccess
);
}
return
data
;
}
TFE_TensorHandle
*
CreateTensor
(
void
*
data
,
const
std
::
vector
<
size_t
>&
shape
,
transformer_engine
::
DType
te_dtype
)
{
auto
dtype
=
GetTFDataType
(
te_dtype
);
size_t
num_bytes
=
TF_DataTypeSize
(
dtype
);
for
(
int
i
=
0
;
i
<
shape
.
size
();
++
i
)
num_bytes
*=
shape
[
i
];
TF_Status
*
status
=
TF_NewStatus
();
TFE_Context
*
ctx
=
GetContext
(
status
);
// Get first GPU device name.
TF_DeviceList
*
devices
=
TFE_ContextListDevices
(
ctx
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
int
num_devices
=
TF_DeviceListCount
(
devices
);
const
char
*
device_name
=
nullptr
;
for
(
int
i
=
0
;
i
<
num_devices
;
++
i
)
{
const
char
*
name
=
TF_DeviceListName
(
devices
,
i
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
if
(
std
::
string
(
name
).
find
(
"GPU"
)
!=
std
::
string
::
npos
)
{
device_name
=
name
;
break
;
}
}
CHECK_NE
(
device_name
,
nullptr
)
<<
"No GPU device found."
;
std
::
vector
<
int64_t
>
shape64
(
shape
.
size
());
std
::
transform
(
shape
.
cbegin
(),
shape
.
cend
(),
shape64
.
begin
(),
[](
const
size_t
&
v
)
{
return
static_cast
<
int64_t
>
(
v
);
});
TFE_TensorHandle
*
tensor
=
TFE_NewTensorHandleFromDeviceMemory
(
ctx
,
device_name
,
dtype
,
shape64
.
data
(),
shape64
.
size
(),
data
,
num_bytes
,
&
Deallocator
,
nullptr
,
status
);
CHECK_EQ
(
TF_OK
,
TF_GetCode
(
status
))
<<
TF_Message
(
status
);
TF_DeleteStatus
(
status
);
return
tensor
;
}
void
dispatch_cast_transpose_fusion
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
scale
,
// i
const
std
::
vector
<
size_t
>&
scale_shape
,
const
transformer_engine
::
DType
scale_type
,
void
*
output_cast
,
// o
const
std
::
vector
<
size_t
>&
output_cast_shape
,
const
transformer_engine
::
DType
output_cast_type
,
void
*
output_transpose
,
// o
const
std
::
vector
<
size_t
>&
output_transpose_shape
,
const
transformer_engine
::
DType
output_transpose_type
,
void
*
amax
,
// o
const
std
::
vector
<
size_t
>&
amax_shape
,
const
transformer_engine
::
DType
amax_type
,
void
*
scale_inv
,
// o
const
std
::
vector
<
size_t
>&
scale_inv_shape
,
const
transformer_engine
::
DType
scale_inv_type
,
cudaStream_t
stream
)
{
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
output_cast_cu
=
MakeNVTETensor
(
output_cast
,
output_cast_shape
,
output_cast_type
,
amax
,
scale
,
scale_inv
);
auto
output_transpose_cu
=
MakeNVTETensor
(
output_transpose
,
output_transpose_shape
,
output_transpose_type
,
amax
,
scale
,
scale_inv
);
nvte_cast_transpose
(
input_cu
.
data
(),
output_cast_cu
.
data
(),
output_transpose_cu
.
data
(),
stream
);
}
void
dispatch_transpose
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
output
,
// o
const
std
::
vector
<
size_t
>&
output_shape
,
const
transformer_engine
::
DType
output_type
,
cudaStream_t
stream
)
{
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
output_cu
=
MakeNVTETensor
(
output
,
output_shape
,
output_type
);
nvte_transpose
(
input_cu
.
data
(),
output_cu
.
data
(),
stream
);
}
void
dispatch_bgrad_cast_transpose_fusion
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
scale
,
// i
const
std
::
vector
<
size_t
>&
scale_shape
,
const
transformer_engine
::
DType
scale_type
,
void
*
cast_output
,
// o
const
std
::
vector
<
size_t
>&
cast_output_shape
,
const
transformer_engine
::
DType
cast_output_type
,
void
*
transposed_output
,
// o
const
std
::
vector
<
size_t
>&
transposed_output_shape
,
const
transformer_engine
::
DType
transposed_output_type
,
void
*
amax
,
// o
const
std
::
vector
<
size_t
>&
amax_shape
,
const
transformer_engine
::
DType
amax_type
,
void
*
dbias
,
// o
const
std
::
vector
<
size_t
>&
dbias_shape
,
const
transformer_engine
::
DType
dbias_type
,
void
*
scale_inv
,
// o
const
std
::
vector
<
size_t
>&
scale_inv_shape
,
const
transformer_engine
::
DType
scale_inv_type
,
cudaStream_t
stream
)
{
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
cast_output_cu
=
MakeNVTETensor
(
cast_output
,
cast_output_shape
,
cast_output_type
,
amax
,
scale
,
scale_inv
);
auto
transposed_output_cu
=
MakeNVTETensor
(
transposed_output
,
transposed_output_shape
,
transposed_output_type
,
amax
,
scale
,
scale_inv
);
auto
dbias_cu
=
MakeNVTETensor
(
dbias
,
dbias_shape
,
dbias_type
);
transformer_engine
::
TensorWrapper
workspace
;
nvte_cast_transpose_dbias
(
input_cu
.
data
(),
cast_output_cu
.
data
(),
transposed_output_cu
.
data
(),
dbias_cu
.
data
(),
workspace
.
data
(),
stream
);
// Fill workspace
auto
w_s
=
workspace
.
shape
();
std
::
vector
<
size_t
>
w_shape_vec
{
w_s
.
data
,
w_s
.
data
+
w_s
.
ndim
};
void
*
workspace_ptr
=
AllocateSpace
(
w_shape_vec
,
workspace
.
dtype
());
workspace
=
MakeNVTETensor
(
workspace_ptr
,
w_shape_vec
,
workspace
.
dtype
());
nvte_cast_transpose_dbias
(
input_cu
.
data
(),
cast_output_cu
.
data
(),
transposed_output_cu
.
data
(),
dbias_cu
.
data
(),
workspace
.
data
(),
stream
);
}
void
dispatch_layernorm
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
gamma
,
// i
const
std
::
vector
<
size_t
>&
gamma_shape
,
const
transformer_engine
::
DType
gamma_type
,
void
*
beta
,
// i
const
std
::
vector
<
size_t
>&
beta_shape
,
const
transformer_engine
::
DType
beta_type
,
void
*
scale
,
// i
const
std
::
vector
<
size_t
>&
scale_shape
,
const
transformer_engine
::
DType
scale_type
,
const
float
epsilon
,
// i
void
*
z
,
// o
const
std
::
vector
<
size_t
>&
z_shape
,
const
transformer_engine
::
DType
z_type
,
void
*
mu
,
// o
const
std
::
vector
<
size_t
>&
mu_shape
,
const
transformer_engine
::
DType
mu_type
,
void
*
rsigma
,
// o
const
std
::
vector
<
size_t
>&
rsigma_shape
,
const
transformer_engine
::
DType
rsigma_type
,
void
*
amax
,
// o
const
std
::
vector
<
size_t
>&
amax_shape
,
const
transformer_engine
::
DType
amax_type
,
void
*
scale_inv
,
// o
const
std
::
vector
<
size_t
>&
scale_inv_shape
,
const
transformer_engine
::
DType
scale_inv_type
,
const
int
multiProcessorCount
,
cudaStream_t
stream
)
{
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
gamma_cu
=
MakeNVTETensor
(
gamma
,
gamma_shape
,
gamma_type
);
auto
beta_cu
=
MakeNVTETensor
(
beta
,
beta_shape
,
beta_type
);
auto
z_cu
=
MakeNVTETensor
(
z
,
z_shape
,
z_type
,
amax
,
scale
,
scale_inv
);
auto
mu_cu
=
MakeNVTETensor
(
mu
,
mu_shape
,
mu_type
);
auto
rsigma_cu
=
MakeNVTETensor
(
rsigma
,
rsigma_shape
,
rsigma_type
);
transformer_engine
::
TensorWrapper
workspace
,
barrier
;
// This call populates workspace and barrier tensors with the required config
nvte_layernorm_fwd
(
input_cu
.
data
(),
gamma_cu
.
data
(),
beta_cu
.
data
(),
epsilon
,
z_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
stream
,
multiProcessorCount
,
workspace
.
data
(),
barrier
.
data
());
// Fill workspace and barrier
auto
w_s
=
workspace
.
shape
();
auto
b_s
=
barrier
.
shape
();
std
::
vector
<
size_t
>
w_shape_vec
{
w_s
.
data
,
w_s
.
data
+
w_s
.
ndim
};
std
::
vector
<
size_t
>
b_shape_vec
{
b_s
.
data
,
b_s
.
data
+
b_s
.
ndim
};
void
*
workspace_ptr
=
AllocateSpace
(
w_shape_vec
,
workspace
.
dtype
());
void
*
barrier_ptr
=
AllocateSpace
(
b_shape_vec
,
barrier
.
dtype
(),
stream
,
true
);
workspace
=
MakeNVTETensor
(
workspace_ptr
,
w_shape_vec
,
workspace
.
dtype
());
barrier
=
MakeNVTETensor
(
barrier_ptr
,
b_shape_vec
,
barrier
.
dtype
());
// Actual call to fwd kernel
nvte_layernorm_fwd
(
input_cu
.
data
(),
gamma_cu
.
data
(),
beta_cu
.
data
(),
epsilon
,
z_cu
.
data
(),
mu_cu
.
data
(),
rsigma_cu
.
data
(),
stream
,
multiProcessorCount
,
workspace
.
data
(),
barrier
.
data
());
}
void
dispatch_gelu
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
scale
,
// i
const
std
::
vector
<
size_t
>&
scale_shape
,
const
transformer_engine
::
DType
scale_type
,
void
*
output
,
// o
const
std
::
vector
<
size_t
>&
output_shape
,
const
transformer_engine
::
DType
output_type
,
void
*
amax
,
// o
const
std
::
vector
<
size_t
>&
amax_shape
,
const
transformer_engine
::
DType
amax_type
,
void
*
scale_inv
,
// o
const
std
::
vector
<
size_t
>&
scale_inv_shape
,
const
transformer_engine
::
DType
scale_inv_type
,
cudaStream_t
stream
)
{
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
output_cu
=
MakeNVTETensor
(
output
,
output_shape
,
output_type
,
amax
,
scale
,
scale_inv
);
nvte_gelu
(
input_cu
.
data
(),
output_cu
.
data
(),
stream
);
}
void
dispatch_bgrad_dgelu_cast_transpose_fusion
(
void
*
input
,
// i
const
std
::
vector
<
size_t
>&
input_shape
,
const
transformer_engine
::
DType
input_type
,
void
*
gelu_input
,
// i
const
std
::
vector
<
size_t
>&
gelu_input_shape
,
const
transformer_engine
::
DType
gelu_input_type
,
void
*
scale
,
// i
const
std
::
vector
<
size_t
>&
scale_shape
,
const
transformer_engine
::
DType
scale_type
,
void
*
cast_output
,
// o
const
std
::
vector
<
size_t
>&
cast_output_shape
,
const
transformer_engine
::
DType
cast_output_type
,
void
*
transposed_output
,
// o
const
std
::
vector
<
size_t
>&
transposed_output_shape
,
const
transformer_engine
::
DType
transposed_output_type
,
void
*
amax
,
// o
const
std
::
vector
<
size_t
>&
amax_shape
,
const
transformer_engine
::
DType
amax_type
,
void
*
dbias
,
// o
const
std
::
vector
<
size_t
>&
dbias_shape
,
const
transformer_engine
::
DType
dbias_type
,
void
*
scale_inv
,
// o
const
std
::
vector
<
size_t
>&
scale_inv_shape
,
const
transformer_engine
::
DType
scale_inv_type
,
cudaStream_t
stream
)
{
auto
gelu_input_cu
=
MakeNVTETensor
(
gelu_input
,
gelu_input_shape
,
gelu_input_type
);
auto
input_cu
=
MakeNVTETensor
(
input
,
input_shape
,
input_type
);
auto
cast_output_cu
=
MakeNVTETensor
(
cast_output
,
cast_output_shape
,
cast_output_type
,
amax
,
scale
,
scale_inv
);
auto
transposed_output_cu
=
MakeNVTETensor
(
transposed_output
,
transposed_output_shape
,
transposed_output_type
,
amax
,
scale
,
scale_inv
);
auto
dbias_cu
=
MakeNVTETensor
(
dbias
,
dbias_shape
,
dbias_type
);
transformer_engine
::
TensorWrapper
workspace
;
nvte_cast_transpose_dbias_dgelu
(
input_cu
.
data
(),
gelu_input_cu
.
data
(),
cast_output_cu
.
data
(),
transposed_output_cu
.
data
(),
dbias_cu
.
data
(),
workspace
.
data
(),
stream
);
// Fill workspace
auto
w_s
=
workspace
.
shape
();
std
::
vector
<
size_t
>
w_shape_vec
{
w_s
.
data
,
w_s
.
data
+
w_s
.
ndim
};
void
*
workspace_ptr
=
AllocateSpace
(
w_shape_vec
,
workspace
.
dtype
());
workspace
=
MakeNVTETensor
(
workspace_ptr
,
w_shape_vec
,
workspace
.
dtype
());
nvte_cast_transpose_dbias_dgelu
(
input_cu
.
data
(),
gelu_input_cu
.
data
(),
cast_output_cu
.
data
(),
transposed_output_cu
.
data
(),
dbias_cu
.
data
(),
workspace
.
data
(),
stream
);
}
TFE_TensorHandle
*
GetTFETensorHandle
(
const
pybind11
::
handle
tensor
)
{
CHECK
(
EagerTensor_CheckExact
(
tensor
.
ptr
()))
<<
"All inputs must be EagerTensors."
;
return
EagerTensor_Handle
(
tensor
.
ptr
());
}
int
GetDeviceMultiProcessorCount
()
{
static
int
count
=
[]
{
cudaDeviceProp
properties
;
// Get current device
int
device
=
-
1
;
CHECK_EQ
(
cudaGetDevice
(
&
device
),
cudaSuccess
)
<<
"Got invalid GPU"
<<
device
;
CHECK_EQ
(
cudaGetDeviceProperties
(
&
properties
,
device
),
cudaSuccess
);
return
properties
.
multiProcessorCount
;
}();
return
count
;
}
py
::
object
TFE_Py_TeGemm_wrapper
(
const
pybind11
::
handle
&
a_mat
,
const
pybind11
::
handle
&
a_scale_inv
,
const
transformer_engine
::
DType
atype
,
const
int
a_offset
,
const
pybind11
::
handle
&
b_mat
,
const
pybind11
::
handle
&
b_scale_inv
,
const
transformer_engine
::
DType
btype
,
const
int
b_offset
,
const
pybind11
::
handle
&
workspace
,
const
bool
use_bias
,
const
pybind11
::
handle
&
bias
,
const
bool
use_gelu
,
const
pybind11
::
handle
&
gelu_input
,
const
bool
transa
,
const
bool
transb
,
const
bool
grad
,
const
bool
accumulate
,
const
bool
use_split_accumulate
,
const
transformer_engine
::
DType
otype
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
a_shape
=
GetShape
(
a_mat
);
std
::
vector
<
size_t
>
b_shape
=
GetShape
(
b_mat
);
CHECK_EQ
(
a_shape
.
size
(),
2
);
CHECK_EQ
(
b_shape
.
size
(),
2
);
std
::
vector
<
size_t
>
d_shape
{
transb
?
b_shape
[
1
]
:
b_shape
[
0
],
transa
?
a_shape
[
0
]
:
a_shape
[
1
]};
auto
a_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
a_mat
),
a_shape
,
atype
,
nullptr
,
nullptr
,
GetDevicePtr
(
a_scale_inv
,
a_offset
));
auto
b_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
b_mat
),
b_shape
,
btype
,
nullptr
,
nullptr
,
GetDevicePtr
(
b_scale_inv
,
b_offset
));
NVTEShape
empty_shape
;
TensorWrapper
bias_tensor
(
nullptr
,
empty_shape
,
DType
::
kBFloat16
);
if
(
use_bias
)
{
bias_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
bias
),
GetShape
(
bias
),
GetDataType
(
bias
));
}
TensorWrapper
gelu_input_tensor
(
nullptr
,
empty_shape
,
DType
::
kBFloat16
);
void
*
gelu_input_ptr
=
nullptr
;
if
(
use_gelu
&&
!
grad
)
{
gelu_input_ptr
=
AllocateSpace
(
d_shape
,
otype
);
gelu_input_tensor
=
MakeNVTETensor
(
gelu_input_ptr
,
d_shape
,
otype
);
}
else
if
(
use_gelu
)
{
gelu_input_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
gelu_input
),
GetShape
(
gelu_input
),
GetDataType
(
gelu_input
));
}
auto
workspace_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
workspace
),
GetShape
(
workspace
),
GetDataType
(
workspace
));
void
*
d_ptr
=
AllocateSpace
(
d_shape
,
otype
);
auto
d_tensor
=
MakeNVTETensor
(
d_ptr
,
d_shape
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_cublas_gemm
(
a_tensor
.
data
(),
b_tensor
.
data
(),
d_tensor
.
data
(),
bias_tensor
.
data
(),
gelu_input_tensor
.
data
(),
transa
,
transb
,
grad
,
workspace_tensor
.
data
(),
accumulate
,
use_split_accumulate
,
0
,
stream
);
auto
d_eager
=
CreateTensor
(
d_ptr
,
d_shape
,
otype
);
if
(
use_gelu
&&
!
grad
)
{
auto
gelu_input_eager
=
CreateTensor
(
gelu_input_ptr
,
d_shape
,
otype
);
PyObject
*
result
(
PyList_New
(
2
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
d_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
gelu_input_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
}
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
d_eager
));
}
}
// end namespace
PYBIND11_MODULE
(
transformer_engine_tensorflow
,
m
)
{
py
::
enum_
<
transformer_engine
::
DType
>
(
m
,
"DType"
)
.
value
(
"kByte"
,
transformer_engine
::
DType
::
kByte
)
.
value
(
"kInt32"
,
transformer_engine
::
DType
::
kInt32
)
.
value
(
"kFloat32"
,
transformer_engine
::
DType
::
kFloat32
)
.
value
(
"kFloat16"
,
transformer_engine
::
DType
::
kFloat16
)
.
value
(
"kBFloat16"
,
transformer_engine
::
DType
::
kBFloat16
)
.
value
(
"kFloat8E4M3"
,
transformer_engine
::
DType
::
kFloat8E4M3
)
.
value
(
"kFloat8E5M2"
,
transformer_engine
::
DType
::
kFloat8E5M2
);
py
::
enum_
<
transformer_engine
::
FP8FwdTensors
>
(
m
,
"FP8FwdTensors"
,
py
::
arithmetic
())
.
value
(
"GEMM1_INPUT"
,
transformer_engine
::
FP8FwdTensors
::
GEMM1_INPUT
)
.
value
(
"GEMM1_WEIGHT"
,
transformer_engine
::
FP8FwdTensors
::
GEMM1_WEIGHT
)
.
value
(
"GEMM2_INPUT"
,
transformer_engine
::
FP8FwdTensors
::
GEMM2_INPUT
)
.
value
(
"GEMM2_WEIGHT"
,
transformer_engine
::
FP8FwdTensors
::
GEMM2_WEIGHT
);
py
::
enum_
<
transformer_engine
::
FP8BwdTensors
>
(
m
,
"FP8BwdTensors"
,
py
::
arithmetic
())
.
value
(
"GRAD_OUTPUT1"
,
transformer_engine
::
FP8BwdTensors
::
GRAD_OUTPUT1
)
.
value
(
"GRAD_OUTPUT2"
,
transformer_engine
::
FP8BwdTensors
::
GRAD_OUTPUT2
);
m
.
def
(
"cast_to_fp8"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
auto
input_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
input
),
shape_c
,
GetDataType
(
input
));
void
*
out_c_ptr
=
AllocateSpace
(
shape_c
,
otype
);
auto
output_tensor
=
MakeNVTETensor
(
out_c_ptr
,
shape_c
,
otype
,
GetDevicePtr
(
amax
,
offset
),
GetDevicePtr
(
scale
,
offset
),
GetDevicePtr
(
scale_inv
,
offset
));
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_fp8_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
auto
out_c_eager
=
CreateTensor
(
out_c_ptr
,
shape_c
,
otype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
out_c_eager
));
});
m
.
def
(
"cast_from_fp8"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
scale_inv
,
const
transformer_engine
::
DType
itype
,
const
transformer_engine
::
DType
otype
,
const
int
offset
,
const
int64_t
stream_id
)
{
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
auto
input_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
input
),
shape_c
,
itype
,
nullptr
,
nullptr
,
GetDevicePtr
(
scale_inv
,
offset
));
void
*
out_ptr
=
AllocateSpace
(
shape_c
,
otype
);
auto
output_tensor
=
MakeNVTETensor
(
out_ptr
,
shape_c
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_fp8_dequantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
auto
out_eager
=
CreateTensor
(
out_ptr
,
shape_c
,
otype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
out_eager
));
});
m
.
def
(
"fp8_cast_transpose_fused"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_t
{
shape_c
[
1
],
shape_c
[
0
]};
void
*
out_c_ptr
=
AllocateSpace
(
shape_c
,
otype
);
void
*
out_t_ptr
=
AllocateSpace
(
shape_t
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_cast_transpose_fusion
(
GetDevicePtr
(
input
),
shape_c
,
GetDataType
(
input
),
GetDevicePtr
(
scale
,
offset
),
{
1
},
DType
::
kFloat32
,
out_c_ptr
,
shape_c
,
otype
,
out_t_ptr
,
shape_t
,
otype
,
GetDevicePtr
(
amax
,
offset
),
{
1
},
DType
::
kFloat32
,
GetDevicePtr
(
scale_inv
,
offset
),
{
1
},
DType
::
kFloat32
,
stream
);
auto
out_c_eager
=
CreateTensor
(
out_c_ptr
,
shape_c
,
otype
);
auto
out_t_eager
=
CreateTensor
(
out_t_ptr
,
shape_t
,
otype
);
PyObject
*
result
(
PyList_New
(
2
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
out_c_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
out_t_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"fp8_transpose"
,
[](
const
pybind11
::
handle
&
input
,
transformer_engine
::
DType
otype
,
const
int64_t
stream_id
)
{
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_t
{
shape_c
[
1
],
shape_c
[
0
]};
void
*
out_t_ptr
=
AllocateSpace
(
shape_t
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_transpose
(
GetDevicePtr
(
input
),
shape_c
,
otype
,
out_t_ptr
,
shape_t
,
otype
,
stream
);
TFE_TensorHandle
*
out_t_eager
=
CreateTensor
(
out_t_ptr
,
shape_t
,
otype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
out_t_eager
));
});
m
.
def
(
"fp8_cast_transpose_bgrad_fused"
,
[](
const
pybind11
::
handle
&
grad_out
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
grad_out
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_t
{
shape_c
[
1
],
shape_c
[
0
]};
std
::
vector
<
size_t
>
shape_b
{
shape_c
[
1
]};
auto
itype
=
GetDataType
(
grad_out
);
void
*
grad_bias_ptr
=
AllocateSpace
(
shape_b
,
itype
);
void
*
grad_out_c_ptr
=
AllocateSpace
(
shape_c
,
otype
);
void
*
grad_out_t_ptr
=
AllocateSpace
(
shape_t
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_bgrad_cast_transpose_fusion
(
GetDevicePtr
(
grad_out
),
shape_c
,
itype
,
GetDevicePtr
(
scale
,
offset
),
{
1
},
DType
::
kFloat32
,
grad_out_c_ptr
,
shape_c
,
otype
,
grad_out_t_ptr
,
shape_t
,
otype
,
GetDevicePtr
(
amax
,
offset
),
{
1
},
DType
::
kFloat32
,
grad_bias_ptr
,
shape_b
,
itype
,
GetDevicePtr
(
scale_inv
,
offset
),
{
1
},
DType
::
kFloat32
,
stream
);
auto
grad_bias_eager
=
CreateTensor
(
grad_bias_ptr
,
shape_b
,
itype
);
auto
grad_out_c_eager
=
CreateTensor
(
grad_out_c_ptr
,
shape_c
,
otype
);
auto
grad_out_t_eager
=
CreateTensor
(
grad_out_t_ptr
,
shape_t
,
otype
);
PyObject
*
result
(
PyList_New
(
3
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
grad_bias_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
grad_out_c_eager
));
PyList_SET_ITEM
(
result
,
2
,
EagerTensorFromHandle
(
grad_out_t_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"te_gemm"
,
[](
const
pybind11
::
handle
&
a_mat
,
const
pybind11
::
handle
&
a_scale_inv
,
const
transformer_engine
::
DType
atype
,
const
int
a_offset
,
const
pybind11
::
handle
&
b_mat
,
const
pybind11
::
handle
&
b_scale_inv
,
const
transformer_engine
::
DType
btype
,
const
int
b_offset
,
const
pybind11
::
handle
&
workspace
,
const
bool
use_bias
,
const
pybind11
::
handle
&
bias
,
const
bool
use_gelu
,
const
pybind11
::
handle
&
gelu_input
,
const
bool
transa
,
const
bool
transb
,
const
bool
grad
,
const
bool
accumulate
,
const
bool
use_split_accumulate
,
const
transformer_engine
::
DType
otype
,
const
int64_t
stream_id
)
{
return
TFE_Py_TeGemm_wrapper
(
a_mat
,
a_scale_inv
,
atype
,
a_offset
,
b_mat
,
b_scale_inv
,
btype
,
b_offset
,
workspace
,
use_bias
,
bias
,
use_gelu
,
gelu_input
,
transa
,
transb
,
grad
,
accumulate
,
use_split_accumulate
,
otype
,
stream_id
);
});
m
.
def
(
"layernorm_fwd"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
gamma
,
const
pybind11
::
handle
&
beta
,
float
eps
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_g
{
shape_c
[
1
]};
std
::
vector
<
size_t
>
shape_m
{
shape_c
[
0
]};
auto
itype
=
GetDataType
(
input
);
auto
mtype
=
DType
::
kFloat32
;
void
*
ln_out_ptr
=
AllocateSpace
(
shape_c
,
itype
);
void
*
mu_ptr
=
AllocateSpace
(
shape_m
,
mtype
);
void
*
rsigma_ptr
=
AllocateSpace
(
shape_m
,
mtype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_layernorm
(
GetDevicePtr
(
input
),
shape_c
,
itype
,
GetDevicePtr
(
gamma
),
shape_g
,
itype
,
GetDevicePtr
(
beta
),
shape_g
,
itype
,
nullptr
,
{
1
},
mtype
,
eps
,
ln_out_ptr
,
shape_c
,
itype
,
mu_ptr
,
shape_m
,
mtype
,
rsigma_ptr
,
shape_m
,
mtype
,
nullptr
,
{
1
},
mtype
,
nullptr
,
{
1
},
mtype
,
GetDeviceMultiProcessorCount
(),
stream
);
auto
ln_out_eager
=
CreateTensor
(
ln_out_ptr
,
shape_c
,
itype
);
auto
mu_eager
=
CreateTensor
(
mu_ptr
,
shape_m
,
mtype
);
auto
rsigma_eager
=
CreateTensor
(
rsigma_ptr
,
shape_m
,
mtype
);
PyObject
*
result
(
PyList_New
(
3
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
ln_out_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
mu_eager
));
PyList_SET_ITEM
(
result
,
2
,
EagerTensorFromHandle
(
rsigma_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"layernorm_fwd_fp8"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
gamma
,
const
pybind11
::
handle
&
beta
,
float
eps
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_g
{
shape_c
[
1
]};
std
::
vector
<
size_t
>
shape_m
{
shape_c
[
0
]};
auto
itype
=
GetDataType
(
input
);
auto
mtype
=
DType
::
kFloat32
;
void
*
ln_out_ptr
=
AllocateSpace
(
shape_c
,
otype
);
void
*
mu_ptr
=
AllocateSpace
(
shape_m
,
mtype
);
void
*
rsigma_ptr
=
AllocateSpace
(
shape_m
,
mtype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_layernorm
(
GetDevicePtr
(
input
),
shape_c
,
itype
,
GetDevicePtr
(
gamma
),
shape_g
,
itype
,
GetDevicePtr
(
beta
),
shape_g
,
itype
,
GetDevicePtr
(
scale
,
offset
),
{
1
},
DType
::
kFloat32
,
eps
,
ln_out_ptr
,
shape_c
,
otype
,
mu_ptr
,
shape_m
,
mtype
,
rsigma_ptr
,
shape_m
,
mtype
,
GetDevicePtr
(
amax
,
offset
),
{
1
},
DType
::
kFloat32
,
GetDevicePtr
(
scale_inv
,
offset
),
{
1
},
DType
::
kFloat32
,
GetDeviceMultiProcessorCount
(),
stream
);
auto
ln_out_eager
=
CreateTensor
(
ln_out_ptr
,
shape_c
,
otype
);
auto
mu_eager
=
CreateTensor
(
mu_ptr
,
shape_m
,
mtype
);
auto
rsigma_eager
=
CreateTensor
(
rsigma_ptr
,
shape_m
,
mtype
);
PyObject
*
result
(
PyList_New
(
3
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
ln_out_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
mu_eager
));
PyList_SET_ITEM
(
result
,
2
,
EagerTensorFromHandle
(
rsigma_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"layernorm_bwd"
,
[](
const
pybind11
::
handle
&
dz
,
const
pybind11
::
handle
&
x
,
const
pybind11
::
handle
&
mu
,
const
pybind11
::
handle
&
rsigma
,
const
pybind11
::
handle
&
gamma
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_x
=
GetShape
(
x
);
CHECK_EQ
(
shape_x
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_g
{
shape_x
[
1
]};
std
::
vector
<
size_t
>
shape_m
{
shape_x
[
0
]};
auto
xtype
=
GetDataType
(
x
);
auto
gtype
=
GetDataType
(
gamma
);
auto
mtype
=
GetDataType
(
mu
);
void
*
dx_ptr
=
AllocateSpace
(
shape_x
,
xtype
);
void
*
dgamma_ptr
=
AllocateSpace
(
shape_g
,
gtype
);
void
*
dbeta_ptr
=
AllocateSpace
(
shape_g
,
gtype
);
auto
x_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
x
),
shape_x
,
xtype
);
auto
gamma_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
gamma
),
shape_g
,
gtype
);
auto
dz_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
dz
),
shape_x
,
xtype
);
auto
mu_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
mu
),
shape_m
,
mtype
);
auto
rsigma_tensor
=
MakeNVTETensor
(
GetDevicePtr
(
rsigma
),
shape_m
,
mtype
);
auto
dx_tensor
=
MakeNVTETensor
(
dx_ptr
,
shape_x
,
xtype
);
auto
dgamma_tensor
=
MakeNVTETensor
(
dgamma_ptr
,
shape_g
,
gtype
);
auto
dbeta_tensor
=
MakeNVTETensor
(
dbeta_ptr
,
shape_g
,
gtype
);
TensorWrapper
workspace
,
barrier
,
dgamma_part
,
dbeta_part
;
// This call populates tensors with the required config.
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_layernorm_bwd
(
dz_tensor
.
data
(),
x_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
gamma_tensor
.
data
(),
dx_tensor
.
data
(),
dgamma_tensor
.
data
(),
dbeta_tensor
.
data
(),
dgamma_part
.
data
(),
dbeta_part
.
data
(),
stream
,
GetDeviceMultiProcessorCount
(),
workspace
.
data
(),
barrier
.
data
());
// Alloc space for Tensors.
auto
w_s
=
workspace
.
shape
();
auto
b_s
=
barrier
.
shape
();
auto
dg_s
=
dgamma_part
.
shape
();
auto
db_s
=
dbeta_part
.
shape
();
std
::
vector
<
size_t
>
w_shape_vec
{
w_s
.
data
,
w_s
.
data
+
w_s
.
ndim
};
std
::
vector
<
size_t
>
b_shape_vec
{
b_s
.
data
,
b_s
.
data
+
b_s
.
ndim
};
std
::
vector
<
size_t
>
dg_shape_vec
{
dg_s
.
data
,
dg_s
.
data
+
dg_s
.
ndim
};
std
::
vector
<
size_t
>
db_shape_vec
{
db_s
.
data
,
db_s
.
data
+
db_s
.
ndim
};
void
*
workspace_ptr
=
AllocateSpace
(
w_shape_vec
,
workspace
.
dtype
());
void
*
barrier_ptr
=
AllocateSpace
(
b_shape_vec
,
barrier
.
dtype
(),
stream
,
true
);
void
*
dgamma_part_ptr
=
AllocateSpace
(
dg_shape_vec
,
dgamma_part
.
dtype
());
void
*
dbeta_part_ptr
=
AllocateSpace
(
db_shape_vec
,
dbeta_part
.
dtype
());
workspace
=
MakeNVTETensor
(
workspace_ptr
,
w_shape_vec
,
workspace
.
dtype
());
barrier
=
MakeNVTETensor
(
barrier_ptr
,
b_shape_vec
,
barrier
.
dtype
());
dgamma_part
=
MakeNVTETensor
(
dgamma_part_ptr
,
dg_shape_vec
,
dgamma_part
.
dtype
());
dbeta_part
=
MakeNVTETensor
(
dbeta_part_ptr
,
db_shape_vec
,
dbeta_part
.
dtype
());
// Actual call to bwd kernel.
nvte_layernorm_bwd
(
dz_tensor
.
data
(),
x_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
gamma_tensor
.
data
(),
dx_tensor
.
data
(),
dgamma_tensor
.
data
(),
dbeta_tensor
.
data
(),
dgamma_part
.
data
(),
dbeta_part
.
data
(),
stream
,
GetDeviceMultiProcessorCount
(),
workspace
.
data
(),
barrier
.
data
());
auto
dx_eager
=
CreateTensor
(
dx_ptr
,
shape_x
,
xtype
);
auto
dgamma_eager
=
CreateTensor
(
dgamma_ptr
,
shape_g
,
gtype
);
auto
dbeta_eager
=
CreateTensor
(
dbeta_ptr
,
shape_g
,
gtype
);
PyObject
*
result
(
PyList_New
(
3
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
dx_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
dgamma_eager
));
PyList_SET_ITEM
(
result
,
2
,
EagerTensorFromHandle
(
dbeta_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"te_gelu"
,
[](
const
pybind11
::
handle
&
input
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
input
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
void
*
out_ptr
=
AllocateSpace
(
shape_c
,
otype
);
auto
itype
=
GetDataType
(
input
);
void
*
scale_ptr
=
nullptr
;
void
*
amax_ptr
=
nullptr
;
void
*
scale_inv_ptr
=
nullptr
;
if
(
otype
==
DType
::
kFloat8E4M3
||
otype
==
DType
::
kFloat8E5M2
)
{
scale_ptr
=
GetDevicePtr
(
scale
,
offset
);
amax_ptr
=
GetDevicePtr
(
amax
,
offset
);
scale_inv_ptr
=
GetDevicePtr
(
scale_inv
,
offset
);
}
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_gelu
(
GetDevicePtr
(
input
),
shape_c
,
itype
,
scale_ptr
,
{
1
},
DType
::
kFloat32
,
out_ptr
,
shape_c
,
otype
,
amax_ptr
,
{
1
},
DType
::
kFloat32
,
scale_inv_ptr
,
{
1
},
DType
::
kFloat32
,
stream
);
auto
out_eager
=
CreateTensor
(
out_ptr
,
shape_c
,
otype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
out_eager
));
});
m
.
def
(
"fp8_fused_cast_transpose_bgrad_dgelu"
,
[](
const
pybind11
::
handle
&
grad_output
,
const
pybind11
::
handle
&
gelu_input
,
const
pybind11
::
handle
&
scale
,
const
transformer_engine
::
DType
otype
,
const
pybind11
::
handle
&
amax
,
const
pybind11
::
handle
&
scale_inv
,
const
int
offset
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_c
=
GetShape
(
grad_output
);
CHECK_EQ
(
shape_c
.
size
(),
2
);
std
::
vector
<
size_t
>
shape_t
{
shape_c
[
1
],
shape_c
[
0
]};
std
::
vector
<
size_t
>
shape_b
{
shape_c
[
1
]};
auto
itype
=
GetDataType
(
grad_output
);
void
*
grad_bias_ptr
=
AllocateSpace
(
shape_b
,
itype
);
void
*
dgelu_c_ptr
=
AllocateSpace
(
shape_c
,
otype
);
void
*
dgelu_t_ptr
=
AllocateSpace
(
shape_t
,
otype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
dispatch_bgrad_dgelu_cast_transpose_fusion
(
GetDevicePtr
(
grad_output
),
shape_c
,
itype
,
GetDevicePtr
(
gelu_input
),
shape_c
,
itype
,
GetDevicePtr
(
scale
,
offset
),
{
1
},
DType
::
kFloat32
,
dgelu_c_ptr
,
shape_c
,
otype
,
dgelu_t_ptr
,
shape_t
,
otype
,
GetDevicePtr
(
amax
,
offset
),
{
1
},
DType
::
kFloat32
,
grad_bias_ptr
,
shape_b
,
itype
,
GetDevicePtr
(
scale_inv
,
offset
),
{
1
},
DType
::
kFloat32
,
stream
);
auto
grad_bias_eager
=
CreateTensor
(
grad_bias_ptr
,
shape_b
,
itype
);
auto
dgelu_c_eager
=
CreateTensor
(
dgelu_c_ptr
,
shape_c
,
otype
);
auto
dgelu_t_eager
=
CreateTensor
(
dgelu_t_ptr
,
shape_t
,
otype
);
PyObject
*
result
(
PyList_New
(
3
));
PyList_SET_ITEM
(
result
,
0
,
EagerTensorFromHandle
(
grad_bias_eager
));
PyList_SET_ITEM
(
result
,
1
,
EagerTensorFromHandle
(
dgelu_c_eager
));
PyList_SET_ITEM
(
result
,
2
,
EagerTensorFromHandle
(
dgelu_t_eager
));
return
tensorflow
::
PyoOrThrow
(
result
);
});
m
.
def
(
"scaled_upper_triang_masked_softmax_forward"
,
[](
const
pybind11
::
handle
&
input
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_in
=
GetShape
(
input
);
CHECK_EQ
(
shape_in
.
size
(),
3
);
auto
itype
=
GetDataType
(
input
);
CHECK
(
itype
==
DType
::
kFloat16
||
itype
==
DType
::
kBFloat16
);
const
size_t
attn_batches
=
shape_in
[
0
];
const
size_t
seq_len
=
shape_in
[
1
];
CHECK_LE
(
seq_len
,
2048
);
auto
input_cu
=
MakeNVTETensor
(
GetDevicePtr
(
input
),
shape_in
,
itype
);
void
*
softmax_ptr
=
AllocateSpace
(
shape_in
,
itype
);
auto
softmax_results_cu
=
MakeNVTETensor
(
softmax_ptr
,
shape_in
,
itype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_upper_triang_masked_softmax_forward
(
input_cu
.
data
(),
softmax_results_cu
.
data
(),
scale_factor
,
stream
);
auto
softmax_results_eager
=
CreateTensor
(
softmax_ptr
,
shape_in
,
itype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
softmax_results_eager
));
});
m
.
def
(
"scaled_upper_triang_masked_softmax_backward"
,
[](
const
pybind11
::
handle
&
dy
,
const
pybind11
::
handle
&
y
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_dy
=
GetShape
(
dy
);
std
::
vector
<
size_t
>
shape_y
=
GetShape
(
y
);
CHECK_EQ
(
shape_dy
.
size
(),
3
);
CHECK_EQ
(
shape_y
.
size
(),
3
);
auto
dytype
=
GetDataType
(
dy
);
auto
ytype
=
GetDataType
(
y
);
CHECK
(
dytype
==
DType
::
kFloat16
||
dytype
==
DType
::
kBFloat16
);
CHECK
(
ytype
==
DType
::
kFloat16
||
ytype
==
DType
::
kBFloat16
);
CHECK_EQ
(
shape_dy
[
1
],
shape_dy
[
2
]);
auto
dy_cu
=
MakeNVTETensor
(
GetDevicePtr
(
dy
),
shape_dy
,
dytype
);
auto
y_cu
=
MakeNVTETensor
(
GetDevicePtr
(
y
),
shape_y
,
ytype
);
void
*
dx_ptr
=
AllocateSpace
(
shape_dy
,
dytype
);
auto
dx_cu
=
MakeNVTETensor
(
dx_ptr
,
shape_dy
,
dytype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_upper_triang_masked_softmax_backward
(
dy_cu
.
data
(),
y_cu
.
data
(),
dx_cu
.
data
(),
scale_factor
,
stream
);
auto
dx_eager
=
CreateTensor
(
dx_ptr
,
shape_dy
,
dytype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
dx_eager
));
});
m
.
def
(
"scaled_masked_softmax_forward"
,
[](
const
pybind11
::
handle
&
x
,
const
pybind11
::
handle
&
mask
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_x
=
GetShape
(
x
);
std
::
vector
<
size_t
>
shape_m
=
GetShape
(
mask
);
CHECK_EQ
(
shape_x
.
size
(),
4
)
<<
"expected 4D tensor"
;
CHECK_EQ
(
shape_m
.
size
(),
4
)
<<
"expected 4D tensor"
;
auto
xtype
=
GetDataType
(
x
);
auto
mtype
=
GetDataType
(
mask
);
CHECK
(
xtype
==
DType
::
kFloat16
||
xtype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
CHECK
(
mtype
==
DType
::
kByte
)
<<
"Only bool are supported for mask"
;
const
size_t
batches
=
shape_x
[
0
];
const
size_t
pad_batches
=
shape_m
[
0
];
const
size_t
attn_heads
=
shape_x
[
1
];
const
size_t
query_seq_len
=
shape_x
[
2
];
const
size_t
key_seq_len
=
shape_x
[
3
];
CHECK_LE
(
key_seq_len
,
4096
);
CHECK_GT
(
query_seq_len
,
1
);
CHECK
(
pad_batches
==
1
||
pad_batches
==
batches
);
CHECK_EQ
(
shape_m
[
1
],
1
);
CHECK
(
shape_m
[
2
]
==
query_seq_len
);
CHECK
(
shape_m
[
3
]
==
key_seq_len
);
void
*
softmax_ptr
=
AllocateSpace
(
shape_x
,
xtype
);
auto
softmax_results_cu
=
MakeNVTETensor
(
softmax_ptr
,
shape_x
,
xtype
);
auto
input_cu
=
MakeNVTETensor
(
GetDevicePtr
(
x
),
shape_x
,
xtype
);
auto
mask_cu
=
MakeNVTETensor
(
GetDevicePtr
(
mask
),
shape_m
,
mtype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_masked_softmax_forward
(
input_cu
.
data
(),
mask_cu
.
data
(),
softmax_results_cu
.
data
(),
scale_factor
,
stream
);
auto
softmax_results_eager
=
CreateTensor
(
softmax_ptr
,
shape_x
,
xtype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
softmax_results_eager
));
});
m
.
def
(
"scaled_masked_softmax_backward"
,
[](
const
pybind11
::
handle
&
dy
,
const
pybind11
::
handle
&
y
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_dy
=
GetShape
(
dy
);
std
::
vector
<
size_t
>
shape_y
=
GetShape
(
y
);
CHECK_EQ
(
shape_dy
.
size
(),
4
)
<<
"expected 4D tensor"
;
CHECK_EQ
(
shape_y
.
size
(),
4
)
<<
"expected 4D tensor"
;
auto
dytype
=
GetDataType
(
dy
);
auto
ytype
=
GetDataType
(
y
);
CHECK
(
dytype
==
DType
::
kFloat16
||
dytype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
CHECK
(
ytype
==
DType
::
kFloat16
||
ytype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
auto
dy_cu
=
MakeNVTETensor
(
GetDevicePtr
(
dy
),
shape_dy
,
dytype
);
auto
y_cu
=
MakeNVTETensor
(
GetDevicePtr
(
y
),
shape_y
,
ytype
);
void
*
dx_ptr
=
AllocateSpace
(
shape_dy
,
dytype
);
auto
dx_cu
=
MakeNVTETensor
(
dx_ptr
,
shape_dy
,
dytype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_masked_softmax_backward
(
dy_cu
.
data
(),
y_cu
.
data
(),
dx_cu
.
data
(),
scale_factor
,
stream
);
auto
dx_eager
=
CreateTensor
(
dx_ptr
,
shape_dy
,
dytype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
dx_eager
));
});
m
.
def
(
"scaled_softmax_forward"
,
[](
const
pybind11
::
handle
&
x
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_x
=
GetShape
(
x
);
CHECK_EQ
(
shape_x
.
size
(),
4
)
<<
"expected 4D tensor"
;
auto
xtype
=
GetDataType
(
x
);
CHECK
(
xtype
==
DType
::
kFloat16
||
xtype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
const
size_t
batches
=
shape_x
[
0
];
const
size_t
attn_heads
=
shape_x
[
1
];
const
size_t
query_seq_len
=
shape_x
[
2
];
const
size_t
key_seq_len
=
shape_x
[
3
];
CHECK_LE
(
key_seq_len
,
4096
);
CHECK_GT
(
query_seq_len
,
1
);
void
*
softmax_ptr
=
AllocateSpace
(
shape_x
,
xtype
);
auto
softmax_results_cu
=
MakeNVTETensor
(
softmax_ptr
,
shape_x
,
xtype
);
auto
input_cu
=
MakeNVTETensor
(
GetDevicePtr
(
x
),
shape_x
,
xtype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_softmax_forward
(
input_cu
.
data
(),
softmax_results_cu
.
data
(),
scale_factor
,
stream
);
auto
softmax_results_eager
=
CreateTensor
(
softmax_ptr
,
shape_x
,
xtype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
softmax_results_eager
));
});
m
.
def
(
"scaled_softmax_backward"
,
[](
const
pybind11
::
handle
&
dy
,
const
pybind11
::
handle
&
y
,
const
float
scale_factor
,
const
int64_t
stream_id
)
{
using
namespace
transformer_engine
;
std
::
vector
<
size_t
>
shape_dy
=
GetShape
(
dy
);
std
::
vector
<
size_t
>
shape_y
=
GetShape
(
y
);
CHECK_EQ
(
shape_dy
.
size
(),
4
)
<<
"expected 4D tensor"
;
CHECK_EQ
(
shape_y
.
size
(),
4
)
<<
"expected 4D tensor"
;
auto
dytype
=
GetDataType
(
dy
);
auto
ytype
=
GetDataType
(
y
);
CHECK
(
dytype
==
DType
::
kFloat16
||
dytype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
CHECK
(
ytype
==
DType
::
kFloat16
||
ytype
==
DType
::
kBFloat16
)
<<
"Only fp16 and bf16 are supported"
;
auto
dy_cu
=
MakeNVTETensor
(
GetDevicePtr
(
dy
),
shape_dy
,
dytype
);
auto
y_cu
=
MakeNVTETensor
(
GetDevicePtr
(
y
),
shape_y
,
ytype
);
void
*
dx_ptr
=
AllocateSpace
(
shape_dy
,
dytype
);
auto
dx_cu
=
MakeNVTETensor
(
dx_ptr
,
shape_dy
,
dytype
);
cudaStream_t
stream
=
reinterpret_cast
<
cudaStream_t
>
(
stream_id
);
nvte_scaled_softmax_backward
(
dy_cu
.
data
(),
y_cu
.
data
(),
dx_cu
.
data
(),
scale_factor
,
stream
);
auto
dx_eager
=
CreateTensor
(
dx_ptr
,
shape_dy
,
dytype
);
return
tensorflow
::
PyoOrThrow
(
EagerTensorFromHandle
(
dx_eager
));
});
}
transformer_engine/tensorflow/csrc/get_stream_op.cpp
deleted
100644 → 0
View file @
2574a1ca
/*************************************************************************
* Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_stream.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace
tensorflow
{
class
GetStreamOp
:
public
OpKernel
{
public:
explicit
GetStreamOp
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{}
void
Compute
(
OpKernelContext
*
ctx
)
override
{
Tensor
*
output
=
nullptr
;
OP_REQUIRES_OK
(
ctx
,
ctx
->
allocate_output
(
"stream_id"
,
{
1
},
&
output
));
auto
vec
=
output
->
vec
<
int64_t
>
();
se
::
Stream
*
stream
=
ctx
->
op_device_context
()
->
stream
();
auto
gpu_stream
=
se
::
gpu
::
AsGpuStreamValue
(
stream
);
vec
(
0
)
=
static_cast
<
int64_t
>
(
reinterpret_cast
<
uintptr_t
>
(
gpu_stream
));
}
};
REGISTER_OP
(
"GetStream"
)
.
Output
(
"stream_id: int64"
)
.
SetShapeFn
(
shape_inference
::
UnknownShape
);
REGISTER_OP_NO_GRADIENT
(
"GetStream"
);
REGISTER_KERNEL_BUILDER
(
Name
(
"GetStream"
).
Device
(
DEVICE_GPU
).
HostMemory
(
"stream_id"
),
GetStreamOp
);
}
// namespace tensorflow
transformer_engine/tensorflow/fp8.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""FP8 utilies for TransformerEngine"""
from
contextlib
import
contextmanager
from
typing
import
Generator
,
Optional
,
Dict
,
Any
import
tensorflow
as
tf
import
transformer_engine_tensorflow
as
tex
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
_FP8_ENABLED
=
False
_FP8_RECIPE
=
None
_FP8_DISTRIBUTED_GROUP
=
None
_IS_FIRST_FP8_MODULE
=
False
_FP8_AUTOCAST_COUNTER
=
0
_FP8_CURRENT_CONTEXT_ID
=
0
_FP8_AUTOCAST_DEPTH
=
0
_global_fp8_buffer
=
{}
_amax_forward_global_reduce_func
=
lambda
:
None
_buffer_delete_key_fwd
=
None
_buffer_delete_key_bwd
=
None
def
get_meta_tensor_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns scaling key in `fp8_meta`."""
if
forward
:
return
"scaling_fwd"
return
"scaling_bwd"
def
get_autocast_key
(
forward
:
bool
=
True
)
->
str
:
"""Returns module position key in `fp8_meta`."""
if
forward
:
return
"autocast_id_fwd"
return
"autocast_id_bwd"
def
get_amax_buffer_key
(
fp8_meta
:
Dict
[
str
,
Any
],
forward
:
bool
=
True
)
->
str
:
"""Return a key in `_global_fp8_buffer` for the AMAX storage."""
if
forward
:
return
f
"FWD_AMAX_
{
fp8_meta
[
'autocast_id_fwd'
]
}
"
return
f
"BWD_AMAX_
{
fp8_meta
[
'autocast_id_bwd'
]
}
"
def
set_amax_buffer_key_deletion
(
fp8_meta
:
Dict
[
str
,
Any
],
forward
:
bool
=
True
)
->
None
:
"""Delete this amax key from global buffer during autocast end."""
if
get_autocast_key
(
forward
=
forward
)
not
in
fp8_meta
:
return
global
_buffer_delete_key_fwd
,
_buffer_delete_key_bwd
if
forward
:
_buffer_delete_key_fwd
=
get_amax_buffer_key
(
fp8_meta
,
forward
=
forward
)
else
:
_buffer_delete_key_bwd
=
get_amax_buffer_key
(
fp8_meta
,
forward
=
forward
)
def
get_default_fp8_recipe
():
"""FP8 recipe if not provided by user
Margin = 0, interval = 1, E4M3
"""
return
DelayedScaling
()
@
contextmanager
def
fp8_autocast
(
enabled
:
bool
=
False
,
fp8_recipe
:
Optional
[
DelayedScaling
]
=
None
,
)
->
Generator
[
None
,
None
,
None
]:
"""
Context manager for FP8 usage.
.. code-block:: python
with fp8_autocast(enabled=True):
out = model(inp)
.. note::
Support for FP8 in the Dense layer of Transformer Engine is currently
limited to tensors with shapes where both dimensions are divisible by 16.
In terms of the input to the full Transformer network, this typically
requires padding sequence length to be multiple of 16.
Parameters
----------
enabled: bool, default = `False`
whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = `None`
recipe used for FP8 training.
"""
global
_FP8_ENABLED
,
_FP8_RECIPE
,
_FP8_DISTRIBUTED_GROUP
,
_FP8_AUTOCAST_DEPTH
global
_IS_FIRST_FP8_MODULE
,
_FP8_AUTOCAST_COUNTER
global
_global_fp8_buffer
,
_buffer_delete_key_fwd
fp8_state
=
(
_FP8_ENABLED
,
_FP8_RECIPE
,
_FP8_DISTRIBUTED_GROUP
)
try
:
_FP8_ENABLED
=
enabled
_FP8_RECIPE
=
get_default_fp8_recipe
()
if
fp8_recipe
is
None
else
fp8_recipe
if
_FP8_AUTOCAST_DEPTH
==
0
:
_IS_FIRST_FP8_MODULE
=
True
_FP8_AUTOCAST_COUNTER
+=
1
_FP8_AUTOCAST_DEPTH
+=
1
yield
finally
:
_FP8_ENABLED
,
_FP8_RECIPE
,
_FP8_DISTRIBUTED_GROUP
=
fp8_state
_IS_FIRST_FP8_MODULE
=
False
_FP8_AUTOCAST_DEPTH
-=
1
if
_FP8_AUTOCAST_DEPTH
==
0
:
if
callable
(
_amax_forward_global_reduce_func
):
_amax_forward_global_reduce_func
()
delete_key_from_amax_buffer
(
forward
=
True
)
def
get_fp8_context_id
()
->
int
:
"""Returns an ID for the current FP8 context."""
return
_FP8_CURRENT_CONTEXT_ID
def
set_fp8_context_id
(
ctx_id
:
int
)
->
None
:
"""Sets the current FP8 context."""
global
_FP8_CURRENT_CONTEXT_ID
_FP8_CURRENT_CONTEXT_ID
=
ctx_id
def
new_fp8_context_id
()
->
int
:
"""Returns global autocast counter as a proxy to be used
as the autocast ID for FP8 modules.
"""
return
_FP8_AUTOCAST_COUNTER
def
is_fp8_enabled
():
"""Is FP8 enabled"""
return
_FP8_ENABLED
def
is_first_fp8_module
():
"""Returns `True` only the first time when called multiple
times from within the same `fp8_autocast` context.
"""
global
_IS_FIRST_FP8_MODULE
tmp
=
_IS_FIRST_FP8_MODULE
_IS_FIRST_FP8_MODULE
=
False
return
tmp
def
get_fp8_recipe
():
"""Return the fp8 recipe"""
return
_FP8_RECIPE
def
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
margin
):
"""Default function to convert amax to scaling factor."""
sf
=
(
fp8_max
/
amax
)
/
(
2
**
margin
)
sf
=
tf
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
tf
.
where
(
tf
.
math
.
is_finite
(
amax
),
sf
,
scale
)
return
sf
def
_roll_and_zero_out
(
amax_history
):
"""Update amax history and set next amax to zero."""
amax_history
=
tf
.
roll
(
amax_history
,
-
1
,
0
)
zeros
=
tf
.
zeros
(
shape
=
amax_history
[
0
].
shape
)
updated
=
tf
.
tensor_scatter_nd_update
(
amax_history
,
[[
0
]],
[
zeros
])
return
updated
@
tf
.
function
(
jit_compile
=
True
)
def
_reduce_max_and_default_sf_compute
(
amax_history
,
scale
,
fp8_max
,
margin
):
"""Get amax using max algorithm and compute scaling factor."""
amax
=
tf
.
reduce_max
(
amax_history
,
axis
=
0
)
sf
=
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
margin
)
updated
=
_roll_and_zero_out
(
amax_history
)
return
updated
,
sf
@
tf
.
function
(
jit_compile
=
True
)
def
_most_recent_and_default_sf_compute
(
amax_history
,
scale
,
fp8_max
,
margin
):
"""Get amax using most-recent algorithm and compute scaling factor."""
amax
=
amax_history
[
0
]
sf
=
_default_sf_compute
(
amax
,
scale
,
fp8_max
,
margin
)
updated
=
_roll_and_zero_out
(
amax_history
)
return
updated
,
sf
def
fused_amax_and_scale_update
(
amax_history
:
tf
.
Variable
,
scale
:
tf
.
Variable
,
scale_inv
:
tf
.
Variable
,
fp8_max
:
float
,
margin
:
int
,
amax_compute_algo
:
str
,
):
"""Amax to scale conversion."""
if
amax_compute_algo
==
"max"
:
updated
,
sf
=
_reduce_max_and_default_sf_compute
(
amax_history
,
scale
,
fp8_max
,
margin
)
else
:
assert
amax_compute_algo
==
"most_recent"
updated
,
sf
=
_most_recent_and_default_sf_compute
(
amax_history
,
scale
,
fp8_max
,
margin
)
amax_history
.
assign
(
updated
)
scale
.
assign
(
sf
)
scale_inv
.
assign
(
1.0
/
sf
)
def
amax_and_scale_update
(
fp8_meta
:
Dict
[
str
,
Any
],
fwd_update
:
bool
,
)
->
None
:
"""Updates fp8 amaxes/scales for fwd | bwd."""
amax_compute
=
fp8_meta
[
"recipe"
].
amax_compute_algo
sf_compute
=
fp8_meta
[
"recipe"
].
scaling_factor_compute_algo
fp8_meta_tensor_key
=
"scaling_fwd"
if
fwd_update
else
"scaling_bwd"
fp8_max_key
=
"fp8_max_fwd"
if
fwd_update
else
"fp8_max_bwd"
if
not
callable
(
amax_compute
)
and
sf_compute
is
None
:
fused_amax_and_scale_update
(
fp8_meta
[
fp8_meta_tensor_key
][
"amax_history"
],
fp8_meta
[
fp8_meta_tensor_key
][
"scale"
],
fp8_meta
[
fp8_meta_tensor_key
][
"scale_inv"
],
fp8_meta
[
fp8_max_key
],
fp8_meta
[
"recipe"
].
margin
,
fp8_meta
[
"recipe"
].
amax_compute_algo
,
)
else
:
raise
ValueError
(
"We only support the fp8 recipe with 'max' or 'most_recent' "
"amax_compute_algo and default scaling_factor_compute_algo at this "
"moment."
)
def
get_fp8_te_dtype
(
fp8_recipe
:
DelayedScaling
,
fprop_tensor
:
bool
=
True
):
"""Get fp8 data type according to recipe and tensor"""
if
fp8_recipe
.
fp8_format
==
Format
.
E4M3
or
(
fp8_recipe
.
fp8_format
==
Format
.
HYBRID
and
fprop_tensor
):
return
tex
.
DType
.
kFloat8E4M3
return
tex
.
DType
.
kFloat8E5M2
def
delete_key_from_amax_buffer
(
forward
:
bool
=
True
)
->
None
:
"""Delete the key from global amax buffer."""
global
_global_fp8_buffer
,
_buffer_delete_key_fwd
,
_buffer_delete_key_bwd
if
forward
:
if
(
_buffer_delete_key_fwd
is
not
None
and
_buffer_delete_key_fwd
in
_global_fp8_buffer
):
del
_global_fp8_buffer
[
_buffer_delete_key_fwd
]
else
:
if
(
_buffer_delete_key_bwd
is
not
None
and
_buffer_delete_key_bwd
in
_global_fp8_buffer
):
del
_global_fp8_buffer
[
_buffer_delete_key_bwd
]
transformer_engine/tensorflow/jit.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""XLA functions and JIT utilities"""
from
typing
import
Callable
import
tensorflow
as
tf
@
tf
.
function
(
jit_compile
=
True
)
def
_bgrad_dgelu_fused
(
grad_output
,
inp
):
"""Bgrad-Dgelu fused"""
x
=
inp
tanh_out
=
tf
.
math
.
tanh
(
0.79788456
*
x
*
(
1
+
0.044715
*
x
*
x
))
# sqrt(2/pi) * 3 * 0.044715 -> 0.1070322243
ff
=
0.5
*
x
*
(
(
1
-
tanh_out
*
tanh_out
)
*
(
0.79788456
+
0.1070322243
*
x
*
x
)
)
+
0.5
*
(
1
+
tanh_out
)
dgelu
=
ff
*
grad_output
bgrad
=
tf
.
math
.
reduce_sum
(
dgelu
,
axis
=
0
)
return
bgrad
,
dgelu
def
bgrad_dgelu_fused
(
grad_output
,
inp
):
"""Bgrad-Dgelu fused"""
return
_bgrad_dgelu_fused
(
grad_output
,
inp
)
def
bias_dropout_add
(
x
:
tf
.
Tensor
,
bias
:
tf
.
Variable
,
residual
:
tf
.
Tensor
,
prob
:
float
,
training
:
bool
,
)
->
tf
.
Tensor
:
"""dropout(inp + bias) + residual"""
# TODO(kaixih): Use stateless_dropout and specify the seed mainly for
# debugging purpose. Should allow random seed.
out
=
(
tf
.
nn
.
experimental
.
stateless_dropout
(
x
+
bias
,
rate
=
prob
,
seed
=
[
1
,
0
],
)
if
training
else
x
+
bias
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
:
bool
)
->
Callable
:
"""bias_dropout_add based on training or not"""
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
tf
.
function
(
jit_compile
=
True
)
def
bias_dropout_add_fused_train_
(
x
:
tf
.
Tensor
,
bias
:
tf
.
Variable
,
residual
:
tf
.
Tensor
,
prob
:
float
,
)
->
tf
.
Tensor
:
"""Jit fused bias_dropout_add for training"""
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
def
bias_dropout_add_fused_train
(
x
:
tf
.
Tensor
,
bias
:
tf
.
Variable
,
residual
:
tf
.
Tensor
,
prob
:
float
,
)
->
tf
.
Tensor
:
"""Jit fused bias_dropout_add for training"""
return
bias_dropout_add_fused_train_
(
x
,
bias
,
residual
,
prob
)
@
tf
.
function
(
jit_compile
=
True
)
def
bias_dropout_add_fused_inference_
(
x
:
tf
.
Tensor
,
bias
:
tf
.
Variable
,
residual
:
tf
.
Tensor
,
prob
:
float
,
)
->
tf
.
Tensor
:
"""Jit fused bias_dropout_add for inference"""
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
def
bias_dropout_add_fused_inference
(
x
:
tf
.
Tensor
,
bias
:
tf
.
Variable
,
residual
:
tf
.
Tensor
,
prob
:
float
,
)
->
tf
.
Tensor
:
"""Jit fused bias_dropout_add for inference"""
return
bias_dropout_add_fused_inference_
(
x
,
bias
,
residual
,
prob
)
transformer_engine/tensorflow/module.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level Transformer Engine PyTorch modules"""
from
typing
import
Union
,
Callable
from
keras
import
backend
,
layers
,
initializers
import
tensorflow
as
tf
import
transformer_engine_tensorflow
as
tex
from
.constants
import
TE_DType
from
.fp8
import
(
is_fp8_enabled
,
get_fp8_recipe
,
get_default_fp8_recipe
,
get_fp8_te_dtype
,
is_first_fp8_module
,
new_fp8_context_id
,
get_fp8_context_id
,
set_fp8_context_id
,
amax_and_scale_update
,
set_amax_buffer_key_deletion
,
get_meta_tensor_key
,
)
from
.jit
import
(
bgrad_dgelu_fused
,
)
stream_lib
=
tf
.
load_op_library
(
tf
.
compat
.
v1
.
resource_loader
.
get_path_to_datafile
(
tf
.
sysconfig
.
get_lib
()
+
"/../lib_get_stream.so"
)
)
def
get_stream_id
():
"""Get stream index for GPU tasks."""
return
stream_lib
.
get_stream
().
numpy
()[
0
]
_2X_ACC_FPROP
=
False
_2X_ACC_DGRAD
=
True
_2X_ACC_WGRAD
=
True
_cublas_workspace
=
None
def
get_workspace
():
"""Returns workspace for cublas."""
global
_cublas_workspace
if
_cublas_workspace
is
None
:
_cublas_workspace
=
tf
.
zeros
([
33_554_432
],
dtype
=
tf
.
int8
)
return
_cublas_workspace
def
get_autocast_bias
(
dtype
,
bias_var
,
use_bias
,
use_fp8
):
"""Get casted bias for fp8 gemm."""
if
not
use_bias
:
return
None
# We need to pass the EagerTensor instead of Variable when calling into the
# pybind functions. So, we use value() for the explicit convertion.
bias
=
bias_var
.
value
()
if
dtype
==
"float16"
:
bias
=
tf
.
cast
(
bias
,
dtype
)
if
use_fp8
and
bias
.
dtype
==
tf
.
float32
:
bias
=
tf
.
cast
(
bias
,
dtype
=
tf
.
bfloat16
)
return
bias
def
get_init_method
(
user_input
,
default_init_method
):
"""Get initializer method for variables."""
if
user_input
is
None
:
return
default_init_method
if
callable
(
user_input
):
return
user_input
assert
isinstance
(
user_input
,
str
)
return
initializers
.
get
(
user_input
)
def
cast_to_fp8_wrapper
(
x
,
fp8_meta
,
amax_index
,
fwd
,
output_dtype
,
stream_id
):
"""Wrapper to call the tex.cast_to_fp8."""
scaling_key
=
get_meta_tensor_key
(
fwd
)
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
x_fp8
=
tex
.
cast_to_fp8
(
x
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
stream_id
)
return
x_fp8
def
cast_from_fp8_wrapper
(
x
,
fp8_meta
,
amax_index
,
fwd
,
idtype
,
odtype
,
sid
):
"""Wrapper to call the tex.cast_from_fp8."""
scaling_key
=
"scaling_fwd"
if
fwd
else
"scaling_bwd"
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
x_fp8
=
tex
.
cast_from_fp8
(
x
,
scale_inv
,
idtype
,
odtype
,
amax_index
,
sid
)
return
x_fp8
def
fp8_cast_transpose_fused_wrapper
(
x
,
fp8_meta
,
amax_index
,
fwd
,
output_dtype
,
sid
):
"""Wrapper to call the tex.fp8_cast_transpose_fused."""
scaling_key
=
get_meta_tensor_key
(
fwd
)
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
x_fp8
,
x_t_fp8
=
tex
.
fp8_cast_transpose_fused
(
x
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
sid
)
return
x_fp8
,
x_t_fp8
def
fp8_cast_transpose_bgrad_fused_wrapper
(
x
,
fp8_meta
,
amax_index
,
fwd
,
output_dtype
,
sid
):
"""Wrapper to call the tex.fp8_cast_transpose_bgrad_fused."""
scaling_key
=
get_meta_tensor_key
(
fwd
)
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
grad_bias
,
grad_fp8
,
grad_t_fp8
=
tex
.
fp8_cast_transpose_bgrad_fused
(
x
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
sid
)
return
grad_bias
,
grad_fp8
,
grad_t_fp8
def
fp8_cast_transpose_bgrad_dgelu_fused_wrapper
(
dy
,
x
,
fp8_meta
,
amax_index
,
fwd
,
output_dtype
,
sid
):
"""Wrapper to call the tex.fp8_fused_cast_transpose_bgrad_dgelu."""
scaling_key
=
get_meta_tensor_key
(
fwd
)
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
dbias
,
dgelu_c
,
dgelu_t
=
tex
.
fp8_fused_cast_transpose_bgrad_dgelu
(
dy
,
x
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
sid
)
return
dbias
,
dgelu_c
,
dgelu_t
def
fp8_gelu_wrapper
(
x
,
fp8_meta
,
amax_index
,
fwd
,
output_dtype
,
sid
):
"""Wrapper to call the tex.te_gelu."""
scaling_key
=
get_meta_tensor_key
(
fwd
)
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
y_fp8
=
tex
.
te_gelu
(
x
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
sid
)
return
y_fp8
def
matmul_wrapper
(
inp
,
weight
,
mode
,
output_dtype
,
sid
,
use_bias
=
False
,
bias
=
None
,
grad
=
False
,
gelu
=
False
,
gelu_input
=
None
,
):
"""Wrapper to call the tex.te_gemm for the non-fp8 gemm."""
A
=
inp
B
=
weight
A_dtype
,
B_dtype
=
TE_DType
[
A
.
dtype
],
TE_DType
[
B
.
dtype
]
A_offset
,
B_offset
=
-
1
,
-
1
if
mode
in
(
"fwd"
,
"fc1_fwd"
,
"fc2_fwd"
):
transA
,
transB
=
False
,
False
elif
mode
in
(
"bwd_input"
,
"fc1_bwd_input"
,
"fc2_bwd_input"
):
transA
,
transB
=
False
,
True
elif
mode
in
(
"bwd_weight"
,
"fc1_bwd_weight"
,
"fc2_bwd_weight"
):
transA
,
transB
=
True
,
False
return
tex
.
te_gemm
(
B
,
None
,
B_dtype
,
B_offset
,
A
,
None
,
A_dtype
,
A_offset
,
get_workspace
(),
use_bias
,
bias
,
gelu
,
gelu_input
,
transB
,
transA
,
grad
,
False
,
# accumulate
False
,
# accumulate
TE_DType
[
output_dtype
],
sid
,
)
def
fp8_matmul_wrapper
(
inp
,
weight
,
fp8_meta
,
mode
,
A_dtype
,
B_dtype
,
output_dtype
,
use_split_accumulate
,
sid
,
use_bias
=
False
,
bias
=
None
,
):
"""Wrapper to call the tex.te_gemm for the fp8 gemm."""
A
=
inp
B
=
weight
if
mode
in
(
"fwd"
,
"fc1_fwd"
):
A_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8FwdTensors
.
GEMM1_INPUT
B_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
elif
mode
==
"fc2_fwd"
:
A_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8FwdTensors
.
GEMM2_INPUT
B_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
elif
mode
==
"bwd_input"
:
A_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
B_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
elif
mode
==
"fc1_bwd_input"
:
A_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
B_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
elif
mode
==
"fc2_bwd_input"
:
A_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
B_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
elif
mode
==
"bwd_weight"
:
A_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8FwdTensors
.
GEMM1_INPUT
B_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
elif
mode
==
"fc2_bwd_weight"
:
A_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8FwdTensors
.
GEMM2_INPUT
B_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
elif
mode
==
"fc1_bwd_weight"
:
A_scale_inv
=
fp8_meta
[
"scaling_fwd"
][
"scale_inv"
].
value
()
A_offset
=
tex
.
FP8FwdTensors
.
GEMM1_INPUT
B_scale_inv
=
fp8_meta
[
"scaling_bwd"
][
"scale_inv"
].
value
()
B_offset
=
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
return
tex
.
te_gemm
(
B
,
B_scale_inv
,
B_dtype
,
B_offset
,
A
,
A_scale_inv
,
A_dtype
,
A_offset
,
get_workspace
(),
use_bias
,
bias
,
False
,
# use_gelu
None
,
# gelu_input
True
,
# transa
False
,
# transb
False
,
# grad
False
,
# accumulate
use_split_accumulate
,
TE_DType
[
output_dtype
],
sid
,
)
def
layernorm_fwd_fp8_wrapper
(
x
,
ln_gamma
,
ln_beta
,
epsilon
,
fp8_meta
,
amax_index
,
output_dtype
,
sid
):
"""Wrapper to call the tex.layernorm_fwd_fp8."""
scaling_key
=
"scaling_fwd"
scale
=
fp8_meta
[
scaling_key
][
"scale"
].
value
()
amax
=
fp8_meta
[
scaling_key
][
"amax_history"
].
value
()
scale_inv
=
fp8_meta
[
scaling_key
][
"scale_inv"
].
value
()
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd_fp8
(
x
,
ln_gamma
,
ln_beta
,
epsilon
,
scale
,
output_dtype
,
amax
,
scale_inv
,
amax_index
,
sid
,
)
return
ln_out
,
mu
,
rsigma
# The DelayedScaling object is not supported in TF autograd. So, to avoid
# passing this object to the custom gradient function, we only extract the
# useful information.
def
get_recipe_attrs
(
recipe
):
"""Get attributes from the recipe."""
fp8_dtype_fwd
=
get_fp8_te_dtype
(
recipe
,
fprop_tensor
=
True
)
fp8_dtype_bwd
=
get_fp8_te_dtype
(
recipe
,
fprop_tensor
=
False
)
override_linear_precision
=
recipe
.
override_linear_precision
return
(
fp8_dtype_fwd
,
fp8_dtype_bwd
,
override_linear_precision
)
# TransformerEngineBaseModule is a mixin class and its init function will pass
# through all the positional and keyword arguments to other subclasses. Make
# sure this class is inherited first.
class
TransformerEngineBaseModule
:
"""Base TE module."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
# fp8 related
self
.
fp8
=
False
self
.
fp8_meta
=
{}
self
.
fp8_meta
[
"recipe"
]
=
get_default_fp8_recipe
()
self
.
fp8_meta_tensors_initialized
=
False
self
.
fp8_weight_shapes
=
[]
self
.
stream_id
=
get_stream_id
()
def
set_meta_tensor
(
self
,
fwd
):
"""Init scales and amaxes for fwd | bwd."""
fp8_meta_tensor_key
=
"scaling_fwd"
if
fwd
else
"scaling_bwd"
num_fp8_tensors
=
(
self
.
fp8_meta
[
"num_gemms"
]
*
2
if
fwd
else
self
.
fp8_meta
[
"num_gemms"
]
)
self
.
fp8_meta
[
fp8_meta_tensor_key
]
=
{}
self
.
fp8_meta
[
fp8_meta_tensor_key
][
"scale"
]
=
tf
.
Variable
(
tf
.
ones
((
num_fp8_tensors
),
dtype
=
tf
.
float32
),
trainable
=
False
)
self
.
fp8_meta
[
fp8_meta_tensor_key
][
"scale_inv"
]
=
tf
.
Variable
(
tf
.
ones
((
num_fp8_tensors
),
dtype
=
tf
.
float32
),
trainable
=
False
)
self
.
fp8_meta
[
fp8_meta_tensor_key
][
"amax_history"
]
=
tf
.
Variable
(
tf
.
zeros
(
(
self
.
fp8_meta
[
"recipe"
].
amax_history_len
,
num_fp8_tensors
),
dtype
=
tf
.
float32
,
),
trainable
=
False
,
)
def
init_fp8_meta_tensors
(
self
):
"""Init scales and amaxes."""
# Checkpoint loaded
if
self
.
fp8_meta_tensors_initialized
:
return
self
.
set_meta_tensor
(
True
)
self
.
set_meta_tensor
(
False
)
def
fp8_init
(
self
,
num_gemms
=
1
):
"""Initialize fp8 related metadata and tensors during fprop."""
if
not
is_fp8_enabled
():
self
.
fp8
=
False
return
# FP8 is already enabled and recipe is the same, don't do anything.
if
self
.
fp8
and
get_fp8_recipe
()
==
self
.
fp8_meta
[
"recipe"
]:
return
# Set FP8, recipe, and other FP8 metadata
self
.
fp8
=
True
self
.
fp8_meta
[
"recipe"
]
=
get_fp8_recipe
()
self
.
fp8_meta
[
"num_gemms"
]
=
num_gemms
# Set FP8_MAX per tensor according to recipe
fp8_format_val
=
self
.
fp8_meta
[
"recipe"
].
fp8_format
.
value
self
.
fp8_meta
[
"fp8_max_fwd"
]
=
fp8_format_val
.
max_fwd
self
.
fp8_meta
[
"fp8_max_bwd"
]
=
fp8_format_val
.
max_bwd
# Allocate scales and amaxes
self
.
init_fp8_meta_tensors
()
def
pre_forward
(
self
,
training
,
num_gemms
=
1
):
"""Checks and prep for FWD."""
self
.
fp8_init
(
num_gemms
=
num_gemms
)
if
self
.
fp8
:
if
self
.
fp8_meta
.
get
(
"update_amax_and_scale_fwd"
,
False
):
# Previous iteration was grad_enabled
amax_and_scale_update
(
self
.
fp8_meta
,
True
)
set_amax_buffer_key_deletion
(
self
.
fp8_meta
,
forward
=
True
)
if
training
:
self
.
fp8_meta
[
"first_module"
]
=
is_first_fp8_module
()
if
self
.
fp8_meta
[
"first_module"
]:
self
.
fp8_meta
[
"autocast_id_fwd"
]
=
new_fp8_context_id
()
set_fp8_context_id
(
self
.
fp8_meta
[
"autocast_id_fwd"
])
else
:
self
.
fp8_meta
[
"autocast_id_fwd"
]
=
get_fp8_context_id
()
self
.
fp8_meta
[
"update_amax_and_scale_fwd"
]
=
True
# Create an empty tensor as a placeholder for the backprop to
# correctly know how many tensors to autograd.
self
.
fp8_meta
[
"autocast_id_bwd"
]
=
-
1
else
:
self
.
fp8_meta
[
"update_amax_and_scale_fwd"
]
=
False
def
pre_backward
(
self
):
"""Checks and prep for BWD."""
# From previous iteration
amax_and_scale_update
(
self
.
fp8_meta
,
False
)
set_amax_buffer_key_deletion
(
self
.
fp8_meta
,
forward
=
False
)
class
Dense
(
TransformerEngineBaseModule
,
layers
.
Layer
):
"""
Applies a linear transformation to the incoming data :math:`y = xW + b`
On NVIDIA GPUs it is a drop-in replacement for `tf.keras.layers.Dense`.
Parameters
----------
units : int
size of each output sample.
use_bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
kernel_initializer: Callable, default = `None`
used for initializing weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
bias_initializer: Callable, default = `None`
used for initializing biases in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
Parallelism parameters
----------------------
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be passed as
a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def
__init__
(
self
,
units
:
int
,
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
kernel_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
bias_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
skip_weight_param_allocation
:
bool
=
False
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
units
=
units
self
.
use_bias
=
use_bias
self
.
return_bias
=
return_bias
self
.
kernel_initializer
=
get_init_method
(
kernel_initializer
,
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
)
self
.
bias_initializer
=
get_init_method
(
bias_initializer
,
initializers
.
get
(
"zeros"
)
)
self
.
skip_weight_param_allocation
=
skip_weight_param_allocation
def
build
(
self
,
input_shape
):
"""One-time allocation of the variables."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
if
last_dim
is
None
:
raise
ValueError
(
"The last dimension of the inputs to a Dense layer should be "
f
"defined. Found None. Full input shape received:
{
input_shape
}
"
)
self
.
kernel
=
None
self
.
bias
=
None
if
not
self
.
skip_weight_param_allocation
:
self
.
kernel
=
self
.
add_weight
(
name
=
"kernel"
,
shape
=
(
last_dim
,
self
.
units
),
initializer
=
self
.
kernel_initializer
,
trainable
=
True
,
)
if
self
.
use_bias
or
self
.
return_bias
:
self
.
bias
=
self
.
add_weight
(
name
=
"bias"
,
shape
=
(
self
.
units
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
# fp8 related
self
.
fp8_weight_shapes
.
append
((
last_dim
,
self
.
units
))
self
.
built
=
True
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
if
not
self
.
trainable
:
# When the layer is not trainable, it overrides the value passed
# from model.
training
=
False
return
training
def
non_fp8_matmul
(
self
,
inp
:
tf
.
Tensor
,
kernel_var
:
tf
.
Variable
,
bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd non-fp8 matmul."""
@
tf
.
custom_gradient
def
non_fp8_matmul_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val
=
kernel_var
.
value
()
bias
=
get_autocast_bias
(
self
.
compute_dtype
,
bias_var
,
self
.
use_bias
,
use_fp8
=
False
,
)
output_dtype
=
self
.
_compute_dtype_object
outputs
=
matmul_wrapper
(
x
,
kernel_val
,
"fwd"
,
output_dtype
,
self
.
stream_id
,
self
.
use_bias
,
bias
,
)
def
grad_fn
(
upstream
,
variables
=
None
):
grad_x
=
matmul_wrapper
(
upstream
,
kernel_val
,
"bwd_input"
,
output_dtype
,
self
.
stream_id
,
)
grad_weight
=
matmul_wrapper
(
x
,
upstream
,
"bwd_weight"
,
output_dtype
,
self
.
stream_id
)
if
self
.
use_bias
:
grad_bias
=
tf
.
math
.
reduce_sum
(
upstream
,
axis
=
0
)
grad_inputs
=
[
grad_x
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
grad_bias
)
elif
v
.
name
.
endswith
(
"kernel:0"
):
grad_vars
.
append
(
grad_weight
)
return
grad_inputs
,
grad_vars
return
outputs
,
grad_fn
return
non_fp8_matmul_func
(
inp
)
def
fp8_matmul
(
self
,
inp
:
tf
.
Tensor
,
kernel_var
:
tf
.
Variable
,
bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd fp8 matmul."""
fp8_meta
=
self
.
fp8_meta
fp8_dtype_fwd
,
fp8_dtype_bwd
,
override_linear_precision
=
\
get_recipe_attrs
(
fp8_meta
[
"recipe"
])
@
tf
.
custom_gradient
def
fp8_matmul_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val
=
kernel_var
.
value
()
bias
=
get_autocast_bias
(
self
.
compute_dtype
,
bias_var
,
self
.
use_bias
,
use_fp8
=
True
,
)
if
not
override_linear_precision
.
wgrad
:
x_fp8
,
x_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
x
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
else
:
x_fp8
=
cast_to_fp8_wrapper
(
x
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
weight_fp8
,
weight_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
kernel_val
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
output_dtype
=
self
.
_compute_dtype_object
outputs
=
fp8_matmul_wrapper
(
x_fp8
,
weight_t_fp8
,
fp8_meta
,
"fwd"
,
fp8_dtype_fwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_FPROP
,
self
.
stream_id
,
self
.
use_bias
,
bias
,
)
def
grad_fn
(
upstream
,
variables
=
None
):
self
.
pre_backward
()
if
self
.
use_bias
:
(
grad_bias
,
grad_fp8
,
grad_t_fp8
,
)
=
fp8_cast_transpose_bgrad_fused_wrapper
(
upstream
,
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
if
not
override_linear_precision
.
wgrad
:
grad_fp8
,
grad_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
upstream
,
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
grad_fp8
=
cast_to_fp8_wrapper
(
upstream
,
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
grad_x
=
fp8_matmul_wrapper
(
grad_fp8
,
weight_fp8
,
fp8_meta
,
"bwd_input"
,
fp8_dtype_bwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_DGRAD
,
self
.
stream_id
,
)
if
not
override_linear_precision
.
wgrad
:
grad_weight
=
fp8_matmul_wrapper
(
x_t_fp8
,
grad_t_fp8
,
fp8_meta
,
"bwd_weight"
,
fp8_dtype_fwd
,
fp8_dtype_bwd
,
output_dtype
,
_2X_ACC_WGRAD
,
self
.
stream_id
,
)
else
:
grad_weight
=
matmul_wrapper
(
x
,
upstream
,
"bwd_weight"
,
output_dtype
,
self
.
stream_id
)
grad_inputs
=
[
grad_x
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
grad_bias
)
elif
v
.
name
.
endswith
(
"kernel:0"
):
grad_vars
.
append
(
grad_weight
)
return
grad_inputs
,
grad_vars
return
outputs
,
grad_fn
return
fp8_matmul_func
(
inp
)
def
call
(
self
,
inputs
,
kernel
=
None
,
bias
=
None
,
training
=
None
,
):
"""
Apply the linear transformation to the input.
Parameters
----------
inp : tf.Tensor
Input tensor.
weight : tf.Variable, default = None
An optional weight tensor for the module. This argument is compulsory
if module is initialized with `skip_weight_param_allocation=True`
bias : tf.Variable, default = None
An optional bias tensor for the module. This argument is compulsory if
module is initialized with `skip_weight_param_allocation=True` and one
of `use_bias` or `return_bias`
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it will set the self.fp8 if the autocast is detected.
training
=
self
.
_get_training_value
(
training
)
self
.
pre_forward
(
training
)
kernel_var
=
(
kernel
if
self
.
skip_weight_param_allocation
else
self
.
kernel
)
bias_var
=
bias
if
self
.
skip_weight_param_allocation
else
self
.
bias
if
kernel_var
is
None
:
raise
ValueError
(
"No valid kernel is provided"
)
inputmat
=
tf
.
reshape
(
inputs
,
shape
=
(
-
1
,
inputs
.
shape
[
-
1
]))
if
self
.
fp8
:
outputmat
=
self
.
fp8_matmul
(
inputmat
,
kernel_var
,
bias_var
)
else
:
outputmat
=
self
.
non_fp8_matmul
(
inputmat
,
kernel_var
,
bias_var
)
outputs
=
tf
.
reshape
(
outputmat
,
shape
=
(
-
1
,
*
inputs
.
shape
[
1
:
-
1
],
outputmat
.
shape
[
-
1
])
)
if
self
.
return_bias
:
return
outputs
,
bias_var
return
outputs
def
get_config
(
self
):
"""Returns the config of the layer."""
config
=
super
().
get_config
()
config
.
update
(
{
"units"
:
self
.
units
,
"use_bias"
:
self
.
use_bias
,
"kernel_initializer"
:
initializers
.
serialize
(
self
.
kernel_initializer
),
"bias_initializer"
:
initializers
.
serialize
(
self
.
bias_initializer
),
"skip_weight_param_allocation"
:
self
.
skip_weight_param_allocation
,
}
)
class
LayerNorm
(
layers
.
Layer
):
"""
Applies Layer Normalization over a mini-batch of inputs.
Parameters
----------
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
"""
def
__init__
(
self
,
epsilon
=
1e-3
,
gamma_initializer
=
"ones"
,
beta_initializer
=
"zeros"
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
epsilon
=
epsilon
self
.
beta_initializer
=
initializers
.
get
(
beta_initializer
)
self
.
gamma_initializer
=
initializers
.
get
(
gamma_initializer
)
self
.
stream
=
get_stream_id
()
def
build
(
self
,
input_shape
):
"""One-time allocation of the variables."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
if
last_dim
is
None
:
raise
ValueError
(
"The last dimension of the inputs to a Dense layer should be "
f
"defined. Found None. Full input shape received:
{
input_shape
}
"
)
self
.
gamma
=
self
.
add_weight
(
name
=
"gamma"
,
shape
=
(
last_dim
,),
initializer
=
self
.
gamma_initializer
,
trainable
=
True
,
)
self
.
beta
=
self
.
add_weight
(
name
=
"beta"
,
shape
=
(
last_dim
,),
initializer
=
self
.
beta_initializer
,
trainable
=
True
,
)
self
.
built
=
True
@
tf
.
custom_gradient
def
layernorm
(
self
,
inp
:
tf
.
Tensor
):
"""Prep fwd+bwd non-fp8 layernorm."""
gamma
=
self
.
gamma
.
value
()
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
inp
,
gamma
,
self
.
beta
.
value
(),
self
.
epsilon
,
self
.
stream
)
def
grad_fn
(
upstream
,
variables
=
None
):
# pylint: disable=unused-argument
dxmat
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
upstream
,
inp
,
mu
,
rsigma
,
gamma
,
self
.
stream
)
grad_inputs
=
[
tf
.
reshape
(
dxmat
,
inp
.
shape
)]
grad_vars
=
[
dgamma
,
dbeta
]
return
grad_inputs
,
grad_vars
return
ln_out
,
grad_fn
def
call
(
self
,
inputs
):
"""LayerNorm FWD"""
inputmat
=
tf
.
reshape
(
inputs
,
shape
=
(
-
1
,
inputs
.
shape
[
-
1
]))
outputmat
=
self
.
layernorm
(
inputmat
)
outputs
=
tf
.
reshape
(
outputmat
,
shape
=
inputs
.
shape
)
return
outputs
def
get_config
(
self
):
"""Returns the config of the layer."""
config
=
super
().
get_config
()
config
.
update
(
{
"epsilon"
:
self
.
epsilon
,
"gamma_initializer"
:
initializers
.
serialize
(
self
.
gamma_initializer
),
"beta_initializer"
:
initializers
.
serialize
(
self
.
beta_initializer
),
}
)
class
LayerNormDense
(
TransformerEngineBaseModule
,
layers
.
Layer
):
"""
Applies layer normalization followed by linear transformation to the
incoming data.
Parameters
----------
units : int
size of each output sample.
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
use_bias : bool, default = `True`
if set to `False`, the layer will not learn an additive bias.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
kernel_initializer : Callable, default = `None`
used for initializing GEMM weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
bias_initializer : Callable, default = `None`
used for initializing GEMM bias in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is taken post
layernorm.
Parallelism parameters
----------------------
skip_weight_param_allocation: bool, default = `False`
if set to `True`, weight parameter is not allocated and must be passed as
a keyword argument `weight` during the forward pass.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def
__init__
(
self
,
units
,
epsilon
=
1e-3
,
gamma_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
beta_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
return_layernorm_output
=
False
,
use_bias
=
True
,
return_bias
=
False
,
kernel_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
bias_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
skip_weight_param_allocation
=
False
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
units
=
units
self
.
epsilon
=
epsilon
self
.
gamma_initializer
=
get_init_method
(
gamma_initializer
,
initializers
.
get
(
"ones"
)
)
self
.
beta_initializer
=
get_init_method
(
beta_initializer
,
initializers
.
get
(
"zeros"
)
)
self
.
return_layernorm_output
=
return_layernorm_output
self
.
use_bias
=
use_bias
self
.
return_bias
=
return_bias
self
.
kernel_initializer
=
get_init_method
(
kernel_initializer
,
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
)
self
.
bias_initializer
=
get_init_method
(
bias_initializer
,
initializers
.
get
(
"zeros"
)
)
self
.
skip_weight_param_allocation
=
skip_weight_param_allocation
def
build
(
self
,
input_shape
):
"""One-time allocation of the variables."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
if
last_dim
is
None
:
raise
ValueError
(
"The last dimension of the inputs to a Dense layer should be "
f
"defined. Found None. Full input shape received:
{
input_shape
}
"
)
self
.
gamma
=
self
.
add_weight
(
name
=
"gamma"
,
shape
=
(
last_dim
,),
initializer
=
self
.
gamma_initializer
,
trainable
=
True
,
)
self
.
beta
=
self
.
add_weight
(
name
=
"beta"
,
shape
=
(
last_dim
,),
initializer
=
self
.
beta_initializer
,
trainable
=
True
,
)
self
.
kernel
=
None
self
.
bias
=
None
if
not
self
.
skip_weight_param_allocation
:
self
.
kernel
=
self
.
add_weight
(
name
=
"kernel"
,
shape
=
(
last_dim
,
self
.
units
),
initializer
=
self
.
kernel_initializer
,
trainable
=
True
,
)
if
self
.
use_bias
or
self
.
return_bias
:
self
.
bias
=
self
.
add_weight
(
name
=
"bias"
,
shape
=
(
self
.
units
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
# fp8 related
self
.
fp8_weight_shapes
.
append
((
last_dim
,
self
.
units
))
self
.
built
=
True
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
if
not
self
.
trainable
:
# When the layer is not trainable, it overrides the value passed
# from model.
training
=
False
return
training
def
non_fp8_layernorm_matmul
(
self
,
inp
:
tf
.
Tensor
,
gamma_var
:
tf
.
Variable
,
beta_var
:
tf
.
Variable
,
kernel_var
:
tf
.
Variable
,
bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd non-fp8 layernorm followed by matmul."""
@
tf
.
custom_gradient
def
non_fp8_layernorm_matmul_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val
=
kernel_var
.
value
()
gamma_val
=
gamma_var
.
value
()
beta_val
=
beta_var
.
value
()
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
self
.
stream_id
)
bias
=
get_autocast_bias
(
self
.
compute_dtype
,
bias_var
,
self
.
use_bias
,
use_fp8
=
False
,
)
output_dtype
=
self
.
_compute_dtype_object
outputs
=
matmul_wrapper
(
ln_out
,
kernel_val
,
"fwd"
,
output_dtype
,
self
.
stream_id
,
self
.
use_bias
,
bias
,
)
def
grad_fn
(
*
upstream
,
variables
=
None
):
grad_x
=
matmul_wrapper
(
upstream
[
0
],
kernel_val
,
"bwd_input"
,
output_dtype
,
self
.
stream_id
,
)
grad_weight
=
matmul_wrapper
(
ln_out
,
upstream
[
0
],
"bwd_weight"
,
output_dtype
,
self
.
stream_id
,
)
if
self
.
use_bias
:
grad_bias
=
tf
.
math
.
reduce_sum
(
upstream
[
0
],
axis
=
0
)
if
self
.
return_layernorm_output
:
assert
len
(
upstream
)
==
2
grad_x
=
grad_x
+
upstream
[
1
]
dxmat
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
grad_x
,
x
,
mu
,
rsigma
,
gamma_val
,
self
.
stream_id
)
grad_inputs
=
[
dxmat
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"gamma:0"
):
grad_vars
.
append
(
dgamma
)
elif
v
.
name
.
endswith
(
"bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
grad_bias
)
elif
v
.
name
.
endswith
(
"kernel:0"
):
grad_vars
.
append
(
grad_weight
)
elif
v
.
name
.
endswith
(
"beta:0"
):
grad_vars
.
append
(
dbeta
)
return
grad_inputs
,
grad_vars
if
self
.
return_layernorm_output
:
return
(
outputs
,
ln_out
),
grad_fn
return
outputs
,
grad_fn
return
non_fp8_layernorm_matmul_func
(
inp
)
def
fp8_layernorm_matmul
(
self
,
inp
:
tf
.
Tensor
,
gamma_var
:
tf
.
Variable
,
beta_var
:
tf
.
Variable
,
kernel_var
:
tf
.
Variable
,
bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd fp8 layernorm followed by matmul."""
fp8_meta
=
self
.
fp8_meta
fp8_dtype_fwd
,
fp8_dtype_bwd
,
override_linear_precision
=
\
get_recipe_attrs
(
fp8_meta
[
"recipe"
])
@
tf
.
custom_gradient
def
fp8_layernorm_matmul_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
kernel_val
=
kernel_var
.
value
()
gamma_val
=
gamma_var
.
value
()
beta_val
=
beta_var
.
value
()
if
not
self
.
return_layernorm_output
:
ln_out
,
mu
,
rsigma
=
layernorm_fwd_fp8_wrapper
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
else
:
ln_out_return
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
self
.
stream_id
)
ln_out
=
cast_to_fp8_wrapper
(
ln_out_return
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
bias
=
get_autocast_bias
(
self
.
compute_dtype
,
bias_var
,
self
.
use_bias
,
use_fp8
=
True
,
)
weight_fp8
,
weight_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
kernel_val
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
output_dtype
=
self
.
_compute_dtype_object
outputs
=
fp8_matmul_wrapper
(
ln_out
,
weight_t_fp8
,
fp8_meta
,
"fwd"
,
fp8_dtype_fwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_FPROP
,
self
.
stream_id
,
self
.
use_bias
,
bias
,
)
def
grad_fn
(
*
upstream
,
variables
=
None
):
self
.
pre_backward
()
if
self
.
use_bias
:
(
grad_bias
,
grad_fp8
,
grad_t_fp8
,
)
=
fp8_cast_transpose_bgrad_fused_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
if
not
override_linear_precision
.
wgrad
:
grad_fp8
,
grad_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
grad_fp8
=
cast_to_fp8_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
grad_x
=
fp8_matmul_wrapper
(
grad_fp8
,
weight_fp8
,
fp8_meta
,
"bwd_input"
,
fp8_dtype_bwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_DGRAD
,
self
.
stream_id
,
)
if
not
override_linear_precision
.
wgrad
:
ln_out_t
=
tex
.
fp8_transpose
(
ln_out
,
fp8_dtype_fwd
,
self
.
stream_id
)
grad_weight
=
fp8_matmul_wrapper
(
ln_out_t
,
grad_t_fp8
,
fp8_meta
,
"bwd_weight"
,
fp8_dtype_fwd
,
fp8_dtype_bwd
,
output_dtype
,
_2X_ACC_WGRAD
,
self
.
stream_id
,
)
else
:
ln_out_c
=
cast_from_fp8_wrapper
(
ln_out
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
TE_DType
[
x
.
dtype
],
self
.
stream_id
,
)
grad_weight
=
matmul_wrapper
(
ln_out_c
,
upstream
[
0
],
"bwd_weight"
,
output_dtype
,
self
.
stream_id
,
)
if
self
.
return_layernorm_output
:
assert
len
(
upstream
)
==
2
grad_x
=
grad_x
+
upstream
[
1
]
dxmat
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
grad_x
,
x
,
mu
,
rsigma
,
gamma_val
,
self
.
stream_id
)
grad_inputs
=
[
dxmat
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"gamma:0"
):
grad_vars
.
append
(
dgamma
)
elif
v
.
name
.
endswith
(
"bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
grad_bias
)
elif
v
.
name
.
endswith
(
"kernel:0"
):
grad_vars
.
append
(
grad_weight
)
elif
v
.
name
.
endswith
(
"beta:0"
):
grad_vars
.
append
(
dbeta
)
return
grad_inputs
,
grad_vars
if
self
.
return_layernorm_output
:
return
(
outputs
,
ln_out_return
),
grad_fn
return
outputs
,
grad_fn
return
fp8_layernorm_matmul_func
(
inp
)
def
call
(
self
,
inputs
,
kernel
=
None
,
bias
=
None
,
training
=
None
,
):
"""
Apply layer normalization to the input followed by a linear
transformation.
Parameters
----------
inputs : tf.Tensor
Input tensor.
kernel : tf.Variable, default = None
An optional weight tensor for the module. This argument is compulsory
if module is initialized with `skip_weight_param_allocation=True`
bias : tf.Variable, default = None
An optional bias tensor for the module. This argument is compulsory if
module is initialized with `skip_weight_param_allocation=True` and one
of `use_bias` or `return_bias`
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it has side effects to set the self.fp8 if the autocast is
# detected.
training
=
self
.
_get_training_value
(
training
)
self
.
pre_forward
(
training
)
kernel_var
=
(
kernel
if
self
.
skip_weight_param_allocation
else
self
.
kernel
)
bias_var
=
bias
if
self
.
skip_weight_param_allocation
else
self
.
bias
if
kernel_var
is
None
:
raise
ValueError
(
"No valid kernel is provided"
)
inputmat
=
tf
.
reshape
(
inputs
,
shape
=
(
-
1
,
inputs
.
shape
[
-
1
]))
if
self
.
fp8
:
outputs
=
self
.
fp8_layernorm_matmul
(
inputmat
,
self
.
gamma
,
self
.
beta
,
kernel_var
,
bias_var
)
else
:
outputs
=
self
.
non_fp8_layernorm_matmul
(
inputmat
,
self
.
gamma
,
self
.
beta
,
kernel_var
,
bias_var
)
if
self
.
return_layernorm_output
:
outputmat
,
ln_outputmat
=
outputs
else
:
outputmat
=
outputs
outputs
=
tf
.
reshape
(
outputmat
,
shape
=
(
-
1
,
*
inputs
.
shape
[
1
:
-
1
],
outputmat
.
shape
[
-
1
])
)
if
self
.
return_bias
:
if
self
.
return_layernorm_output
:
ln_outputs
=
tf
.
reshape
(
ln_outputmat
,
shape
=
inputs
.
shape
)
return
(
outputs
,
bias_var
,
ln_outputs
)
return
outputs
,
bias_var
if
self
.
return_layernorm_output
:
ln_outputs
=
tf
.
reshape
(
ln_outputmat
,
shape
=
inputs
.
shape
)
return
(
outputs
,
ln_outputs
)
return
outputs
def
get_config
(
self
):
"""Returns the config of the layer."""
config
=
super
().
get_config
()
config
.
update
(
{
"units"
:
self
.
units
,
"epsilon"
:
self
.
epsilon
,
"gamma_initializer"
:
initializers
.
serialize
(
self
.
gamma_initializer
),
"beta_initializer"
:
initializers
.
serialize
(
self
.
beta_initializer
),
"return_layernorm_output"
:
self
.
return_layernorm_output
,
"use_bias"
:
self
.
use_bias
,
"kernel_initializer"
:
initializers
.
serialize
(
self
.
kernel_initializer
),
"bias_initializer"
:
initializers
.
serialize
(
self
.
bias_initializer
),
"skip_weight_param_allocation"
:
self
.
skip_weight_param_allocation
,
}
)
class
LayerNormMLP
(
TransformerEngineBaseModule
,
layers
.
Layer
):
"""
Applies layer normalization on the input followed by the MLP module,
consisting of 2 successive linear transformations, separated by the GeLU
activation.
Parameters
----------
units : int
size of each input sample.
ffn_units : int
intermediate size to which input samples are projected.
epsilon : float, default = 1e-3
a value added to the denominator of layer normalization for numerical
stability.
gamma_initializer: Callable, default = `None`
used for initializing LayerNorm gamma in the following way:
`gamma_initializer(weight)`. When set to `None`, defaults to `ones`.
beta_initializer: Callable, default = `None`
used for initializing LayerNorm beta in the following way:
`beta_initializer(weight)`. When set to `None`, defaults to `zeros`.
use_bias : bool, default = `True`
if set to `False`, the FC2 layer will not learn an additive bias.
kernel_initializer: Callable, default = `None`
used for initializing FC1 weights in the following way:
`kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
ffn_kernel_initializer: Callable, default = `None`
used for initializing FC2 weights in the following way:
`ffn_kernel_initializer(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
return_layernorm_output : bool, default = `False`
if set to `True`, output of layernorm is returned from the forward
together with the output of the linear transformation.
Example use case: residual connection for transformer module is taken post
layernorm.
bias_initializer: Callable, default = `None`
used for initializing FC1 and FC2 bias in the following way:
`bias_initializer(weight)`. When set to `None`, defaults to `zeros`.
Optimization parameters
-----------------------
return_bias : bool, default = `False`
when set to `True`, this module will not apply the additive bias itself,
but instead return the bias value during the forward pass together with
the output of the linear transformation :math:`y = xW`. This is useful
when the bias addition can be fused to subsequent operations.
"""
def
__init__
(
self
,
units
:
int
,
ffn_units
:
int
,
epsilon
:
float
=
1e-3
,
gamma_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
beta_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
return_layernorm_output
:
bool
=
False
,
use_bias
:
bool
=
True
,
return_bias
:
bool
=
False
,
kernel_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
ffn_kernel_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
bias_initializer
:
Union
[
Callable
,
str
,
None
]
=
None
,
**
kwargs
,
):
super
().
__init__
(
**
kwargs
)
self
.
fc1_units
=
units
self
.
fc2_units
=
ffn_units
self
.
epsilon
=
epsilon
self
.
gamma_initializer
=
get_init_method
(
gamma_initializer
,
initializers
.
get
(
"ones"
)
)
self
.
beta_initializer
=
get_init_method
(
beta_initializer
,
initializers
.
get
(
"zeros"
)
)
self
.
return_layernorm_output
=
return_layernorm_output
self
.
use_bias
=
use_bias
self
.
return_bias
=
return_bias
self
.
kernel1_initializer
=
get_init_method
(
kernel_initializer
,
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
)
self
.
kernel2_initializer
=
get_init_method
(
ffn_kernel_initializer
,
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
)
self
.
bias_initializer
=
get_init_method
(
bias_initializer
,
initializers
.
get
(
"zeros"
)
)
def
build
(
self
,
input_shape
):
"""One-time allocation of the variables."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
if
last_dim
is
None
:
raise
ValueError
(
"The last dimension of the inputs to a Dense layer should be "
f
"defined. Found None. Full input shape received:
{
input_shape
}
"
)
self
.
gamma
=
self
.
add_weight
(
name
=
"gamma"
,
shape
=
(
last_dim
,),
initializer
=
self
.
gamma_initializer
,
trainable
=
True
,
)
self
.
beta
=
self
.
add_weight
(
name
=
"beta"
,
shape
=
(
last_dim
,),
initializer
=
self
.
beta_initializer
,
trainable
=
True
,
)
self
.
fc1_kernel
=
self
.
add_weight
(
name
=
"fc1_kernel"
,
shape
=
(
last_dim
,
self
.
fc1_units
),
initializer
=
self
.
kernel1_initializer
,
trainable
=
True
,
)
self
.
fc1_bias
=
self
.
add_weight
(
name
=
"fc1_bias"
,
shape
=
(
self
.
fc1_units
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
# fp8 related
self
.
fp8_weight_shapes
.
append
((
last_dim
,
self
.
fc1_units
))
self
.
fc2_kernel
=
self
.
add_weight
(
name
=
"fc2_kernel"
,
shape
=
(
self
.
fc1_units
,
self
.
fc2_units
),
initializer
=
self
.
kernel2_initializer
,
trainable
=
True
,
)
self
.
fc2_bias
=
None
if
self
.
use_bias
or
self
.
return_bias
:
self
.
fc2_bias
=
self
.
add_weight
(
name
=
"fc2_bias"
,
shape
=
(
self
.
fc2_units
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
# fp8 related
self
.
fp8_weight_shapes
.
append
((
self
.
fc1_units
,
self
.
fc2_units
))
self
.
built
=
True
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
if
not
self
.
trainable
:
# When the layer is not trainable, it overrides the value passe from
# model.
training
=
False
return
training
def
non_fp8_layernorm_mlp
(
self
,
inp
:
tf
.
Tensor
,
gamma_var
:
tf
.
Variable
,
beta_var
:
tf
.
Variable
,
fc1_kernel_var
:
tf
.
Variable
,
fc1_bias_var
:
tf
.
Variable
,
fc2_kernel_var
:
tf
.
Variable
,
fc2_bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd non-fp8 layernorm followed by mlp."""
@
tf
.
custom_gradient
def
non_fp8_layernorm_mlp_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val
=
fc1_kernel_var
.
value
()
fc2_kernel_val
=
fc2_kernel_var
.
value
()
gamma_val
=
gamma_var
.
value
()
beta_val
=
beta_var
.
value
()
ln_out
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
self
.
stream_id
)
fc1_bias
=
get_autocast_bias
(
self
.
compute_dtype
,
fc1_bias_var
,
use_bias
=
True
,
use_fp8
=
False
,
)
fc2_bias
=
get_autocast_bias
(
self
.
compute_dtype
,
fc2_bias_var
,
self
.
use_bias
,
use_fp8
=
False
,
)
output_dtype
=
self
.
_compute_dtype_object
# TODO(kaixih): Ideally, we should set gelu=True to fuse the gelu in
# cuBlasLt calls. However, it seems it is slower than the unfused
# version. Fix this when cuBlasLt improves the issue.
fc1_out
=
matmul_wrapper
(
ln_out
,
fc1_kernel_val
,
"fc1_fwd"
,
output_dtype
,
self
.
stream_id
,
use_bias
=
True
,
bias
=
fc1_bias
,
)
gelu_out
=
tex
.
te_gelu
(
fc1_out
,
None
,
TE_DType
[
output_dtype
],
None
,
None
,
0
,
self
.
stream_id
,
)
fc2_out
=
matmul_wrapper
(
gelu_out
,
fc2_kernel_val
,
"fc2_fwd"
,
output_dtype
,
self
.
stream_id
,
use_bias
=
self
.
use_bias
,
bias
=
fc2_bias
,
)
def
grad_fn
(
*
upstream
,
variables
=
None
):
fc2_dgrad
=
matmul_wrapper
(
upstream
[
0
],
fc2_kernel_val
,
"fc2_bwd_input"
,
output_dtype
,
self
.
stream_id
,
grad
=
True
,
gelu
=
True
,
gelu_input
=
fc1_out
,
)
fc2_wgrad
=
matmul_wrapper
(
gelu_out
,
upstream
[
0
],
"bwd_weight"
,
output_dtype
,
self
.
stream_id
,
)
if
self
.
use_bias
:
fc2_bias_grad
=
tf
.
math
.
reduce_sum
(
upstream
[
0
],
axis
=
0
)
dgelu
=
fc2_dgrad
fc1_dgrad
=
matmul_wrapper
(
dgelu
,
fc1_kernel_val
,
"fc1_bwd_input"
,
output_dtype
,
self
.
stream_id
,
)
fc1_wgrad
=
matmul_wrapper
(
ln_out
,
dgelu
,
"bwd_weight"
,
output_dtype
,
self
.
stream_id
)
fc1_bias_grad
=
tf
.
math
.
reduce_sum
(
dgelu
,
axis
=
0
)
d_ln_out
=
fc1_dgrad
if
self
.
return_layernorm_output
:
assert
len
(
upstream
)
==
2
d_ln_out
=
d_ln_out
+
upstream
[
1
]
dxmat
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
d_ln_out
,
x
,
mu
,
rsigma
,
gamma_val
,
self
.
stream_id
)
grad_inputs
=
[
dxmat
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"gamma:0"
):
grad_vars
.
append
(
dgamma
)
elif
v
.
name
.
endswith
(
"fc1_kernel:0"
):
grad_vars
.
append
(
fc1_wgrad
)
elif
v
.
name
.
endswith
(
"fc1_bias:0"
):
grad_vars
.
append
(
fc1_bias_grad
)
elif
v
.
name
.
endswith
(
"fc2_kernel:0"
):
grad_vars
.
append
(
fc2_wgrad
)
elif
v
.
name
.
endswith
(
"fc2_bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
fc2_bias_grad
)
elif
v
.
name
.
endswith
(
"beta:0"
):
grad_vars
.
append
(
dbeta
)
return
grad_inputs
,
grad_vars
if
self
.
return_layernorm_output
:
return
(
fc2_out
,
ln_out
),
grad_fn
return
fc2_out
,
grad_fn
return
non_fp8_layernorm_mlp_func
(
inp
)
def
fp8_layernorm_mlp
(
self
,
inp
:
tf
.
Tensor
,
gamma_var
:
tf
.
Variable
,
beta_var
:
tf
.
Variable
,
fc1_kernel_var
:
tf
.
Variable
,
fc1_bias_var
:
tf
.
Variable
,
fc2_kernel_var
:
tf
.
Variable
,
fc2_bias_var
:
Union
[
tf
.
Variable
,
None
]
=
None
,
):
"""Prep fwd+bwd fp8 layernorm followed by mlp."""
fp8_meta
=
self
.
fp8_meta
fp8_dtype_fwd
,
fp8_dtype_bwd
,
override_linear_precision
=
\
get_recipe_attrs
(
fp8_meta
[
"recipe"
])
@
tf
.
custom_gradient
def
fp8_layernorm_mlp_func
(
x
):
# We need to pass the EagerTensor instead of Variable when calling
# into the pybind functions. So, we use value() for the explicit
# convertion.
fc1_kernel_val
=
fc1_kernel_var
.
value
()
fc2_kernel_val
=
fc2_kernel_var
.
value
()
gamma_val
=
gamma_var
.
value
()
beta_val
=
beta_var
.
value
()
if
not
self
.
return_layernorm_output
:
ln_out
,
mu
,
rsigma
=
layernorm_fwd_fp8_wrapper
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
else
:
ln_out_return
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma_val
,
beta_val
,
self
.
epsilon
,
self
.
stream_id
)
ln_out
=
cast_to_fp8_wrapper
(
ln_out_return
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
fc1_bias
=
get_autocast_bias
(
self
.
compute_dtype
,
fc1_bias_var
,
use_bias
=
True
,
use_fp8
=
True
,
)
fc2_bias
=
get_autocast_bias
(
self
.
compute_dtype
,
fc2_bias_var
,
self
.
use_bias
,
use_fp8
=
True
,
)
fc1_weight_fp8
,
fc1_weight_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
fc1_kernel_val
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_WEIGHT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
fc2_weight_fp8
,
fc2_weight_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
fc2_kernel_val
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM2_WEIGHT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
output_dtype
=
self
.
_compute_dtype_object
fc1_out
=
fp8_matmul_wrapper
(
ln_out
,
fc1_weight_t_fp8
,
fp8_meta
,
"fc1_fwd"
,
fp8_dtype_fwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_FPROP
,
self
.
stream_id
,
use_bias
=
True
,
bias
=
fc1_bias
,
)
gelu_out
=
fp8_gelu_wrapper
(
fc1_out
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM2_INPUT
,
True
,
fp8_dtype_fwd
,
self
.
stream_id
,
)
fc2_out
=
fp8_matmul_wrapper
(
gelu_out
,
fc2_weight_t_fp8
,
fp8_meta
,
"fc2_fwd"
,
fp8_dtype_fwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_FPROP
,
self
.
stream_id
,
use_bias
=
self
.
use_bias
,
bias
=
fc2_bias
,
)
def
grad_fn
(
*
upstream
,
variables
=
None
):
self
.
pre_backward
()
if
self
.
use_bias
:
(
fc2_bias_grad
,
grad_fp8
,
grad_t_fp8
,
)
=
fp8_cast_transpose_bgrad_fused_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
if
not
override_linear_precision
.
wgrad
:
grad_fp8
,
grad_t_fp8
=
fp8_cast_transpose_fused_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
grad_fp8
=
cast_to_fp8_wrapper
(
upstream
[
0
],
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT1
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
fc2_dgrad
=
fp8_matmul_wrapper
(
grad_fp8
,
fc2_weight_fp8
,
fp8_meta
,
"fc2_bwd_input"
,
fp8_dtype_bwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_DGRAD
,
self
.
stream_id
,
)
if
not
override_linear_precision
.
wgrad
:
gelu_out_t
=
tex
.
fp8_transpose
(
gelu_out
,
fp8_dtype_fwd
,
self
.
stream_id
)
fc2_wgrad
=
fp8_matmul_wrapper
(
gelu_out_t
,
grad_t_fp8
,
fp8_meta
,
"fc2_bwd_weight"
,
fp8_dtype_fwd
,
fp8_dtype_bwd
,
output_dtype
,
_2X_ACC_WGRAD
,
self
.
stream_id
,
)
(
fc1_bias_grad
,
dgelu
,
dgelu_t
,
)
=
fp8_cast_transpose_bgrad_dgelu_fused_wrapper
(
fc2_dgrad
,
fc1_out
,
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
else
:
gelu_out_c
=
cast_from_fp8_wrapper
(
gelu_out
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM2_INPUT
,
True
,
fp8_dtype_fwd
,
TE_DType
[
x
.
dtype
],
self
.
stream_id
,
)
fc2_wgrad
=
matmul_wrapper
(
gelu_out_c
,
upstream
[
0
],
"bwd_weight"
,
output_dtype
,
self
.
stream_id
,
)
# Different from PyTorch implementation, the fc1_out has
# already added bias. So we don't need to pass fc1_bias
# here.
fc1_bias_grad
,
dgelu_no_fp8
=
bgrad_dgelu_fused
(
fc2_dgrad
,
fc1_out
)
dgelu
=
cast_to_fp8_wrapper
(
dgelu_no_fp8
,
fp8_meta
,
tex
.
FP8BwdTensors
.
GRAD_OUTPUT2
,
False
,
fp8_dtype_bwd
,
self
.
stream_id
,
)
dgelu_t
=
None
fc1_dgrad
=
fp8_matmul_wrapper
(
dgelu
,
fc1_weight_fp8
,
fp8_meta
,
"fc1_bwd_input"
,
fp8_dtype_bwd
,
fp8_dtype_fwd
,
output_dtype
,
_2X_ACC_DGRAD
,
self
.
stream_id
,
)
if
not
override_linear_precision
.
wgrad
:
ln_out_t
=
tex
.
fp8_transpose
(
ln_out
,
fp8_dtype_fwd
,
self
.
stream_id
)
fc1_wgrad
=
fp8_matmul_wrapper
(
ln_out_t
,
dgelu_t
,
fp8_meta
,
"fc1_bwd_weight"
,
fp8_dtype_fwd
,
fp8_dtype_bwd
,
output_dtype
,
_2X_ACC_WGRAD
,
self
.
stream_id
,
)
else
:
ln_out_c
=
cast_from_fp8_wrapper
(
ln_out
,
fp8_meta
,
tex
.
FP8FwdTensors
.
GEMM1_INPUT
,
True
,
fp8_dtype_fwd
,
TE_DType
[
x
.
dtype
],
self
.
stream_id
,
)
fc1_wgrad
=
matmul_wrapper
(
ln_out_c
,
dgelu_no_fp8
,
"bwd_weight"
,
output_dtype
,
self
.
stream_id
,
)
d_ln_out
=
fc1_dgrad
if
self
.
return_layernorm_output
:
assert
len
(
upstream
)
==
2
d_ln_out
=
d_ln_out
+
upstream
[
1
]
dxmat
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
d_ln_out
,
x
,
mu
,
rsigma
,
gamma_val
,
self
.
stream_id
)
grad_inputs
=
[
dxmat
]
grad_vars
=
[]
for
v
in
variables
:
if
v
.
name
.
endswith
(
"gamma:0"
):
grad_vars
.
append
(
dgamma
)
elif
v
.
name
.
endswith
(
"fc1_kernel:0"
):
grad_vars
.
append
(
fc1_wgrad
)
elif
v
.
name
.
endswith
(
"fc1_bias:0"
):
grad_vars
.
append
(
fc1_bias_grad
)
elif
v
.
name
.
endswith
(
"fc2_kernel:0"
):
grad_vars
.
append
(
fc2_wgrad
)
elif
v
.
name
.
endswith
(
"fc2_bias:0"
)
and
self
.
use_bias
:
grad_vars
.
append
(
fc2_bias_grad
)
elif
v
.
name
.
endswith
(
"beta:0"
):
grad_vars
.
append
(
dbeta
)
return
grad_inputs
,
grad_vars
if
self
.
return_layernorm_output
:
return
(
fc2_out
,
ln_out_return
),
grad_fn
return
fc2_out
,
grad_fn
return
fp8_layernorm_mlp_func
(
inp
)
def
call
(
self
,
inputs
,
training
=
None
,
):
"""
Apply layer normalization to the input followed by a feedforward network
(MLP Block).
Parameters
----------
inputs : tf.Tensor
Input tensor.
training : {True, False, None}, default = None
Whether this is in the training context.
"""
# self.pre_forward needs to be called outside the following branch,
# since it has side effects to set the self.fp8 if the autocast is
# detected.
training
=
self
.
_get_training_value
(
training
)
self
.
pre_forward
(
training
,
num_gemms
=
2
)
inputmat
=
tf
.
reshape
(
inputs
,
shape
=
(
-
1
,
inputs
.
shape
[
-
1
]))
if
self
.
fp8
:
outputs
=
self
.
fp8_layernorm_mlp
(
inputmat
,
self
.
gamma
,
self
.
beta
,
self
.
fc1_kernel
,
self
.
fc1_bias
,
self
.
fc2_kernel
,
self
.
fc2_bias
,
)
else
:
outputs
=
self
.
non_fp8_layernorm_mlp
(
inputmat
,
self
.
gamma
,
self
.
beta
,
self
.
fc1_kernel
,
self
.
fc1_bias
,
self
.
fc2_kernel
,
self
.
fc2_bias
,
)
if
self
.
return_layernorm_output
:
outputmat
,
ln_outputmat
=
outputs
else
:
outputmat
=
outputs
outputs
=
tf
.
reshape
(
outputmat
,
shape
=
(
-
1
,
*
inputs
.
shape
[
1
:
-
1
],
outputmat
.
shape
[
-
1
])
)
if
self
.
return_bias
:
if
self
.
return_layernorm_output
:
ln_outputs
=
tf
.
reshape
(
ln_outputmat
,
shape
=
inputs
.
shape
)
return
(
outputs
,
self
.
fc2_bias
,
ln_outputs
)
return
outputs
,
self
.
fc2_bias
if
self
.
return_layernorm_output
:
ln_outputs
=
tf
.
reshape
(
ln_outputmat
,
shape
=
inputs
.
shape
)
return
(
outputs
,
ln_outputs
)
return
outputs
def
get_config
(
self
):
"""Returns the config of the layer."""
config
=
super
().
get_config
()
config
.
update
(
{
"hidden_size"
:
self
.
fc1_units
,
"ffn_hidden_size"
:
self
.
fc2_units
,
"epsilon"
:
self
.
epsilon
,
"gamma_init_method"
:
initializers
.
serialize
(
self
.
gamma_initializer
),
"beta_init_method"
:
initializers
.
serialize
(
self
.
beta_initializer
),
"return_layernorm_output"
:
self
.
return_layernorm_output
,
"use_bias"
:
self
.
use_bias
,
"init_method"
:
initializers
.
serialize
(
self
.
kernel1_initializer
),
"output_layer_init_method"
:
initializers
.
serialize
(
self
.
kernel2_initializer
),
"bias_init_method"
:
initializers
.
serialize
(
self
.
bias_initializer
),
}
)
transformer_engine/tensorflow/softmax.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused scaled masked softmax functions"""
from
typing
import
Callable
import
os
import
transformer_engine_tensorflow
as
tex
import
tensorflow
as
tf
from
.module
import
get_stream_id
THREADS_PER_WARP
=
32
THREADS_PER_BLOCK
=
128
_default_causal_mask
=
{}
def
_get_default_causal_mask
(
sq
:
int
)
->
tf
.
Tensor
:
"""Return the causal upper triangular mask for softmax input"""
if
sq
not
in
_default_causal_mask
:
# In TF, the mask specifies 1 to keep and 0 to mask. In "causal" mask
# mode, we compute the softmax of the lower triangular.
mask_operator
=
tf
.
linalg
.
LinearOperatorLowerTriangular
(
tf
.
ones
((
sq
,
sq
),
dtype
=
tf
.
bool
))
mask
=
mask_operator
.
to_dense
()
_default_causal_mask
[
sq
]
=
mask
return
_default_causal_mask
[
sq
]
class
FusedScaleMaskSoftmax
(
tf
.
keras
.
Model
):
"""
fused operation: scaling + mask + softmax
Arguments:
attn_mask_type: attention mask type (pad or causal)
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
"""
def
__init__
(
self
,
attn_mask_type
:
str
,
mask_func
:
Callable
,
softmax_in_fp32
:
bool
,
scale
:
float
,
)
->
None
:
super
().
__init__
()
self
.
attn_mask_type
=
attn_mask_type
self
.
scaled_masked_softmax_fusion
=
bool
(
int
(
os
.
getenv
(
"NVTE_MASKED_SOFTMAX_FUSION"
,
"1"
))
)
self
.
mask_func
=
mask_func
self
.
softmax_in_fp32
=
softmax_in_fp32
self
.
scale
=
scale
self
.
stream
=
get_stream_id
()
assert
(
self
.
scale
is
None
or
softmax_in_fp32
),
"softmax should be in fp32 when scaled"
def
__call__
(
self
,
inp
:
tf
.
Tensor
,
mask
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""FusedScaleMaskSoftmax fprop"""
# [b, np, sq, sk]
assert
len
(
inp
.
shape
)
==
4
self
.
input_in_fp16
=
inp
.
dtype
==
tf
.
float16
self
.
input_in_bf16
=
inp
.
dtype
==
tf
.
bfloat16
self
.
input_in_float16
=
self
.
input_in_fp16
or
self
.
input_in_bf16
if
self
.
is_kernel_available
(
*
inp
.
shape
):
return
self
.
forward_fused_softmax
(
inp
,
mask
)
return
self
.
forward_tf_softmax
(
inp
,
mask
)
def
is_kernel_available
(
self
,
b
:
int
,
np
:
int
,
sq
:
int
,
sk
:
int
)
->
bool
:
"""Check FusedScaleMaskSoftmax kernel availability based on size"""
attn_batches
=
b
*
np
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
int
(
sk
))
if
self
.
attn_mask_type
==
"causal"
:
if
attn_batches
%
batch_per_block
==
0
:
return
True
else
:
if
sq
%
batch_per_block
==
0
:
return
True
return
False
@
tf
.
custom_gradient
def
scaled_masked_softmax
(
self
,
x
:
tf
.
Tensor
,
mask
:
tf
.
Tensor
,
scale
:
float
):
"""Scaled masked softmax."""
y
=
tex
.
scaled_masked_softmax_forward
(
x
,
mask
,
scale
,
self
.
stream
)
def
grad_fn
(
upstream
):
dx
=
tex
.
scaled_masked_softmax_backward
(
upstream
,
y
,
scale
,
self
.
stream
)
return
dx
,
None
,
None
return
y
,
grad_fn
@
tf
.
custom_gradient
def
scaled_softmax
(
self
,
x
:
tf
.
Tensor
,
scale
:
float
):
"""Scaled softmax."""
y
=
tex
.
scaled_softmax_forward
(
x
,
scale
,
self
.
stream
)
def
grad_fn
(
upstream
):
dx
=
tex
.
scaled_softmax_backward
(
upstream
,
y
,
scale
,
self
.
stream
)
return
dx
,
None
return
y
,
grad_fn
@
tf
.
custom_gradient
def
scaled_upper_triang_masked_softmax
(
self
,
x
:
tf
.
Tensor
,
scale
:
float
):
"""Scaled upper triangular masked softmax."""
y
=
tex
.
scaled_upper_triang_masked_softmax_forward
(
x
,
scale
,
self
.
stream
)
def
grad_fn
(
upstream
):
dx
=
tex
.
scaled_upper_triang_masked_softmax_backward
(
upstream
,
y
,
scale
,
self
.
stream
)
return
dx
,
None
return
y
,
grad_fn
def
forward_fused_softmax
(
self
,
inp
:
tf
.
Tensor
,
mask
:
tf
.
Tensor
,
)
->
tf
.
Tensor
:
"""Fused masked softmax kernel"""
sq
,
sk
=
inp
.
shape
[
2
],
inp
.
shape
[
3
]
scale
=
self
.
scale
if
self
.
scale
is
not
None
else
1.0
if
self
.
attn_mask_type
==
"causal"
:
assert
sq
==
sk
,
"causal mask is only for self attention"
# input is 3D tensor (attn_batches, sq, sk)
inp
=
tf
.
reshape
(
inp
,
(
-
1
,
sq
,
sk
))
probs
=
self
.
scaled_upper_triang_masked_softmax
(
inp
,
scale
)
return
tf
.
reshape
(
probs
,
inp
.
shape
)
# input is 4D tensor (b, np, sq, sk)
if
mask
is
not
None
:
# The mask defined in TE kernels are different from TF. In TE, the
# mask specifies 1 to mask out and 0 to keep.
mask
=
tf
.
math
.
logical_not
(
mask
)
ndims
=
len
(
mask
.
shape
)
assert
ndims
<=
4
,
"mask ndims should be <= 4"
if
len
(
mask
.
shape
)
<
4
:
# Broadcasting the first dims of mask to match the input ndims.
broadcast_shape
=
[
1
]
*
(
4
-
ndims
)
+
mask
.
shape
[:]
mask
=
tf
.
reshape
(
mask
,
broadcast_shape
)
return
self
.
scaled_masked_softmax
(
inp
,
mask
,
scale
)
return
self
.
scaled_softmax
(
inp
,
scale
)
def
forward_tf_softmax
(
self
,
inp
:
tf
.
Tensor
,
mask
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Framework softmax"""
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
inp
=
tf
.
cast
(
inp
,
tf
.
float32
)
if
self
.
scale
is
not
None
:
inp
=
inp
*
self
.
scale
if
self
.
attn_mask_type
==
"causal"
:
mask
=
_get_default_causal_mask
(
inp
.
shape
[
2
])
mask_output
=
self
.
mask_func
(
inp
,
mask
)
if
mask
is
not
None
else
inp
probs
=
tf
.
nn
.
softmax
(
mask_output
,
axis
=-
1
)
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
if
self
.
input_in_fp16
:
probs
=
tf
.
cast
(
probs
,
tf
.
half
)
else
:
probs
=
tf
.
cast
(
probs
,
tf
.
bfloat16
)
return
probs
@
staticmethod
def
get_batch_per_block
(
key_seq_len
:
int
)
->
int
:
"""Softmax utility"""
pow2
=
1
<<
(
key_seq_len
-
1
).
bit_length
()
warp_size
=
pow2
if
pow2
<
THREADS_PER_WARP
else
THREADS_PER_WARP
batches_per_warp
=
2
if
pow2
<=
128
else
1
warps_per_block
=
THREADS_PER_BLOCK
/
warp_size
batches_per_block
=
warps_per_block
*
batches_per_warp
return
batches_per_block
transformer_engine/tensorflow/transformer.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer."""
from
contextlib
import
nullcontext
from
typing
import
Callable
,
Optional
,
Tuple
,
Union
import
os
from
keras
import
backend
,
layers
,
initializers
import
tensorflow
as
tf
from
transformer_engine.tensorflow.module
import
(
LayerNorm
,
LayerNormDense
,
LayerNormMLP
,
Dense
,
)
from
transformer_engine.tensorflow.softmax
import
FusedScaleMaskSoftmax
from
transformer_engine.tensorflow.constants
import
(
AttnMaskTypes
,
AttnTypes
,
LayerTypes
,
)
from
transformer_engine.tensorflow.utils
import
(
divide
,
attention_mask_func
,
)
from
transformer_engine.tensorflow.jit
import
(
get_bias_dropout_add
,
bias_dropout_add_fused_train
,
bias_dropout_add_fused_inference
,
)
class
CoreAttention
(
tf
.
keras
.
Model
):
# pylint: disable=too-few-public-methods
"""Parallel attention w/o QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def
__init__
(
self
,
num_attention_heads
:
int
,
kv_channels
:
int
,
attention_dropout
:
float
,
layer_number
:
Optional
[
int
]
=
None
,
apply_query_key_layer_scaling
:
bool
=
True
,
attention_softmax_in_fp32
:
bool
=
False
,
attn_mask_type
:
str
=
"causal"
,
)
->
None
:
super
().
__init__
()
self
.
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
attention_softmax_in_fp32
if
layer_number
is
None
:
self
.
apply_query_key_layer_scaling
=
False
else
:
self
.
layer_number
=
max
(
1
,
layer_number
)
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
attn_mask_type
=
attn_mask_type
projection_size
=
kv_channels
*
num_attention_heads
assert
(
attn_mask_type
in
AttnMaskTypes
),
f
"attn_mask_type
{
attn_mask_type
}
not supported"
# Per attention head and per partition values.
self
.
hidden_size_per_partition
=
divide
(
projection_size
,
1
)
self
.
hidden_size_per_attention_head
=
divide
(
projection_size
,
num_attention_heads
)
self
.
attention_dropout_ctx
=
nullcontext
coeff
=
None
self
.
norm_factor
=
tf
.
math
.
sqrt
(
float
(
self
.
hidden_size_per_attention_head
))
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
attn_mask_type
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
,
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
layers
.
Dropout
(
attention_dropout
)
def
__call__
(
self
,
query_layer
:
tf
.
Tensor
,
key_layer
:
tf
.
Tensor
,
value_layer
:
tf
.
Tensor
,
attention_mask
:
tf
.
Tensor
,
)
->
tf
.
Tensor
:
"""core attention fprop"""
# [b, np, sq, sk]
output_size
=
(
query_layer
.
shape
[
1
],
query_layer
.
shape
[
2
],
query_layer
.
shape
[
0
],
key_layer
.
shape
[
0
],
)
# [sq, b, np, hn] -> [sq, b * np, hn]
new_q_shape
=
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
query_layer
=
tf
.
reshape
(
query_layer
,
new_q_shape
)
# [sk, b, np, hn] -> [sk, b * np, hn]
new_k_shape
=
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
key_layer
=
tf
.
reshape
(
key_layer
,
new_k_shape
)
norm_factor
=
self
.
_maybe_cast_inputs
(
self
.
norm_factor
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
(
tf
.
matmul
(
tf
.
transpose
(
query_layer
,
perm
=
(
1
,
0
,
2
)),
# [b * np, sq, hn]
tf
.
transpose
(
key_layer
,
perm
=
(
1
,
2
,
0
)),
# [b * np, hn, sk]
)
/
norm_factor
)
# change view to [b, np, sq, sk]
attention_scores
=
tf
.
reshape
(
matmul_result
,
output_size
)
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
self
.
attention_dropout_ctx
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
output_size
=
(
value_layer
.
shape
[
1
],
value_layer
.
shape
[
2
],
query_layer
.
shape
[
0
],
value_layer
.
shape
[
3
],
)
# change view [sk, b * np, hn]
new_v_shape
=
(
value_layer
.
shape
[
0
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
value_layer
=
tf
.
reshape
(
value_layer
,
new_v_shape
)
# change view [b * np, sq, sk]
new_attn_shape
=
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
attention_probs
=
tf
.
reshape
(
attention_probs
,
new_attn_shape
)
# matmul: [b * np, sq, hn]
context_layer
=
tf
.
matmul
(
attention_probs
,
# [b * np, sq, sk]
tf
.
transpose
(
value_layer
,
perm
=
(
1
,
0
,
2
)),
# [b * np, sk, hn]
)
# change view [b, np, sq, hn]
context_layer
=
tf
.
reshape
(
context_layer
,
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
tf
.
transpose
(
context_layer
,
perm
=
(
2
,
0
,
1
,
3
))
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
(
*
context_layer
.
shape
[:
-
2
],
self
.
hidden_size_per_partition
,
)
context_layer
=
tf
.
reshape
(
context_layer
,
new_context_layer_shape
)
return
context_layer
class
MultiHeadAttention
(
layers
.
Layer
):
"""Parallel attention w/ QKV and Proj Gemms
BMM1 -> softmax + dropout -> BMM2
"""
def
__init__
(
self
,
hidden_size
:
int
,
num_attention_heads
:
int
,
kv_channels
:
int
,
attention_dropout
:
float
,
layernorm_epsilon
:
float
=
1e-3
,
init_method
:
Optional
[
Callable
]
=
None
,
output_layer_init_method
:
Optional
[
Callable
]
=
None
,
layer_number
:
Optional
[
int
]
=
None
,
apply_query_key_layer_scaling
:
bool
=
True
,
attention_softmax_in_fp32
:
bool
=
False
,
attn_mask_type
:
str
=
"causal"
,
return_layernorm_output
:
bool
=
False
,
input_layernorm
:
bool
=
False
,
attention_type
:
str
=
"self"
,
fuse_qkv_params
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
layer_number
=
(
layer_number
,)
self
.
input_layernorm
=
input_layernorm
self
.
attention_type
=
attention_type
self
.
return_layernorm_output
=
return_layernorm_output
self
.
init_method
=
init_method
self
.
fuse_qkv_params
=
fuse_qkv_params
# We only support zero-initializer for bias weights.
self
.
bias_initializer
=
initializers
.
get
(
"zeros"
)
assert
(
attention_type
in
AttnTypes
),
f
"attention_type
{
attention_type
}
not supported"
self
.
hidden_size_per_attention_head
=
kv_channels
self
.
num_attention_heads_per_partition
=
divide
(
num_attention_heads
,
1
)
if
self
.
attention_type
==
"self"
:
if
self
.
input_layernorm
:
self
.
layernorm_qkv
=
LayerNormDense
(
3
*
hidden_size
,
epsilon
=
layernorm_epsilon
,
kernel_initializer
=
init_method
,
use_bias
=
True
,
return_bias
=
False
,
return_layernorm_output
=
return_layernorm_output
,
skip_weight_param_allocation
=
not
fuse_qkv_params
,
)
else
:
self
.
qkv
=
Dense
(
3
*
hidden_size
,
kernel_initializer
=
init_method
,
use_bias
=
True
,
return_bias
=
False
,
skip_weight_param_allocation
=
not
fuse_qkv_params
,
)
else
:
if
self
.
input_layernorm
:
self
.
layernorm_query
=
LayerNormDense
(
hidden_size
,
epsilon
=
layernorm_epsilon
,
kernel_initializer
=
init_method
,
use_bias
=
True
,
return_bias
=
False
,
return_layernorm_output
=
return_layernorm_output
,
skip_weight_param_allocation
=
not
fuse_qkv_params
,
)
else
:
self
.
query_layer
=
Dense
(
hidden_size
,
kernel_initializer
=
init_method
,
use_bias
=
True
,
return_bias
=
False
,
skip_weight_param_allocation
=
not
fuse_qkv_params
,
)
self
.
key_value
=
Dense
(
2
*
hidden_size
,
kernel_initializer
=
init_method
,
use_bias
=
True
,
return_bias
=
False
,
skip_weight_param_allocation
=
not
fuse_qkv_params
,
)
# Core Self attention.
self
.
core_attention
=
CoreAttention
(
num_attention_heads
,
kv_channels
,
attention_dropout
,
layer_number
=
layer_number
,
apply_query_key_layer_scaling
=
apply_query_key_layer_scaling
,
attention_softmax_in_fp32
=
attention_softmax_in_fp32
,
attn_mask_type
=
attn_mask_type
,
)
# Linear
self
.
proj
=
Dense
(
hidden_size
,
kernel_initializer
=
output_layer_init_method
,
use_bias
=
False
,
return_bias
=
True
,
)
def
build
(
self
,
input_shape
):
"""One-time allocation of the variables."""
input_shape
=
tf
.
TensorShape
(
input_shape
)
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
if
last_dim
is
None
:
raise
ValueError
(
"The last dimension of the inputs to a Dense layer should be "
f
"defined. Found None. Full input shape received:
{
input_shape
}
"
)
if
not
self
.
fuse_qkv_params
:
self
.
set_qkv_params
(
last_dim
,
3
*
self
.
hidden_size
,
use_bias
=
True
,
)
def
set_qkv_params
(
self
,
in_features
,
out_features
,
use_bias
:
bool
=
False
,
)
->
None
:
"""Initialize separate Parameters for query, key, and value tensors."""
assert
(
out_features
%
3
==
0
),
f
"3 way QKV split with dimension
{
out_features
}
not possible."
qkv_dim
=
out_features
//
3
if
self
.
attention_type
==
"self"
:
self
.
qkv_weight
=
self
.
add_weight
(
name
=
"qkv_kernel"
,
shape
=
(
in_features
,
out_features
),
initializer
=
self
.
init_method
,
trainable
=
True
,
)
self
.
qkv_bias
=
None
if
use_bias
:
self
.
qkv_bias
=
self
.
add_weight
(
name
=
"qkv_bias"
,
shape
=
(
out_features
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
else
:
self
.
q_weight
=
self
.
add_weight
(
name
=
"q_kernel"
,
shape
=
(
in_features
,
qkv_dim
),
initializer
=
self
.
init_method
,
trainable
=
True
,
)
self
.
kv_weight
=
self
.
add_weight
(
name
=
"kv_kernel"
,
shape
=
(
in_features
,
2
*
qkv_dim
),
initializer
=
self
.
init_method
,
trainable
=
True
,
)
self
.
q_bias
=
None
self
.
kv_bias
=
None
if
use_bias
:
self
.
q_bias
=
self
.
add_weight
(
name
=
"q_bias"
,
shape
=
(
qkv_dim
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
self
.
kv_bias
=
self
.
add_weight
(
name
=
"kv_bias"
,
shape
=
(
2
*
qkv_dim
,),
initializer
=
self
.
bias_initializer
,
trainable
=
True
,
)
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
if
not
self
.
trainable
:
# When the layer is not trainable, it overrides the value passed
# from model.
training
=
False
return
training
def
call
(
self
,
hidden_states
:
tf
.
Tensor
,
attention_mask
:
tf
.
Tensor
,
encoder_output
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
bool
=
None
,
)
->
Tuple
[
Union
[
tf
.
Tensor
,
None
],
...]:
"""MultiHeadAttention FWD"""
training
=
self
.
_get_training_value
(
training
)
# hidden_states: [sq, b, h]
if
attention_mask
is
not
None
:
assert
(
attention_mask
.
dtype
==
tf
.
bool
),
"Attention mask must be a boolean tensor"
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
"self"
:
qkv_weight
=
self
.
qkv_weight
if
not
self
.
fuse_qkv_params
else
None
qkv_bias
=
self
.
qkv_bias
if
not
self
.
fuse_qkv_params
else
None
# Attention heads [sq, b, h] --> [sq, b, (np * 3 * hn)]
if
self
.
input_layernorm
:
layernorm_qkv_outputs
=
self
.
layernorm_qkv
(
hidden_states
,
kernel
=
qkv_weight
,
bias
=
qkv_bias
,
training
=
training
,
)
if
self
.
return_layernorm_output
:
mixed_x_layer
,
layernorm_output
=
layernorm_qkv_outputs
else
:
mixed_x_layer
=
layernorm_qkv_outputs
else
:
mixed_x_layer
=
self
.
qkv
(
hidden_states
,
kernel
=
qkv_weight
,
bias
=
qkv_bias
,
training
=
training
,
)
# [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn]
new_tensor_shape
=
(
*
mixed_x_layer
.
shape
[:
-
1
],
self
.
num_attention_heads_per_partition
,
3
*
self
.
hidden_size_per_attention_head
,
)
mixed_x_layer
=
tf
.
reshape
(
mixed_x_layer
,
new_tensor_shape
)
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
query_layer
,
key_layer
,
value_layer
=
tf
.
split
(
mixed_x_layer
,
num_or_size_splits
=
3
,
axis
=-
1
)
else
:
kv_weight
=
self
.
kv_weight
if
not
self
.
fuse_qkv_params
else
None
kv_bias
=
self
.
kv_bias
if
not
self
.
fuse_qkv_params
else
None
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
=
self
.
key_value
(
encoder_output
,
kernel
=
kv_weight
,
bias
=
kv_bias
,
training
=
training
,
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
(
*
mixed_kv_layer
.
shape
[:
-
1
],
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
,
)
mixed_kv_layer
=
tf
.
reshape
(
mixed_kv_layer
,
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
key_layer
,
value_layer
=
tf
.
split
(
mixed_kv_layer
,
num_or_size_splits
=
2
,
axis
=-
1
)
# Attention head [sq, b, h] --> [sq, b, hp]
if
self
.
input_layernorm
:
layernorm_query_outputs
=
self
.
layernorm_query
(
hidden_states
,
kernel
=
self
.
q_weight
,
bias
=
self
.
q_bias
,
training
=
training
,
)
if
self
.
return_layernorm_output
:
query_layer
,
layernorm_output
=
layernorm_query_outputs
else
:
query_layer
=
layernorm_query_outputs
else
:
query_layer
=
self
.
query_layer
(
hidden_states
,
kernel
=
self
.
q_weight
,
bias
=
self
.
q_bias
,
training
=
training
,
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
(
*
query_layer
.
shape
[:
-
1
],
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
)
query_layer
=
tf
.
reshape
(
query_layer
,
new_tensor_shape
)
# ==================================
# core attention computation
# ==================================
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
# =================
# Output. [sq, b, h]
# =================
attention_output
,
attention_bias
=
self
.
proj
(
context_layer
,
training
=
training
,
)
if
self
.
input_layernorm
and
self
.
return_layernorm_output
:
return
attention_output
,
attention_bias
,
layernorm_output
return
attention_output
,
attention_bias
class
DropPath
(
tf
.
keras
.
Model
):
# pylint: disable=too-few-public-methods
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
"""
def
__init__
(
self
,
drop_prob
:
float
=
0.0
)
->
None
:
super
().
__init__
()
self
.
drop_prob
=
drop_prob
def
__call__
(
self
,
hidden_state
:
tf
.
Tensor
,
training
:
bool
)
->
tf
.
Tensor
:
"""DropPath FWD"""
if
self
.
drop_prob
==
0.0
or
not
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
len
(
hidden_state
.
shape
)
-
1
)
# TODO(kaixih): We set the seed mainly for debugging purpose. Should
# allow users to turn it off.
random_tensor
=
tf
.
random
.
stateless_uniform
(
shape
,
seed
=
[
1
,
0
])
random_mask
=
tf
.
cast
(
random_tensor
<=
keep_prob
,
dtype
=
hidden_state
.
dtype
)
output
=
(
hidden_state
/
keep_prob
)
*
random_mask
return
output
class
TransformerLayer
(
tf
.
keras
.
Model
):
# pylint: disable=too-few-public-methods
"""
TransformerLayer is made up of an attention block and a feedforward network
(MLP). This standard layer is based on the paper
"Attention Is All You Need".
Parameters
----------
hidden_size : int
size of each input sample.
ffn_hidden_size : int
intermediate size to which input samples are projected.
num_attention_heads : int
number of attention heads in the transformer layer.
layernorm_epsilon : float, default = 1e-5
a value added to the denominator of layer normalization for numerical
stability.
hidden_dropout: float, default = 0.1
dropout probability for the dropout op after FC2 layer.
attention_dropout: float, default = 0.1
dropout probability for the dropout op during multi-head attention.
init_method : Callable, default = `None`
used for initializing weights of QKV and FC1 weights in the following way:
`init_method(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
output_layer_init_method : Callable, default = `None`
used for initializing weights of PROJ and FC2 in the following way:
`output_layer_init_method(weight)`. When set to `None`, defaults to
`tf.keras.initializers.RandomNormal(mean=0.0, std=0.023)`.
apply_residual_connection_post_layernorm : bool, default = `False`
if set to `True`, residual connections are taken from the output of layer
norm (default is taken from input of layer norm)
layer_number: int, default = `None`
layer number of the current `TransformerLayer` when multiple such modules
are concatenated to form a transformer block.
apply_query_key_layer_scaling: bool, default = `True`
apply query-key layer scaling during BMM1 by a factor of `layer_number`
output_layernorm: bool, default = `False`
if set to `True`, layer normalization is applied on the output side, after
the final dropout-add. default behavior is to apply layer normalization on
the input side, before the QKV transformation.
attention_softmax_in_fp32: bool, default = `False`
if set to `True`, softmax is executed in tf.float32 dtype (single
precision)
layer_type: {'encoder', 'decoder'}, default = `encoder`
if set to `decoder`, an additional cross-attn block is added after
self-attn. This can be used for structures like `T5` Transformer in
conjunction with the `encoder` option.
kv_channels: int, default = `None`
number of key-value channels. defaults to
`hidden_size / num_attention_heads` if `None`.
self_attn_mask_type: {'causal', 'padding'}, default = `causal`
type of attention mask passed into softmax operation.
Optimization parameters
-----------------------
drop_path_rate: float, default = 0.0
when > 0.0, applies stochastic depth per sample in the main path of the
residual block.
fuse_qkv_params: bool, default = 'False'
if set to `True`, `TransformerLayer` module exposes a single fused
parameter for query-key-value. This enables optimizations such as QKV
fusion without concatentations/splits.
"""
def
__init__
(
self
,
hidden_size
:
int
,
ffn_hidden_size
:
int
,
num_attention_heads
:
int
,
layernorm_epsilon
:
float
=
1e-5
,
hidden_dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
init_method
:
Optional
[
Callable
]
=
None
,
output_layer_init_method
:
Optional
[
Callable
]
=
None
,
layer_number
:
Optional
[
int
]
=
None
,
kv_channels
:
Optional
[
int
]
=
None
,
self_attn_mask_type
:
str
=
"causal"
,
apply_query_key_layer_scaling
:
bool
=
True
,
attention_softmax_in_fp32
:
bool
=
False
,
apply_residual_connection_post_layernorm
:
bool
=
False
,
output_layernorm
:
bool
=
False
,
layer_type
:
str
=
"encoder"
,
drop_path_rate
:
float
=
0.0
,
fuse_qkv_params
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
bias_dropout_fusion
=
\
bool
(
int
(
os
.
getenv
(
"NVTE_BIAS_DROPOUT_FUSION"
,
"1"
)))
self
.
layer_number
=
layer_number
self
.
output_layernorm
=
output_layernorm
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_layernorm
=
(
apply_residual_connection_post_layernorm
)
assert
(
self_attn_mask_type
in
AttnMaskTypes
),
f
"self_attn_mask_type
{
self_attn_mask_type
}
not supported"
assert
layer_type
in
LayerTypes
,
\
f
"layer_type
{
layer_type
}
not supported"
self
.
kv_channels
=
(
kv_channels
if
kv_channels
else
(
hidden_size
//
num_attention_heads
)
)
if
init_method
is
None
:
init_method
=
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
if
output_layer_init_method
is
None
:
output_layer_init_method
=
initializers
.
RandomNormal
(
mean
=
0.0
,
stddev
=
0.023
)
attention_args
=
(
hidden_size
,
num_attention_heads
,
self
.
kv_channels
,
attention_dropout
,
layernorm_epsilon
,
init_method
,
output_layer_init_method
,
)
common_attention_kwargs
=
{
"layer_number"
:
layer_number
,
"apply_query_key_layer_scaling"
:
apply_query_key_layer_scaling
,
"attention_softmax_in_fp32"
:
attention_softmax_in_fp32
,
"return_layernorm_output"
:
apply_residual_connection_post_layernorm
,
"fuse_qkv_params"
:
fuse_qkv_params
,
}
self
.
self_attention
=
MultiHeadAttention
(
*
attention_args
,
**
common_attention_kwargs
,
attn_mask_type
=
self_attn_mask_type
,
input_layernorm
=
not
output_layernorm
,
attention_type
=
"self"
,
)
if
layer_type
==
"decoder"
:
self
.
inter_attention
=
MultiHeadAttention
(
*
attention_args
,
**
common_attention_kwargs
,
attn_mask_type
=
"padding"
,
input_layernorm
=
True
,
attention_type
=
"cross"
,
)
# LayerNorm -> gelu(Linear + Bias) -> Linear
self
.
layernorm_mlp
=
LayerNormMLP
(
hidden_size
,
ffn_hidden_size
,
epsilon
=
layernorm_epsilon
,
kernel_initializer
=
init_method
,
ffn_kernel_initializer
=
output_layer_init_method
,
use_bias
=
False
,
return_bias
=
True
,
return_layernorm_output
=
apply_residual_connection_post_layernorm
,
)
self
.
hidden_dropout
=
hidden_dropout
self
.
bias_dropout_fusion
=
bias_dropout_fusion
self
.
drop_path
=
(
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
)
if
self
.
output_layernorm
:
self
.
layernorm
=
LayerNorm
(
epsilon
=
layernorm_epsilon
,
)
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
if
not
self
.
trainable
:
# When the layer is not trainable, it overrides the value passed
# from model.
training
=
False
return
training
def
__call__
(
self
,
hidden_states
:
tf
.
Tensor
,
attention_mask
:
tf
.
Tensor
,
encoder_output
:
Optional
[
tf
.
Tensor
]
=
None
,
enc_dec_attn_mask
:
Optional
[
tf
.
Tensor
]
=
None
,
training
:
bool
=
None
,
)
->
tf
.
Tensor
:
"""
Transformer Layer: attention block and a feedforward network (MLP)
Parameters
----------
hidden_states : tf.Tensor
Input tensor.
attention_mask : tf.Tensor
Boolean tensor used to mask out self-attention softmax input.
encoder_output : tf.Tensor
Output of the encoder block to be fed into the decoder block if using
`layer_type="decoder"`.
enc_dec_attn_mask : tf.Tensor
Boolean tensor used to mask out inter-attention softmax input if using
`layer_type="decoder"`.
"""
if
attention_mask
is
not
None
:
assert
(
attention_mask
.
dtype
==
tf
.
bool
),
"Attention mask must be a boolean tensor"
# Theoretically, the input dtype can be handled by the autocast during
# the layer call. However, we may use the input (hidden_states) in the
# residual connection before the layer is called. So, we convert it
# ahead of time. As for the other input (encoder_output), we can leave
# the conversion to the inter_attention layer, since it won't be used in
# the residual connection.
hidden_states
=
self
.
_maybe_cast_inputs
(
hidden_states
)
# Self attention.
self_attention_outputs
=
self
.
self_attention
(
hidden_states
,
attention_mask
,
training
=
training
,
)
if
(
self
.
apply_residual_connection_post_layernorm
and
not
self
.
output_layernorm
):
attention_output
,
attention_bias
,
residual
=
self_attention_outputs
else
:
attention_output
,
attention_bias
=
self_attention_outputs
residual
=
hidden_states
# Set BDA func.
if
self
.
bias_dropout_fusion
:
if
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
training
)
# Bias dropout add.
attention_bias
=
tf
.
cast
(
attention_bias
,
dtype
=
self
.
compute_dtype
)
if
self
.
drop_path
is
None
:
bda_output
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
,
)
else
:
# TODO(kaixih): Use stateless_dropout and specify the seed
# mainly for debugging purpose. Should allow random seed.
out
=
(
tf
.
nn
.
experimental
.
stateless_dropout
(
attention_output
+
attention_bias
,
rate
=
self
.
hidden_dropout
,
seed
=
[
1
,
0
],
)
if
training
else
attention_output
+
attention_bias
)
bda_output
=
residual
+
self
.
drop_path
(
out
,
training
)
# Cross attention.
if
self
.
layer_type
==
"decoder"
:
inter_attention_outputs
=
self
.
inter_attention
(
bda_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
,
training
=
training
,
)
if
self
.
apply_residual_connection_post_layernorm
:
attention_output
,
attention_bias
,
residual
=
\
inter_attention_outputs
else
:
attention_output
,
attention_bias
=
inter_attention_outputs
residual
=
bda_output
attention_bias
=
tf
.
cast
(
attention_bias
,
dtype
=
self
.
compute_dtype
)
bda_output
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
,
)
# MLP.
mlp_outputs
=
self
.
layernorm_mlp
(
bda_output
,
training
=
training
,
)
if
self
.
apply_residual_connection_post_layernorm
:
mlp_output
,
mlp_bias
,
residual
=
mlp_outputs
else
:
mlp_output
,
mlp_bias
=
mlp_outputs
residual
=
bda_output
# Bias dropout add.
mlp_bias
=
tf
.
cast
(
mlp_bias
,
dtype
=
self
.
compute_dtype
)
if
self
.
drop_path
is
None
:
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
,
residual
,
self
.
hidden_dropout
,
)
else
:
# TODO(kaixih): Use stateless_dropout and specify the seed
# mainly for debugging purpose. Should allow random seed.
output
=
(
tf
.
nn
.
experimental
.
stateless_dropout
(
mlp_output
+
mlp_bias
,
rate
=
self
.
hidden_dropout
,
seed
=
[
1
,
0
],
)
if
training
else
mlp_output
+
mlp_bias
)
output
=
residual
+
self
.
drop_path
(
output
,
training
)
# For BERT like architectures.
if
self
.
output_layernorm
:
output
=
self
.
layernorm
(
output
)
# output: [b, s, h]
return
output
transformer_engine/tensorflow/utils.py
deleted
100644 → 0
View file @
2574a1ca
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""
import
tensorflow
as
tf
def
attention_mask_func
(
attention_scores
:
tf
.
Tensor
,
attention_mask
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Get attention mask"""
return
tf
.
where
(
attention_mask
,
attention_scores
,
-
10000.0
)
def
ensure_divisibility
(
numerator
:
int
,
denominator
:
int
)
->
None
:
"""Ensure that numerator is divisible by the denominator."""
assert
(
numerator
%
denominator
==
0
),
f
"
{
numerator
}
is not divisible by
{
denominator
}
"
def
divide
(
numerator
:
int
,
denominator
:
int
)
->
int
:
"""Ensure that numerator is divisible by the denominator and return
the division value."""
ensure_divisibility
(
numerator
,
denominator
)
return
numerator
//
denominator
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment