Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
a207db1d
Commit
a207db1d
authored
Apr 01, 2025
by
yuguo
Browse files
Merge branch 'main' of
https://github.com/NVIDIA/TransformerEngine
parents
fbee8990
69365f88
Changes
101
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1773 additions
and
2684 deletions
+1773
-2684
transformer_engine/jax/csrc/extensions/ffi.h
transformer_engine/jax/csrc/extensions/ffi.h
+25
-0
transformer_engine/jax/csrc/extensions/gemm.cpp
transformer_engine/jax/csrc/extensions/gemm.cpp
+214
-0
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+6
-0
transformer_engine/jax/csrc/extensions/normalization.cpp
transformer_engine/jax/csrc/extensions/normalization.cpp
+141
-476
transformer_engine/jax/csrc/extensions/packing.cpp
transformer_engine/jax/csrc/extensions/packing.cpp
+0
-77
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+38
-79
transformer_engine/jax/csrc/extensions/quantization.cpp
transformer_engine/jax/csrc/extensions/quantization.cpp
+106
-48
transformer_engine/jax/csrc/extensions/softmax.cpp
transformer_engine/jax/csrc/extensions/softmax.cpp
+0
-97
transformer_engine/jax/csrc/extensions/transpose.cpp
transformer_engine/jax/csrc/extensions/transpose.cpp
+0
-289
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+302
-0
transformer_engine/jax/dot.py
transformer_engine/jax/dot.py
+0
-242
transformer_engine/jax/flax/__init__.py
transformer_engine/jax/flax/__init__.py
+1
-2
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+158
-165
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+5
-2
transformer_engine/jax/fp8.py
transformer_engine/jax/fp8.py
+0
-427
transformer_engine/jax/layernorm.py
transformer_engine/jax/layernorm.py
+95
-342
transformer_engine/jax/layernorm_dense.py
transformer_engine/jax/layernorm_dense.py
+309
-0
transformer_engine/jax/layernorm_mlp.py
transformer_engine/jax/layernorm_mlp.py
+260
-438
transformer_engine/jax/quantize/__init__.py
transformer_engine/jax/quantize/__init__.py
+17
-0
transformer_engine/jax/quantize/dequantizer.py
transformer_engine/jax/quantize/dequantizer.py
+96
-0
No files found.
transformer_engine/jax/csrc/extensions/ffi.h
View file @
a207db1d
...
@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id
...
@@ -81,5 +81,30 @@ inline size_t product(const xla::ffi::Span<const int64_t>& data, size_t start_id
std
::
multiplies
<
size_t
>
());
std
::
multiplies
<
size_t
>
());
}
}
inline
static
size_t
te_dtype_bytes
(
const
DType
&
type
)
{
switch
(
type
)
{
case
DType
::
kByte
:
return
1
;
case
DType
::
kInt32
:
return
4
;
case
DType
::
kInt64
:
return
8
;
case
DType
::
kFloat32
:
return
4
;
case
DType
::
kFloat16
:
return
2
;
case
DType
::
kBFloat16
:
return
2
;
case
DType
::
kFloat8E5M2
:
return
1
;
case
DType
::
kFloat8E4M3
:
return
1
;
case
DType
::
kFloat8E8M0
:
return
1
;
default:
NVTE_ERROR
(
"Unsupported DType: "
,
static_cast
<
int
>
(
type
));
}
}
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/gemm.cpp
0 → 100644
View file @
a207db1d
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/gemm.h"
#include <memory>
#include "common/util/cuda_runtime.h"
#include "common/util/system.h"
#include "extensions.h"
#include "xla/ffi/api/c_api.h"
namespace
transformer_engine
{
namespace
jax
{
constexpr
static
size_t
MXFP8_BLOCK_SIZE
=
32
;
// Note: we only support TN-GEMM for now (TN in cuBLASLt == NT in JAX)
Error_Type
GroupedGemmImpl
(
uint8_t
*
lhs_ptr
,
const
DType
&
lhs_dtype
,
uint8_t
*
lhs_sinv_ptr
,
const
DType
&
lhs_sinv_dtype
,
uint8_t
*
rhs_ptr
,
const
DType
&
rhs_dtype
,
uint8_t
*
rhs_sinv_ptr
,
const
DType
&
rhs_sinv_dtype
,
uint8_t
*
bias_ptr
,
const
DType
&
bias_dtype
,
uint8_t
*
out_ptr
,
const
DType
&
out_dtype
,
uint8_t
*
workspace_ptr
,
const
size_t
workspace_size
,
size_t
num_gemms
,
int32_t
*
dim_list_ptr
,
const
int64_t
&
scaling_mode
,
cudaStream_t
stream
)
{
size_t
lhs_dtype_bytes
=
te_dtype_bytes
(
lhs_dtype
);
size_t
rhs_dtype_bytes
=
te_dtype_bytes
(
rhs_dtype
);
size_t
lhs_sinv_dtype_bytes
=
te_dtype_bytes
(
lhs_sinv_dtype
);
size_t
rhs_sinv_dtype_bytes
=
te_dtype_bytes
(
rhs_sinv_dtype
);
size_t
bias_dtype_bytes
=
te_dtype_bytes
(
bias_dtype
);
size_t
out_dtype_bytes
=
te_dtype_bytes
(
out_dtype
);
NVTE_CHECK
(
lhs_dtype_bytes
==
rhs_dtype_bytes
,
"sizeof(lhs_dtype) != sizeof(rhs_dtype)"
);
NVTE_CHECK
(
lhs_sinv_dtype_bytes
==
rhs_sinv_dtype_bytes
,
"sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"
);
size_t
dim_list_bytes
=
sizeof
(
int32_t
)
*
3
*
num_gemms
;
std
::
unique_ptr
<
int32_t
[]
>
dim_list_host
=
std
::
make_unique
<
int32_t
[]
>
(
3
*
num_gemms
);
cudaMemcpyAsync
(
dim_list_host
.
get
(),
dim_list_ptr
,
dim_list_bytes
,
cudaMemcpyDeviceToHost
,
stream
);
// Note: This may break cudaGraph.
cudaStreamSynchronize
(
stream
);
// Notes on matrix layouts and transpose:
// Jax uses row-major layout, on entering this function, each input matrix pair:
// A: row-major with size [m, k],
// B: row-major with size [n, k], needs transpose,
// on exiting this function, JAX expect:
// C: row-major with size [m, n].
// cuBLAS uses column-major layout, in this view, each input matrix pair:
// A: column-major with size [k, m], needs transpose,
// B: column-major with size [k, n].
// If we call cuBLAS GEMM for A * B, the output will be:
// C: column-major with size [m, n] --> row-major with size [n, m].
// To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call.
bool
trans_lhs
=
true
;
bool
trans_rhs
=
false
;
auto
num_math_sm
=
cuda
::
sm_count
()
-
getenv
<
int
>
(
"NVTE_EXT_MARGIN_SM"
,
0
);
bool
grad
=
false
;
bool
accumulate
=
false
;
bool
use_split_accumulator
=
false
;
// These lists are to keep the TensorWrapper objects alive
std
::
vector
<
TensorWrapper
>
lhs_wrapper_list
;
std
::
vector
<
TensorWrapper
>
rhs_wrapper_list
;
std
::
vector
<
TensorWrapper
>
bias_wrapper_list
;
std
::
vector
<
TensorWrapper
>
pre_gelu_wrapper_list
;
std
::
vector
<
TensorWrapper
>
out_wrapper_list
;
std
::
vector
<
TensorWrapper
>
workspace_wrapper_list
;
// These lists are the actual NVTETensor (void *) lists for multi-stream GEMM
std
::
vector
<
NVTETensor
>
lhs_list
;
std
::
vector
<
NVTETensor
>
rhs_list
;
std
::
vector
<
NVTETensor
>
bias_list
;
std
::
vector
<
NVTETensor
>
pre_gelu_list
;
std
::
vector
<
NVTETensor
>
out_list
;
std
::
vector
<
NVTETensor
>
workspace_list
;
for
(
int
i
=
0
;
i
<
num_gemms
;
i
++
)
{
size_t
m
=
dim_list_host
[
i
*
3
];
size_t
n
=
dim_list_host
[
i
*
3
+
1
];
size_t
k
=
dim_list_host
[
i
*
3
+
2
];
auto
lhs_shape
=
std
::
vector
<
size_t
>
{
m
,
k
};
auto
rhs_shape
=
std
::
vector
<
size_t
>
{
n
,
k
};
auto
out_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
lhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
auto
rhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
if
(
scaling_mode
==
NVTE_NO_SCALING
||
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
auto
lhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_shape
,
lhs_dtype
,
nullptr
,
nullptr
,
reinterpret_cast
<
float
*>
(
lhs_sinv_ptr
));
auto
rhs_i
=
TensorWrapper
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_shape
,
rhs_dtype
,
nullptr
,
nullptr
,
reinterpret_cast
<
float
*>
(
rhs_sinv_ptr
));
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
}
else
if
(
scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
NVTE_CHECK
(
k
%
MXFP8_BLOCK_SIZE
==
0
,
"MXFP8 K-dim being divisble by %d (got %d)"
,
MXFP8_BLOCK_SIZE
,
k
);
size_t
sinv_k
=
k
/
MXFP8_BLOCK_SIZE
;
lhs_sinv_shape
[
0
]
=
m
;
lhs_sinv_shape
[
1
]
=
sinv_k
;
rhs_sinv_shape
[
0
]
=
n
;
rhs_sinv_shape
[
1
]
=
sinv_k
;
// Note: the scale_inv array should have been swizzled in Python before lowering
TensorWrapper
lhs_i
(
NVTE_MXFP8_1D_SCALING
);
TensorWrapper
rhs_i
(
NVTE_MXFP8_1D_SCALING
);
lhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
lhs_ptr
),
lhs_dtype
,
lhs_shape
);
rhs_i
.
set_rowwise_data
(
static_cast
<
void
*>
(
rhs_ptr
),
rhs_dtype
,
rhs_shape
);
lhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
lhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
lhs_sinv_shape
);
rhs_i
.
set_rowwise_scale_inv
(
static_cast
<
void
*>
(
rhs_sinv_ptr
),
DType
::
kFloat8E8M0
,
rhs_sinv_shape
);
lhs_wrapper_list
.
push_back
(
std
::
move
(
lhs_i
));
rhs_wrapper_list
.
push_back
(
std
::
move
(
rhs_i
));
}
else
{
NVTE_ERROR
(
"Unsupported scaling mode: "
,
scaling_mode
);
}
auto
out_i
=
TensorWrapper
(
static_cast
<
void
*>
(
out_ptr
),
out_shape
,
out_dtype
);
lhs_ptr
+=
m
*
k
*
lhs_dtype_bytes
;
rhs_ptr
+=
n
*
k
*
rhs_dtype_bytes
;
out_ptr
+=
m
*
n
*
out_dtype_bytes
;
lhs_sinv_ptr
+=
lhs_sinv_shape
[
0
]
*
lhs_sinv_shape
[
1
]
*
lhs_sinv_dtype_bytes
;
rhs_sinv_ptr
+=
rhs_sinv_shape
[
0
]
*
rhs_sinv_shape
[
1
]
*
rhs_sinv_dtype_bytes
;
void
*
pre_gelu_ptr
=
nullptr
;
auto
bias_shape
=
std
::
vector
<
size_t
>
{
0
};
auto
pre_gelu_shape
=
std
::
vector
<
size_t
>
{
0
};
if
(
bias_ptr
!=
nullptr
)
bias_shape
[
0
]
=
n
;
auto
bias_i
=
TensorWrapper
(
bias_ptr
,
bias_shape
,
bias_dtype
);
if
(
bias_ptr
!=
nullptr
)
bias_ptr
+=
n
*
bias_dtype_bytes
;
auto
pre_gelu_i
=
TensorWrapper
(
pre_gelu_ptr
,
pre_gelu_shape
,
out_dtype
);
out_wrapper_list
.
push_back
(
std
::
move
(
out_i
));
bias_wrapper_list
.
push_back
(
std
::
move
(
bias_i
));
pre_gelu_wrapper_list
.
push_back
(
std
::
move
(
pre_gelu_i
));
lhs_list
.
push_back
(
lhs_wrapper_list
.
back
().
data
());
rhs_list
.
push_back
(
rhs_wrapper_list
.
back
().
data
());
bias_list
.
push_back
(
bias_wrapper_list
.
back
().
data
());
pre_gelu_list
.
push_back
(
pre_gelu_wrapper_list
.
back
().
data
());
out_list
.
push_back
(
out_wrapper_list
.
back
().
data
());
}
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
for
(
int
i
=
0
;
i
<
num_streams
;
i
++
)
{
auto
workspace_i
=
TensorWrapper
(
static_cast
<
void
*>
(
workspace_ptr
),
workspace_shape
,
DType
::
kByte
);
workspace_wrapper_list
.
push_back
(
std
::
move
(
workspace_i
));
workspace_list
.
push_back
(
workspace_wrapper_list
.
back
().
data
());
workspace_ptr
+=
workspace_size
;
}
nvte_multi_stream_cublas_gemm
(
rhs_list
.
data
(),
lhs_list
.
data
(),
out_list
.
data
(),
bias_list
.
data
(),
pre_gelu_list
.
data
(),
num_gemms
,
trans_lhs
,
trans_rhs
,
grad
,
workspace_list
.
data
(),
accumulate
,
use_split_accumulator
,
num_math_sm
,
stream
);
return
ffi_with_cuda_error_check
();
}
Error_Type
GroupedGemmFFI
(
cudaStream_t
stream
,
Buffer_Type
lhs_flatten
,
Buffer_Type
lhs_sinv_flatten
,
Buffer_Type
rhs_flatten
,
Buffer_Type
rhs_sinv_flatten
,
Buffer_Type
bias_flatten
,
Buffer_Type
dim_list
,
Result_Type
out_flatten
,
Result_Type
workspace_flatten
,
int64_t
num_gemms
,
int64_t
scaling_mode
)
{
// Inputs
auto
lhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_flatten
.
untyped_data
());
auto
rhs_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_flatten
.
untyped_data
());
auto
lhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
lhs_sinv_flatten
.
untyped_data
());
auto
rhs_sinv_ptr
=
reinterpret_cast
<
uint8_t
*>
(
rhs_sinv_flatten
.
untyped_data
());
auto
bias_ptr
=
reinterpret_cast
<
uint8_t
*>
(
bias_flatten
.
untyped_data
());
auto
dim_list_ptr
=
reinterpret_cast
<
int32_t
*>
(
dim_list
.
untyped_data
());
auto
lhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_flatten
.
element_type
());
auto
rhs_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_flatten
.
element_type
());
auto
lhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
lhs_sinv_flatten
.
element_type
());
auto
rhs_sinv_dtype
=
convert_ffi_datatype_to_te_dtype
(
rhs_sinv_flatten
.
element_type
());
auto
bias_dtype
=
convert_ffi_datatype_to_te_dtype
(
bias_flatten
.
element_type
());
// Outputs
auto
out_ptr
=
reinterpret_cast
<
uint8_t
*>
(
out_flatten
->
untyped_data
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
out_flatten
->
element_type
());
auto
workspace_ptr
=
reinterpret_cast
<
uint8_t
*>
(
workspace_flatten
->
untyped_data
());
auto
workspace_size
=
workspace_flatten
->
dimensions
().
back
()
/
num_streams
;
return
GroupedGemmImpl
(
lhs_ptr
,
lhs_dtype
,
lhs_sinv_ptr
,
lhs_sinv_dtype
,
rhs_ptr
,
rhs_dtype
,
rhs_sinv_ptr
,
rhs_sinv_dtype
,
bias_ptr
,
bias_dtype
,
out_ptr
,
out_dtype
,
workspace_ptr
,
workspace_size
,
num_gemms
,
dim_list_ptr
,
scaling_mode
,
stream
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
GroupedGemmHandler
,
GroupedGemmFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// lhs_flatten
.
Arg
<
Buffer_Type
>
()
// lhs_sinv_flatten
.
Arg
<
Buffer_Type
>
()
// rhs_flatten
.
Arg
<
Buffer_Type
>
()
// rhs_sinv_flatten
.
Arg
<
Buffer_Type
>
()
// bias_flatten
.
Arg
<
Buffer_Type
>
()
// dim_list
.
Ret
<
Buffer_Type
>
()
// out_flatten
.
Ret
<
Buffer_Type
>
()
// workspace_flatten
.
Attr
<
int64_t
>
(
"num_gemms"
)
.
Attr
<
int64_t
>
(
"scaling_mode"
),
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/misc.h
View file @
a207db1d
...
@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
...
@@ -34,5 +34,11 @@ inline size_t product(const std::vector<size_t> &shape) {
return
ret
;
return
ret
;
}
}
enum
class
QuantizeAxis
{
ROWWISE
,
COLWISE
,
ROWWISE_COLWISE
,
};
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/normalization.cpp
View file @
a207db1d
...
@@ -5,15 +5,18 @@
...
@@ -5,15 +5,18 @@
************************************************************************/
************************************************************************/
#include "transformer_engine/normalization.h"
#include "transformer_engine/normalization.h"
#include <cuda_runtime.h>
#include "extensions.h"
#include "extensions.h"
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
pybind11
::
tuple
GetLayerNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
pybind11
::
tuple
GetNormForwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
in_dtype
,
DType
w_dtype
,
DType
out_dtype
,
DType
w_dtype
,
DType
out_dtype
,
bool
is_layer_norm
,
bool
zero_centered_gamma
,
NVTE_Norm_Type
norm_type
,
int
scaling_mode
,
float
eps
,
int
sm_margin
)
{
bool
zero_centered_gamma
,
float
epsilon
,
int
sm_margin
,
bool
is_training
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
...
@@ -21,23 +24,32 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
...
@@ -21,23 +24,32 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
// empty tensor wrappers are okay just to get workspace size
// empty tensor wrappers are okay just to get workspace size
auto
input_tensor
=
TensorWrapper
(
nullptr
,
input_shape
,
in_dtype
);
auto
input_tensor
=
TensorWrapper
(
nullptr
,
input_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
nullptr
,
input_shape
,
out_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
output_tensor
.
set_rowwise_data
(
nullptr
,
out_dtype
,
input_shape
);
// WAR: NVTE Norms query the is_training from whereas columwise_data is allocated
if
(
is_training
&&
_scaling_mode
==
NVTE_MXFP8_1D_SCALING
)
{
int
temp
=
1
;
output_tensor
.
set_columnwise_data
(
static_cast
<
void
*>
(
&
temp
),
out_dtype
,
input_shape
);
}
// dummy tensor wrappers that will carry workspace size info later
// dummy tensor wrappers that will carry workspace size info later
TensorWrapper
dummy_work_tensor
;
TensorWrapper
dummy_work_tensor
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
if
(
is_l
ayer
_n
orm
)
{
if
(
norm_type
==
NVTE_Norm_Type
::
L
ayer
N
orm
)
{
auto
beta_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
w_dtype
);
auto
beta_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
w_dtype
);
auto
mu_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
mu_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
nvte_layernorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
beta_tensor
.
data
(),
eps
,
nvte_layernorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
beta_tensor
.
data
(),
eps
ilon
,
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
}
else
{
}
else
{
// TODO(Phuong): Verify and remove this check
NVTE_CHECK
(
scaling_mode
!=
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
NVTE_CHECK
(
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
eps
,
output_tensor
.
data
(),
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
eps
ilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
rsigma_tensor
.
data
(),
dummy_work_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
nullptr
);
nullptr
);
}
}
...
@@ -46,232 +58,125 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
...
@@ -46,232 +58,125 @@ pybind11::tuple GetLayerNormForwardWorkspaceSizes(size_t batch_size, size_t hidd
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_work_tensor
.
dtype
()));
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_work_tensor
.
dtype
()));
}
}
void
LayerNormForwardImpl
(
size_t
batch_size
,
size_t
hidden_size
,
size_t
workspace_size
,
Error_Type
NormForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
scale_buf
,
bool
zero_centered_gamma
,
float
eps
,
void
*
input
,
DType
in_dtype
,
Buffer_Type
gamma_buf
,
Buffer_Type
beta_buf
,
Result_Type
output_buf
,
void
*
weight
,
DType
w_dtype
,
void
*
bias
,
void
*
output
,
DType
out_dtype
,
Result_Type
colwise_output_buf
,
Result_Type
scale_inv_buf
,
void
*
workspace
,
DType
work_dtype
,
void
*
mu
,
void
*
rsigma
,
float
*
amax
,
Result_Type
colwise_scale_inv_buf
,
Result_Type
amax_buf
,
float
*
scale
,
float
*
scale_inv
,
int
sm_margin
,
cudaStream_t
stream
)
{
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
int
norm_type
,
bool
zero_centered_gamma
,
double
epsilon
,
int64_t
sm_margin
,
int
scaling_mode
,
bool
is_2x
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
.
element_type
());
auto
wkspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
wkspace_buf
->
element_type
());
auto
*
input
=
x_buf
.
untyped_data
();
auto
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
auto
*
gamma
=
gamma_buf
.
untyped_data
();
auto
*
beta
=
beta_buf
.
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
rsigma
=
rsigma_buf
->
untyped_data
();
auto
*
mu
=
mu_buf
->
untyped_data
();
auto
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
auto
*
workspace
=
wkspace_buf
->
untyped_data
();
auto
_scaling_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode
);
auto
_norm_type
=
static_cast
<
NVTE_Norm_Type
>
(
norm_type
);
auto
_is_2x
=
static_cast
<
bool
>
(
is_2x
);
auto
x_size
=
product
(
x_buf
.
dimensions
());
auto
gamma_size
=
product
(
gamma_buf
.
dimensions
());
auto
workspace_size
=
product
(
wkspace_buf
->
dimensions
());
auto
hidden_size
=
gamma_size
;
auto
batch_size
=
x_size
/
gamma_size
;
float
_epsilon
=
static_cast
<
float
>
(
epsilon
);
int
_sm_margin
=
static_cast
<
int
>
(
sm_margin
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
weight
_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
gamma
_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
auto
is_layer_norm
=
(
bias
)
?
true
:
false
;
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
weight
,
weight
_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
gamma
,
gamma
_shape
,
in_dtype
);
// assume output dtype = input dtype
// If we need mixed I/O precision in the future, we need an additional
// parameter for output type
auto
output_tensor
=
TensorWrapper
(
output
,
input_shape
,
out_dtype
,
amax
,
scale
,
scale_inv
);
auto
rsigma_tensor
=
TensorWrapper
(
rsigma
,
intermediates_shape
,
DType
::
kFloat32
);
auto
rsigma_tensor
=
TensorWrapper
(
rsigma
,
intermediates_shape
,
DType
::
kFloat32
);
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_sm_margin
;
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
auto
output_tensor
=
TensorWrapper
(
_scaling_mode
);
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
work_dtype
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
_scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
_is_2x
)
{
output_tensor
.
set_columnwise_data
(
colwise_output_buf
->
untyped_data
(),
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
}
if
(
is_layer_n
orm
)
{
if
(
_norm_type
==
NVTE_Norm_Type
::
LayerN
orm
)
{
auto
beta_tensor
=
TensorWrapper
(
b
ias
,
weight
_shape
,
w_dtype
);
auto
beta_tensor
=
TensorWrapper
(
b
eta
,
gamma
_shape
,
w_dtype
);
auto
mu_tensor
=
TensorWrapper
(
mu
,
intermediates_shape
,
DType
::
kFloat32
);
auto
mu_tensor
=
TensorWrapper
(
mu
,
intermediates_shape
,
DType
::
kFloat32
);
nvte_layernorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
beta_tensor
.
data
(),
eps
,
nvte_layernorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
beta_tensor
.
data
(),
_
eps
ilon
,
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
output_tensor
.
data
(),
mu_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
}
else
{
}
else
{
NVTE_CHECK
(
!
zero_centered_gamma
,
"rmsnorm doesn't support zero_centered_gamma."
);
NVTE_CHECK
(
scaling_mode
!=
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
||
!
zero_centered_gamma
,
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
eps
,
output_tensor
.
data
(),
"rmsnorm doesn't support zero_centered_gamma."
);
nvte_rmsnorm_fwd
(
input_tensor
.
data
(),
gamma_tensor
.
data
(),
_epsilon
,
output_tensor
.
data
(),
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
rsigma_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
stream
);
}
}
}
Error_Type
LayerNormForwardImplFFI
(
cudaStream_t
stream
,
Buffer_Type
*
x_buf
,
Buffer_Type
*
gamma_buf
,
Buffer_Type
*
beta_buf
,
Buffer_Type
*
amax_buf
,
Buffer_Type
*
scale_buf
,
Buffer_Type
*
scale_inv_buf
,
Result_Type
*
output_buf
,
Result_Type
*
mu_buf
,
Result_Type
*
rsigma_buf
,
Result_Type
*
amax_out_buf
,
Result_Type
*
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
,
bool
is_layer_norm
,
bool
is_fp8
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
((
*
x_buf
).
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
((
*
gamma_buf
).
element_type
());
auto
wkspace_dtype
=
convert_ffi_datatype_to_te_dtype
((
*
wkspace_buf
)
->
element_type
());
auto
*
input
=
x_buf
->
untyped_data
();
auto
*
weight
=
gamma_buf
->
untyped_data
();
auto
*
output
=
(
*
output_buf
)
->
untyped_data
();
auto
*
rsigma
=
(
*
rsigma_buf
)
->
untyped_data
();
auto
*
workspace
=
(
*
wkspace_buf
)
->
untyped_data
();
void
*
bias
=
nullptr
;
void
*
mu
=
nullptr
;
if
(
is_layer_norm
)
{
bias
=
beta_buf
->
untyped_data
();
mu
=
(
*
mu_buf
)
->
untyped_data
();
}
float
*
amax
=
nullptr
;
float
*
scale
=
nullptr
;
float
*
scale_inv
=
nullptr
;
void
*
amax_out
=
nullptr
;
auto
out_dtype
=
in_dtype
;
if
(
is_fp8
)
{
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
->
untyped_data
());
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
->
untyped_data
());
amax_out
=
(
*
amax_out_buf
)
->
untyped_data
();
NVTE_CHECK
(
amax_out
==
amax
,
"amax not bound to amax_out in TE/JAX LayerNormForward primitive"
);
out_dtype
=
DType
::
kFloat8E4M3
;
}
auto
x_size
=
product
(
x_buf
->
dimensions
());
auto
gamma_size
=
product
(
gamma_buf
->
dimensions
());
auto
wkspace_size
=
product
((
*
wkspace_buf
)
->
dimensions
());
auto
hidden_size
=
gamma_size
;
auto
batch_size
=
x_size
/
gamma_size
;
float
eps
=
static_cast
<
float
>
(
eps_
);
int
sm_margin
=
static_cast
<
int
>
(
sm_margin_
);
LayerNormForwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
bias
,
output
,
out_dtype
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
amax
,
scale
,
scale_inv
,
sm_margin
,
stream
);
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
Error_Type
LayerNormForwardFP8FFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
gamma_buf
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
NormForwardHandler
,
NormForwardFFI
,
Buffer_Type
beta_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
amax_out_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormForwardImplFFI
(
stream
,
&
x_buf
,
&
gamma_buf
,
&
beta_buf
,
&
amax_buf
,
&
scale_buf
,
&
scale_inv_buf
,
&
output_buf
,
&
mu_buf
,
&
rsigma_buf
,
&
amax_out_buf
,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
true
,
// is_layer_norm
true
// is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
LayerNormForwardFP8Handler
,
LayerNormForwardFP8FFI
,
FFI
::
Bind
()
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// gamma
.
Arg
<
Buffer_Type
>
()
// beta
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// mu
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
Error_Type
LayerNormForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
gamma_buf
,
Buffer_Type
beta_buf
,
Result_Type
output_buf
,
Result_Type
mu_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormForwardImplFFI
(
stream
,
&
x_buf
,
&
gamma_buf
,
&
beta_buf
,
nullptr
,
// amax_buf
nullptr
,
// scale_buf,
nullptr
,
// scale_inv_buf,
&
output_buf
,
&
mu_buf
,
&
rsigma_buf
,
nullptr
,
// amax_out_buf,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
true
,
// is_layer_norm
false
// is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
LayerNormForwardHandler
,
LayerNormForwardFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// gamma
.
Arg
<
Buffer_Type
>
()
// gamma
.
Arg
<
Buffer_Type
>
()
// beta
.
Arg
<
Buffer_Type
>
()
// beta
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// colwise_output
.
Ret
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// colwise_scale_inv
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// mu
.
Ret
<
Buffer_Type
>
()
// mu
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"norm_type"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
double
>
(
"epsilon"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
.
Attr
<
int64_t
>
(
"sm_margin"
)
FFI_CudaGraph_Traits
);
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
bool
>
(
"is_2x"
),
Error_Type
RMSNormForwardFP8FFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
gamma_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
rsigma_buf
,
Result_Type
amax_out_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormForwardImplFFI
(
stream
,
&
x_buf
,
&
gamma_buf
,
nullptr
,
// beta_buf,
&
amax_buf
,
&
scale_buf
,
&
scale_inv_buf
,
&
output_buf
,
nullptr
,
// mu_buf,
&
rsigma_buf
,
&
amax_out_buf
,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
false
,
// is_layer_norm
true
// is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
RMSNormForwardFP8Handler
,
RMSNormForwardFP8FFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// gamma
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
Error_Type
RMSNormForwardFFI
(
cudaStream_t
stream
,
Buffer_Type
x_buf
,
Buffer_Type
gamma_buf
,
Result_Type
output_buf
,
Result_Type
rsigma_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormForwardImplFFI
(
stream
,
&
x_buf
,
&
gamma_buf
,
nullptr
,
// beta_buf,
nullptr
,
// amax_buf,
nullptr
,
// scale_buf,
nullptr
,
// scale_inv_buf,
&
output_buf
,
nullptr
,
// mu_buf,
&
rsigma_buf
,
nullptr
,
// amax_out_buf,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
false
,
// is_layer_norm
false
// is_fp8
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
RMSNormForwardHandler
,
RMSNormForwardFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// gamma
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// rsigma
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
pybind11
::
tuple
GetLayerNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
pybind11
::
tuple
GetNormBackwardWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
in_dtype
,
DType
w_dtype
,
DType
w_dtype
,
NVTE_Norm_Type
norm_type
,
bool
is_layer_norm
,
bool
zero_centered_gamma
,
bool
zero_centered_gamma
,
int
sm_margin
)
{
float
eps
,
int
sm_margin
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
...
@@ -289,7 +194,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -289,7 +194,7 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
TensorWrapper
dummy_work_tensor
;
TensorWrapper
dummy_work_tensor
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
if
(
is_l
ayer
_n
orm
)
{
if
(
norm_type
==
NVTE_Norm_Type
::
L
ayer
N
orm
)
{
auto
mu_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
intermediates_dtype
);
auto
mu_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
intermediates_dtype
);
auto
dbeta_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
w_dtype
);
auto
dbeta_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
w_dtype
);
...
@@ -309,16 +214,37 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
...
@@ -309,16 +214,37 @@ pybind11::tuple GetLayerNormBackwardWorkspaceSizes(size_t batch_size, size_t hid
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_work_tensor
.
dtype
()));
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_work_tensor
.
dtype
()));
}
}
void
LayerNormBackwardImpl
(
size_t
batch_size
,
size_t
hidden_size
,
size_t
wkspace_size
,
Error_Type
NormBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
dz_buf
,
Buffer_Type
x_buf
,
bool
zero_centered_gamma
,
float
eps
,
void
*
input
,
DType
in_dtype
,
Buffer_Type
mu_buf
,
Buffer_Type
rsigma_buf
,
Buffer_Type
gamma_buf
,
void
*
weight
,
DType
w_dtype
,
void
*
ograd
,
void
*
workspace
,
Result_Type
xgrad_buf
,
Result_Type
wgrad_buf
,
Result_Type
dbeta_buf
,
DType
wkspace_dtype
,
void
*
mu
,
void
*
rsigma
,
void
*
xgrad
,
void
*
wgrad
,
Result_Type
wkspace_buf
,
int64_t
norm_type
,
bool
zero_centered_gamma
,
void
*
dbeta
,
int
sm_margin
,
cudaStream_t
stream
)
{
int64_t
sm_margin
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
.
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
.
element_type
());
auto
wkspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
wkspace_buf
->
element_type
());
auto
*
ograd
=
dz_buf
.
untyped_data
();
auto
*
input
=
x_buf
.
untyped_data
();
void
*
mu
=
mu_buf
.
untyped_data
();
auto
*
rsigma
=
rsigma_buf
.
untyped_data
();
auto
*
gamma
=
gamma_buf
.
untyped_data
();
auto
*
xgrad
=
xgrad_buf
->
untyped_data
();
auto
*
wgrad
=
wgrad_buf
->
untyped_data
();
void
*
dbeta
=
dbeta_buf
->
untyped_data
();
auto
*
workspace
=
wkspace_buf
->
untyped_data
();
auto
x_size
=
product
(
x_buf
.
dimensions
());
auto
gamma_size
=
product
(
gamma_buf
.
dimensions
());
auto
wkspace_size
=
product
(
wkspace_buf
->
dimensions
());
auto
hidden_size
=
gamma_size
;
auto
batch_size
=
x_size
/
gamma_size
;
int
_sm_margin
=
static_cast
<
int
>
(
sm_margin
);
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
weight_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
intermediates_shape
=
std
::
vector
<
size_t
>
{
batch_size
};
auto
intermediates_dtype
=
DType
::
kFloat32
;
auto
intermediates_dtype
=
DType
::
kFloat32
;
auto
is_layer_norm
=
(
dbeta
)
?
true
:
false
;
// assume input type = output type
// assume input type = output type
auto
*
grad_output
=
ograd
;
auto
*
grad_output
=
ograd
;
...
@@ -327,19 +253,18 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
...
@@ -327,19 +253,18 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
auto
rsigma_tensor
=
TensorWrapper
(
rsigma
,
intermediates_shape
,
intermediates_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
rsigma
,
intermediates_shape
,
intermediates_dtype
);
auto
*
x
=
input
;
auto
x_tensor
=
TensorWrapper
(
input
,
input_shape
,
x_dtype
);
auto
x_tensor
=
TensorWrapper
(
x
,
input_shape
,
x_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
weight
,
weight_shape
,
w_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
gamma
,
weight_shape
,
w_dtype
);
auto
xgrad_tensor
=
TensorWrapper
(
xgrad
,
input_shape
,
x_dtype
);
auto
xgrad_tensor
=
TensorWrapper
(
xgrad
,
input_shape
,
x_dtype
);
auto
wgrad_tensor
=
TensorWrapper
(
wgrad
,
weight_shape
,
w_dtype
);
auto
wgrad_tensor
=
TensorWrapper
(
wgrad
,
weight_shape
,
w_dtype
);
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
sm_margin
;
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_
sm_margin
;
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
wkspace_size
};
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
wkspace_size
};
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
wkspace_dtype
);
if
(
is_layer_n
orm
)
{
if
(
static_cast
<
NVTE_Norm_Type
>
(
norm_type
)
==
NVTE_Norm_Type
::
LayerN
orm
)
{
auto
mu_tensor
=
TensorWrapper
(
mu
,
intermediates_shape
,
intermediates_dtype
);
auto
mu_tensor
=
TensorWrapper
(
mu
,
intermediates_shape
,
intermediates_dtype
);
auto
dbeta_tensor
=
TensorWrapper
(
dbeta
,
weight_shape
,
w_dtype
);
auto
dbeta_tensor
=
TensorWrapper
(
dbeta
,
weight_shape
,
w_dtype
);
...
@@ -353,61 +278,11 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
...
@@ -353,61 +278,11 @@ void LayerNormBackwardImpl(size_t batch_size, size_t hidden_size, size_t wkspace
xgrad_tensor
.
data
(),
wgrad_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
xgrad_tensor
.
data
(),
wgrad_tensor
.
data
(),
workspace_tensor
.
data
(),
num_sm
,
zero_centered_gamma
,
stream
);
zero_centered_gamma
,
stream
);
}
}
}
Error_Type
LayerNormBackwardImplFFI
(
cudaStream_t
stream
,
Buffer_Type
*
dz_buf
,
Buffer_Type
*
x_buf
,
Buffer_Type
*
mu_buf
,
Buffer_Type
*
rsigma_buf
,
Buffer_Type
*
gamma_buf
,
Result_Type
*
xgrad_buf
,
Result_Type
*
wgrad_buf
,
Result_Type
*
dbeta_buf
,
Result_Type
*
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
,
bool
is_layer_norm
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
x_buf
->
element_type
());
auto
w_dtype
=
convert_ffi_datatype_to_te_dtype
(
gamma_buf
->
element_type
());
auto
wkspace_dtype
=
convert_ffi_datatype_to_te_dtype
((
*
wkspace_buf
)
->
element_type
());
auto
*
ograd
=
dz_buf
->
untyped_data
();
auto
*
rsigma
=
rsigma_buf
->
untyped_data
();
auto
*
input
=
x_buf
->
untyped_data
();
auto
*
weight
=
gamma_buf
->
untyped_data
();
auto
*
xgrad
=
(
*
xgrad_buf
)
->
untyped_data
();
auto
*
wgrad
=
(
*
wgrad_buf
)
->
untyped_data
();
auto
*
workspace
=
(
*
wkspace_buf
)
->
untyped_data
();
void
*
mu
=
nullptr
;
void
*
dbeta
=
nullptr
;
if
(
is_layer_norm
)
{
mu
=
(
*
mu_buf
).
untyped_data
();
dbeta
=
(
*
dbeta_buf
)
->
untyped_data
();
}
auto
x_size
=
product
(
x_buf
->
dimensions
());
auto
gamma_size
=
product
(
gamma_buf
->
dimensions
());
auto
wkspace_size
=
product
((
*
wkspace_buf
)
->
dimensions
());
auto
hidden_size
=
gamma_size
;
auto
batch_size
=
x_size
/
gamma_size
;
float
eps
=
static_cast
<
float
>
(
eps_
);
int
sm_margin
=
static_cast
<
int
>
(
sm_margin_
);
LayerNormBackwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
ograd
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
xgrad
,
wgrad
,
dbeta
,
sm_margin
,
stream
);
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
Error_Type
LayerNormBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
dz_buf
,
Buffer_Type
x_buf
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
NormBackwardHandler
,
NormBackwardFFI
,
Buffer_Type
mu_buf
,
Buffer_Type
rsigma_buf
,
Buffer_Type
gamma_buf
,
Result_Type
xgrad_buf
,
Result_Type
wgrad_buf
,
Result_Type
dbeta_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormBackwardImplFFI
(
stream
,
&
dz_buf
,
&
x_buf
,
&
mu_buf
,
&
rsigma_buf
,
&
gamma_buf
,
&
xgrad_buf
,
&
wgrad_buf
,
&
dbeta_buf
,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
true
// is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
LayerNormBackwardHandler
,
LayerNormBackwardFFI
,
FFI
::
Bind
()
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// dz
.
Arg
<
Buffer_Type
>
()
// dz
...
@@ -419,220 +294,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
...
@@ -419,220 +294,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(LayerNormBackwardHandler, LayerNormBackwardFFI,
.
Ret
<
Buffer_Type
>
()
// wgrad
.
Ret
<
Buffer_Type
>
()
// wgrad
.
Ret
<
Buffer_Type
>
()
// dbeta
.
Ret
<
Buffer_Type
>
()
// dbeta
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"norm_type"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
Error_Type
RMSNormBackwardFFI
(
cudaStream_t
stream
,
Buffer_Type
dz_buf
,
Buffer_Type
x_buf
,
Buffer_Type
rsigma_buf
,
Buffer_Type
gamma_buf
,
Result_Type
xgrad_buf
,
Result_Type
wgrad_buf
,
Result_Type
wkspace_buf
,
bool
zero_centered_gamma
,
double
eps_
,
int64_t
sm_margin_
)
{
return
LayerNormBackwardImplFFI
(
stream
,
&
dz_buf
,
&
x_buf
,
nullptr
,
// mu_buf
&
rsigma_buf
,
&
gamma_buf
,
&
xgrad_buf
,
&
wgrad_buf
,
nullptr
,
// dbeta_buf,
&
wkspace_buf
,
zero_centered_gamma
,
eps_
,
sm_margin_
,
false
// is_layer_norm
);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
RMSNormBackwardHandler
,
RMSNormBackwardFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// dz
.
Arg
<
Buffer_Type
>
()
// x
.
Arg
<
Buffer_Type
>
()
// rsigma
.
Arg
<
Buffer_Type
>
()
// gamma
.
Ret
<
Buffer_Type
>
()
// xgrad
.
Ret
<
Buffer_Type
>
()
// wgrad
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
bool
>
(
"zero_centered_gamma"
)
.
Attr
<
double
>
(
"eps"
)
.
Attr
<
int64_t
>
(
"sm_margin"
),
FFI_CudaGraph_Traits
);
void
LayerNormForwardFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
weight
=
buffers
[
1
];
auto
*
bias
=
buffers
[
2
];
auto
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
4
]);
auto
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
5
]);
auto
*
output
=
buffers
[
6
];
auto
*
mu
=
buffers
[
7
];
auto
*
rsigma
=
buffers
[
8
];
auto
*
amax_out
=
buffers
[
9
];
auto
*
workspace
=
buffers
[
10
];
NVTE_CHECK
(
amax_out
==
amax
,
"amax not bound to amax_out in TE/JAX LayerNormForwardFP8 primitive"
);
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
auto
out_dtype
=
DType
::
kFloat8E4M3
;
LayerNormForwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
bias
,
output
,
out_dtype
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
amax
,
scale
,
scale_inv
,
sm_margin
,
stream
);
}
void
LayerNormForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
weight
=
buffers
[
1
];
auto
*
bias
=
buffers
[
2
];
auto
*
output
=
buffers
[
3
];
auto
*
mu
=
buffers
[
4
];
auto
*
rsigma
=
buffers
[
5
];
auto
*
workspace
=
buffers
[
6
];
float
*
amax
=
nullptr
;
float
*
scale
=
nullptr
;
float
*
scale_inv
=
nullptr
;
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
out_dtype
=
in_dtype
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
LayerNormForwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
bias
,
output
,
out_dtype
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
amax
,
scale
,
scale_inv
,
sm_margin
,
stream
);
}
void
LayerNormBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
auto
*
ograd
=
buffers
[
0
];
auto
*
mu
=
buffers
[
1
];
auto
*
rsigma
=
buffers
[
2
];
auto
*
input
=
buffers
[
3
];
auto
*
weight
=
buffers
[
4
];
auto
*
xgrad
=
buffers
[
5
];
auto
*
wgrad
=
buffers
[
6
];
auto
*
dbeta
=
buffers
[
7
];
auto
*
workspace
=
buffers
[
8
];
LayerNormBackwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
ograd
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
xgrad
,
wgrad
,
dbeta
,
sm_margin
,
stream
);
}
void
RMSNormForwardFP8
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
weight
=
buffers
[
1
];
auto
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
auto
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
4
]);
auto
*
output
=
buffers
[
5
];
auto
*
rsigma
=
buffers
[
6
];
auto
*
amax_out
=
buffers
[
7
];
auto
*
workspace
=
buffers
[
8
];
NVTE_CHECK
(
amax_out
==
amax
,
"amax not bound to amax_out in TE/JAX RSMNormForwardFP8 primitive."
);
void
*
bias
=
nullptr
;
void
*
mu
=
nullptr
;
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
auto
out_dtype
=
DType
::
kFloat8E4M3
;
LayerNormForwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
bias
,
output
,
out_dtype
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
amax
,
scale
,
scale_inv
,
sm_margin
,
stream
);
}
void
RMSNormForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
weight
=
buffers
[
1
];
auto
*
output
=
buffers
[
2
];
auto
*
rsigma
=
buffers
[
3
];
auto
*
workspace
=
buffers
[
4
];
void
*
bias
=
nullptr
;
void
*
mu
=
nullptr
;
float
*
amax
=
nullptr
;
float
*
scale
=
nullptr
;
float
*
scale_inv
=
nullptr
;
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
auto
out_dtype
=
in_dtype
;
LayerNormForwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
bias
,
output
,
out_dtype
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
amax
,
scale
,
scale_inv
,
sm_margin
,
stream
);
}
void
RMSNormBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
ograd
=
buffers
[
0
];
auto
*
rsigma
=
buffers
[
1
];
auto
*
input
=
buffers
[
2
];
auto
*
weight
=
buffers
[
3
];
auto
*
xgrad
=
buffers
[
4
];
auto
*
wgrad
=
buffers
[
5
];
auto
*
workspace
=
buffers
[
6
];
void
*
mu
=
nullptr
;
void
*
dbeta
=
nullptr
;
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallNormDescriptor
>
(
opaque
,
opaque_len
);
auto
batch_size
=
desc
.
batch_size
;
auto
hidden_size
=
desc
.
hidden_size
;
auto
wkspace_size
=
desc
.
wkspace_size
;
auto
in_dtype
=
desc
.
x_dtype
;
auto
w_dtype
=
desc
.
w_dtype
;
auto
wkspace_dtype
=
desc
.
wkspace_dtype
;
auto
eps
=
desc
.
eps
;
auto
zero_centered_gamma
=
desc
.
zero_centered_gamma
;
auto
sm_margin
=
desc
.
sm_margin
;
LayerNormBackwardImpl
(
batch_size
,
hidden_size
,
wkspace_size
,
zero_centered_gamma
,
eps
,
input
,
in_dtype
,
weight
,
w_dtype
,
ograd
,
workspace
,
wkspace_dtype
,
mu
,
rsigma
,
xgrad
,
wgrad
,
dbeta
,
sm_margin
,
stream
);
}
}
// namespace jax
}
// namespace jax
}
// namespace transformer_engine
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/packing.cpp
deleted
100644 → 0
View file @
fbee8990
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "extensions.h"
namespace
transformer_engine
{
namespace
jax
{
pybind11
::
bytes
PackCustomCallCommonDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
DType
in_dtype
,
DType
out_dtype
,
size_t
act_enum
)
{
CustomCallCommonDescriptor
desc
{};
desc
.
shape
.
from_vector
(
shape
);
desc
.
in_dtype
=
in_dtype
;
desc
.
out_dtype
=
out_dtype
;
desc
.
act_enum
=
act_enum
;
return
PackOpaque
(
desc
);
}
pybind11
::
bytes
PackCustomCallCommonWkDescriptor
(
const
std
::
vector
<
size_t
>
&
shape
,
const
std
::
vector
<
size_t
>
&
wkshape
,
DType
in_dtype
,
DType
out_dtype
,
DType
wk_dtype
,
size_t
act_enum
)
{
CustomCallCommonWkDescriptor
desc
{};
desc
.
shape
.
from_vector
(
shape
);
desc
.
wkshape
.
from_vector
(
wkshape
);
desc
.
in_dtype
=
in_dtype
;
desc
.
out_dtype
=
out_dtype
;
desc
.
wk_dtype
=
wk_dtype
;
desc
.
act_enum
=
act_enum
;
return
PackOpaque
(
desc
);
}
pybind11
::
bytes
PackCustomCallNormDescriptor
(
size_t
batch_size
,
size_t
hidden_size
,
size_t
wkspace_size
,
DType
x_dtype
,
DType
w_dtype
,
DType
wkspace_dtype
,
bool
zero_centered_gamma
,
float
eps
,
int
sm_margin
)
{
CustomCallNormDescriptor
desc
{};
desc
.
batch_size
=
batch_size
;
desc
.
hidden_size
=
hidden_size
;
desc
.
wkspace_size
=
wkspace_size
;
desc
.
x_dtype
=
x_dtype
;
desc
.
w_dtype
=
w_dtype
;
desc
.
wkspace_dtype
=
wkspace_dtype
;
desc
.
zero_centered_gamma
=
zero_centered_gamma
;
desc
.
eps
=
eps
;
desc
.
sm_margin
=
sm_margin
;
return
PackOpaque
(
desc
);
}
pybind11
::
bytes
PackCustomCallSoftmaxDescriptor
(
size_t
batch_size
,
size_t
padding_size
,
size_t
head_dim
,
size_t
q_seqlen
,
size_t
k_seqlen
,
DType
dtype
,
float
scale_factor
)
{
return
PackOpaque
(
SoftmaxDescriptor
{
batch_size
,
padding_size
,
head_dim
,
q_seqlen
,
k_seqlen
,
dtype
,
scale_factor
});
}
pybind11
::
bytes
PackCustomCallFusedAttnDescriptor
(
size_t
input_batch
,
size_t
bias_batch
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
size_t
attn_heads
,
size_t
num_gqa_groups
,
size_t
bias_heads
,
size_t
head_dim
,
size_t
max_segments_per_seq
,
size_t
wkspace_size
,
float
scaling_factor
,
float
dropout_probability
,
NVTE_Bias_Type
bias_type
,
NVTE_Mask_Type
mask_type
,
NVTE_QKV_Layout
qkv_layout
,
DType
dtype
,
DType
wkspace_dtype
,
bool
is_training
,
bool
deterministic
,
int64_t
window_size_left
,
int64_t
window_size_right
)
{
return
PackOpaque
(
CustomCallFusedAttnDescriptor
{
input_batch
,
bias_batch
,
q_max_seqlen
,
kv_max_seqlen
,
attn_heads
,
num_gqa_groups
,
bias_heads
,
head_dim
,
max_segments_per_seq
,
wkspace_size
,
scaling_factor
,
dropout_probability
,
bias_type
,
mask_type
,
qkv_layout
,
dtype
,
wkspace_dtype
,
is_training
,
deterministic
,
window_size_left
,
window_size_right
});
}
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
a207db1d
...
@@ -9,11 +9,6 @@
...
@@ -9,11 +9,6 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
template
<
typename
T
>
pybind11
::
capsule
EncapsulateFunction
(
T
*
fn
)
{
return
pybind11
::
capsule
(
reinterpret_cast
<
void
*>
(
fn
),
"xla._CUSTOM_CALL_TARGET"
);
}
template
<
typename
T
>
template
<
typename
T
>
pybind11
::
capsule
EncapsulateFFI
(
T
*
fn
)
{
pybind11
::
capsule
EncapsulateFFI
(
T
*
fn
)
{
static_assert
(
std
::
is_invocable_r_v
<
XLA_FFI_Error
*
,
T
,
XLA_FFI_CallFrame
*>
,
static_assert
(
std
::
is_invocable_r_v
<
XLA_FFI_Error
*
,
T
,
XLA_FFI_CallFrame
*>
,
...
@@ -23,49 +18,13 @@ pybind11::capsule EncapsulateFFI(T *fn) {
...
@@ -23,49 +18,13 @@ pybind11::capsule EncapsulateFFI(T *fn) {
pybind11
::
dict
Registrations
()
{
pybind11
::
dict
Registrations
()
{
pybind11
::
dict
dict
;
pybind11
::
dict
dict
;
dict
[
"te_transpose"
]
=
EncapsulateFunction
(
Transpose
);
dict
[
"te_cast_transpose"
]
=
EncapsulateFunction
(
CastTranspose
);
dict
[
"te_act_lu"
]
=
EncapsulateFunction
(
ActLu
);
dict
[
"te_act_lu_fp8"
]
=
EncapsulateFunction
(
ActLuFP8
);
dict
[
"te_dact_lu"
]
=
EncapsulateFunction
(
DActLu
);
dict
[
"te_dbias_cast_transpose"
]
=
EncapsulateFunction
(
DBiasCastTranspose
);
dict
[
"te_dact_lu_dbias_cast_transpose"
]
=
EncapsulateFunction
(
DActLuDBiasCastTranspose
);
dict
[
"te_dgated_act_lu_cast_transpose"
]
=
EncapsulateFunction
(
DGatedActLuCastTranspose
);
dict
[
"te_layernorm_forward"
]
=
EncapsulateFunction
(
LayerNormForward
);
dict
[
"te_layernorm_forward_fp8"
]
=
EncapsulateFunction
(
LayerNormForwardFP8
);
dict
[
"te_layernorm_backward"
]
=
EncapsulateFunction
(
LayerNormBackward
);
dict
[
"te_rmsnorm_forward"
]
=
EncapsulateFunction
(
RMSNormForward
);
dict
[
"te_rmsnorm_forward_fp8"
]
=
EncapsulateFunction
(
RMSNormForwardFP8
);
dict
[
"te_rmsnorm_backward"
]
=
EncapsulateFunction
(
RMSNormBackward
);
dict
[
"te_quantize"
]
=
EncapsulateFunction
(
Quantize
);
dict
[
"te_dequantize"
]
=
EncapsulateFunction
(
Dequantize
);
dict
[
"te_scaled_softmax_forward"
]
=
EncapsulateFunction
(
ScaledSoftmaxForward
);
dict
[
"te_scaled_softmax_backward"
]
=
EncapsulateFunction
(
ScaledSoftmaxBackward
);
dict
[
"te_scaled_masked_softmax_forward"
]
=
EncapsulateFunction
(
ScaledMaskedSoftmaxForward
);
dict
[
"te_scaled_masked_softmax_backward"
]
=
EncapsulateFunction
(
ScaledMaskedSoftmaxBackward
);
dict
[
"te_scaled_upper_triang_masked_softmax_forward"
]
=
EncapsulateFunction
(
ScaledUpperTriangMaskedSoftmaxForward
);
dict
[
"te_scaled_upper_triang_masked_softmax_backward"
]
=
EncapsulateFunction
(
ScaledUpperTriangMaskedSoftmaxBackward
);
dict
[
"te_fused_attn_forward"
]
=
EncapsulateFunction
(
FusedAttnForward
);
dict
[
"te_fused_attn_backward"
]
=
EncapsulateFunction
(
FusedAttnBackward
);
// Transpose
dict
[
"te_transpose_ffi"
]
=
EncapsulateFFI
(
TransposeHandler
);
dict
[
"te_cast_transpose_ffi"
]
=
EncapsulateFFI
(
CastTransposeHandler
);
dict
[
"te_dbias_cast_transpose_ffi"
]
=
EncapsulateFFI
(
DBiasCastTransposeHandler
);
// Activation
// Activation
dict
[
"te_act_lu_ffi"
]
=
EncapsulateFFI
(
ActLuHandler
);
dict
[
"te_act_lu_ffi"
]
=
EncapsulateFFI
(
ActLuHandler
);
dict
[
"te_act_lu_fp8_ffi"
]
=
EncapsulateFFI
(
ActLuFP8Handler
);
dict
[
"te_dact_dbias_quantize_ffi"
]
=
EncapsulateFFI
(
DActLuDBiasQuantizeHandler
);
dict
[
"te_dact_lu_ffi"
]
=
EncapsulateFFI
(
DActLuHandler
);
dict
[
"te_dact_lu_dbias_cast_transpose_ffi"
]
=
EncapsulateFFI
(
DActLuDBiasCastTransposeHandler
);
dict
[
"te_dgated_act_lu_cast_transpose_ffi"
]
=
EncapsulateFFI
(
DGatedActLuCastTransposeHandler
);
// Quantization
// Quantization
dict
[
"te_quantize_ffi"
]
=
EncapsulateFFI
(
QuantizeHandler
);
dict
[
"te_
dbias_
quantize_ffi"
]
=
EncapsulateFFI
(
DBias
QuantizeHandler
);
dict
[
"te_dequantize_ffi"
]
=
EncapsulateFFI
(
DequantizeHandler
);
dict
[
"te_dequantize_ffi"
]
=
EncapsulateFFI
(
DequantizeHandler
);
// Softmax
// Softmax
...
@@ -80,58 +39,40 @@ pybind11::dict Registrations() {
...
@@ -80,58 +39,40 @@ pybind11::dict Registrations() {
EncapsulateFFI
(
ScaledUpperTriangMaskedSoftmaxBackwardHandler
);
EncapsulateFFI
(
ScaledUpperTriangMaskedSoftmaxBackwardHandler
);
// Normalization
// Normalization
dict
[
"te_layernorm_forward_ffi"
]
=
dict
[
"te_norm_forward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
LayerNormForwardHandler
));
dict
[
"te_layernorm_forward_fp8_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
LayerNormForwardFP8Handler
));
dict
[
"te_layernorm_backward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
LayerNormBackwardHandler
));
dict
[
"te_rmsnorm_forward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
RMSNormForwardHandler
));
dict
[
"te_rmsnorm_forward_fp8_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
RMS
NormForward
FP8
Handler
));
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormForwardHandler
));
dict
[
"te_
rms
norm_backward_ffi"
]
=
dict
[
"te_norm_backward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
RMS
NormBackwardHandler
));
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
NormBackwardHandler
));
// Attention
// Attention
pybind11
::
dict
fused_attn_forward_ffi
;
dict
[
"te_fused_attn_forward_ffi"
]
=
fused_attn_forward_ffi
[
"prepare"
]
=
EncapsulateFFI
(
CudnnHandleInitHandler
);
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
fused_attn_forward_ffi
[
"execute"
]
=
EncapsulateFFI
(
FusedAttnForwardHandler
);
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
FusedAttnForwardHandler
));
dict
[
"te_fused_attn_forward_ffi"
]
=
fused_attn_forward_ffi
;
dict
[
"te_fused_attn_backward_ffi"
]
=
pybind11
::
dict
(
pybind11
::
arg
(
"prepare"
)
=
EncapsulateFFI
(
CudnnHandleInitHandler
),
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
FusedAttnBackwardHandler
));
pybind11
::
dict
fused_attn_backward_ffi
;
// Grouped GEMM
fused_attn_backward_ffi
[
"prepare"
]
=
EncapsulateFFI
(
CudnnHandleInitHandler
);
dict
[
"te_grouped_gemm_ffi"
]
=
fused_attn_backward_ffi
[
"execut
e"
]
=
EncapsulateFFI
(
FusedAttnBackward
Handler
)
;
pybind11
::
dict
(
pybind11
::
arg
(
"prepar
e"
)
=
EncapsulateFFI
(
CublasHandleInit
Handler
)
,
dict
[
"te_fused_attn_backward_ffi"
]
=
fused_attn_backward_ffi
;
pybind11
::
arg
(
"execute"
)
=
EncapsulateFFI
(
GroupedGemmHandler
))
;
return
dict
;
return
dict
;
}
}
PYBIND11_MODULE
(
transformer_engine_jax
,
m
)
{
PYBIND11_MODULE
(
transformer_engine_jax
,
m
)
{
m
.
def
(
"registrations"
,
&
Registrations
);
m
.
def
(
"registrations"
,
&
Registrations
);
m
.
def
(
"pack_common_descriptor"
,
&
PackCustomCallCommonDescriptor
,
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(
"act_num"
)
=
0
);
m
.
def
(
"pack_common_wk_descriptor"
,
&
PackCustomCallCommonWkDescriptor
,
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(),
pybind11
::
arg
(
"act_num"
)
=
0
);
m
.
def
(
"pack_norm_descriptor"
,
&
PackCustomCallNormDescriptor
);
m
.
def
(
"pack_softmax_descriptor"
,
&
PackCustomCallSoftmaxDescriptor
);
m
.
def
(
"pack_fused_attn_descriptor"
,
&
PackCustomCallFusedAttnDescriptor
);
m
.
def
(
"get_fused_attn_backend"
,
&
GetFusedAttnBackend
);
m
.
def
(
"get_fused_attn_backend"
,
&
GetFusedAttnBackend
);
m
.
def
(
"get_cuda_version"
,
&
GetCudaRuntimeVersion
);
m
.
def
(
"get_cuda_version"
,
&
GetCudaRuntimeVersion
);
m
.
def
(
"get_cudnn_version"
,
&
GetCudnnRuntimeVersion
);
m
.
def
(
"get_cudnn_version"
,
&
GetCudnnRuntimeVersion
);
m
.
def
(
"get_device_compute_capability"
,
&
GetDeviceComputeCapability
);
m
.
def
(
"get_device_compute_capability"
,
&
GetDeviceComputeCapability
);
m
.
def
(
"get_cublasLt_version"
,
&
cublasLtGetVersion
);
m
.
def
(
"get_cublasLt_version"
,
&
cublasLtGetVersion
);
m
.
def
(
"get_dact_dbias_
ct
_workspace_sizes"
,
&
GetDActDBias
CastTranspos
eWorkspaceSizes
);
m
.
def
(
"get_dact_dbias_
quantize
_workspace_sizes"
,
&
GetDActDBias
Quantiz
eWorkspaceSizes
);
m
.
def
(
"get_dbias_
ct
_workspace_sizes"
,
&
GetDBias
CastTranspos
eWorkspaceSizes
);
m
.
def
(
"get_dbias_
quantize
_workspace_sizes"
,
&
GetDBias
Quantiz
eWorkspaceSizes
);
m
.
def
(
"get_
layer
norm_fwd_workspace_sizes"
,
&
Get
Layer
NormForwardWorkspaceSizes
);
m
.
def
(
"get_norm_fwd_workspace_sizes"
,
&
GetNormForwardWorkspaceSizes
);
m
.
def
(
"get_
layer
norm_bwd_workspace_sizes"
,
&
Get
Layer
NormBackwardWorkspaceSizes
);
m
.
def
(
"get_norm_bwd_workspace_sizes"
,
&
GetNormBackwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_fwd_workspace_sizes"
,
&
GetFusedAttnForwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_fwd_workspace_sizes"
,
&
GetFusedAttnForwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_bwd_workspace_sizes"
,
&
GetFusedAttnBackwardWorkspaceSizes
);
m
.
def
(
"get_fused_attn_bwd_workspace_sizes"
,
&
GetFusedAttnBackwardWorkspaceSizes
);
m
.
def
(
"nvte_get_qkv_format"
,
&
nvte_get_qkv_format
);
m
.
def
(
"nvte_get_qkv_format"
,
&
nvte_get_qkv_format
);
...
@@ -191,6 +132,24 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
...
@@ -191,6 +132,24 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.
value
(
"NVTE_F16_max512_seqlen"
,
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
.
value
(
"NVTE_F16_max512_seqlen"
,
NVTE_Fused_Attn_Backend
::
NVTE_F16_max512_seqlen
)
.
value
(
"NVTE_F16_arbitrary_seqlen"
,
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
.
value
(
"NVTE_F16_arbitrary_seqlen"
,
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
.
value
(
"NVTE_FP8"
,
NVTE_Fused_Attn_Backend
::
NVTE_FP8
);
.
value
(
"NVTE_FP8"
,
NVTE_Fused_Attn_Backend
::
NVTE_FP8
);
pybind11
::
enum_
<
NVTE_Norm_Type
>
(
m
,
"NVTE_Norm_Type"
,
pybind11
::
module_local
())
.
value
(
"LayerNorm"
,
NVTE_Norm_Type
::
LayerNorm
)
.
value
(
"RMSNorm"
,
NVTE_Norm_Type
::
RMSNorm
)
.
export_values
();
pybind11
::
enum_
<
NVTEScalingMode
>
(
m
,
"NVTE_Scaling_Mode"
,
pybind11
::
module_local
())
.
value
(
"NVTE_DELAYED_TENSOR_SCALING"
,
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
)
.
value
(
"NVTE_MXFP8_1D_SCALING"
,
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
)
.
value
(
"NVTE_INVALID_SCALING"
,
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
)
.
export_values
();
pybind11
::
enum_
<
transformer_engine
::
jax
::
QuantizeAxis
>
(
m
,
"QuantizeAxis"
,
pybind11
::
module_local
())
.
value
(
"ROWWISE"
,
transformer_engine
::
jax
::
QuantizeAxis
::
ROWWISE
)
.
value
(
"COLWISE"
,
transformer_engine
::
jax
::
QuantizeAxis
::
COLWISE
)
.
value
(
"ROWWISE_COLWISE"
,
transformer_engine
::
jax
::
QuantizeAxis
::
ROWWISE_COLWISE
)
.
export_values
();
}
}
}
// namespace jax
}
// namespace jax
...
...
transformer_engine/jax/csrc/extensions/quantization.cpp
View file @
a207db1d
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
*
*
* See LICENSE for license information.
* See LICENSE for license information.
************************************************************************/
************************************************************************/
#include <cuda_runtime.h>
#include "extensions.h"
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/cast.h"
...
@@ -11,74 +12,131 @@
...
@@ -11,74 +12,131 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
void
Quantize
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
pybind11
::
tuple
GetDBiasQuantizeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
auto
*
input
=
buffers
[
0
];
DType
in_dtype
,
DType
out_dtype
)
{
auto
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
1
]);
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
hidden_size
,
batch_size
};
auto
*
output
=
buffers
[
4
];
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
auto
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
5
]);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX Quantize primitive."
);
// Evil hack to specify TE impl
// Note: nvte_quantize_dbias chooses its internal impl based on what
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
// pointers are allocated, e.g. whether to output with column-wise
auto
shape
=
desc
.
shape
.
to_vector
();
// data. However, we don't have access to any allocated buffers in
auto
input_tensor
=
TensorWrapper
(
input
,
shape
,
desc
.
in_dtype
);
// this function. We pass a dummy pointer as a workaround.
auto
output_tensor
=
TensorWrapper
(
output
,
shape
,
desc
.
out_dtype
,
amax_out
,
scale
,
scale_inv
);
int
temp
=
0
;
nvte_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
output_shape
,
out_dtype
);
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
TensorWrapper
dummy_workspace
;
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
dummy_workspace
.
data
(),
nullptr
);
auto
work_shape
=
MakeShapeVector
(
dummy_workspace
.
shape
());
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_workspace
.
dtype
()));
}
}
Error_Type
QuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Error_Type
DBiasQuantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
amax_out_buf
)
{
Result_Type
scale_inv_buf
,
Result_Type
trans_scale_inv_buf
,
Result_Type
amax_out_buf
,
Result_Type
dbias_buf
,
Result_Type
workspace_buf
,
int64_t
scaling_mode_enum
,
int64_t
quantize_axis_enum
,
bool
is_dbias
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
NVTE_CHECK
(
is_fp8_dtype
(
out_dtype
),
"Output datatype must be FP8 for quantization."
);
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
input
=
input_buf
.
untyped_data
();
auto
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
auto
*
scal
e
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
()
);
auto
scal
ing_mode
=
static_cast
<
NVTEScalingMode
>
(
scaling_mode_enum
);
auto
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
()
);
auto
const
quantize_axis
=
static_cast
<
QuantizeAxis
>
(
quantize_axis_enum
);
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX Quantize primitive."
);
auto
*
dbias
=
dbias_buf
->
untyped_data
();
void
*
workspace
=
workspace_buf
->
untyped_data
();
auto
input_dims
=
input_buf
.
dimensions
();
auto
input_dims
=
input_buf
.
dimensions
();
std
::
vector
<
size_t
>
shape
(
input_dims
.
begin
(),
input_dims
.
end
());
auto
workspace_dims
=
workspace_buf
->
dimensions
();
auto
input_tensor
=
TensorWrapper
(
input
,
shape
,
in_dtype
);
auto
m
=
product
(
input_dims
,
0
,
input_dims
.
size
()
-
1
);
auto
output_tensor
=
TensorWrapper
(
output
,
shape
,
out_dtype
,
amax_out
,
scale
,
scale_inv
);
auto
n
=
input_dims
.
back
();
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
nvte_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
std
::
vector
<
size_t
>
workspace_shape
{
workspace_dims
.
begin
(),
workspace_dims
.
end
()};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
scaling_mode
);
if
(
quantize_axis
==
QuantizeAxis
::
ROWWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
scale_inv_buf
->
dimensions
(),
0
,
scale_inv_buf
->
dimensions
().
size
()
-
1
),
scale_inv_buf
->
dimensions
().
back
()});
}
if
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax_out
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax_out
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax_out
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
if
(
quantize_axis
==
QuantizeAxis
::
COLWISE
||
quantize_axis
==
QuantizeAxis
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
colwise_scale_inv_buf
=
(
scaling_mode
==
NVTE_DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
trans_scale_inv_buf
;
output_tensor
.
set_columnwise_scale_inv
(
colwise_scale_inv_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
colwise_scale_inv_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
product
(
colwise_scale_inv_buf
->
dimensions
(),
0
,
colwise_scale_inv_buf
->
dimensions
().
size
()
-
1
),
colwise_scale_inv_buf
->
dimensions
().
back
()});
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
if
(
is_dbias
)
{
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
}
else
{
nvte_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
}
return
ffi_with_cuda_error_check
();
return
ffi_with_cuda_error_check
();
}
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
QuantizeHandler
,
QuantizeFFI
,
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
DBias
QuantizeHandler
,
DBias
QuantizeFFI
,
FFI
::
Bind
()
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
(),
// amax_out
.
Ret
<
Buffer_Type
>
()
// colwise output
.
Ret
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// scale_inv colwise
.
Ret
<
Buffer_Type
>
()
// amax
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// wkspace
.
Attr
<
int64_t
>
(
"scaling_mode"
)
.
Attr
<
int64_t
>
(
"q_axis"
)
.
Attr
<
bool
>
(
"is_dbias"
),
FFI_CudaGraph_Traits
);
FFI_CudaGraph_Traits
);
void
Dequantize
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
1
]);
auto
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
auto
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
output
=
buffers
[
4
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
auto
shape
=
desc
.
shape
.
to_vector
();
auto
input_tensor
=
TensorWrapper
(
input
,
shape
,
desc
.
in_dtype
,
amax
,
scale
,
scale_inv
);
auto
output_tensor
=
TensorWrapper
(
output
,
shape
,
desc
.
out_dtype
);
nvte_dequantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
}
Error_Type
DequantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Error_Type
DequantizeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
)
{
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
...
...
transformer_engine/jax/csrc/extensions/softmax.cpp
View file @
a207db1d
...
@@ -12,103 +12,6 @@
...
@@ -12,103 +12,6 @@
namespace
transformer_engine
{
namespace
transformer_engine
{
namespace
jax
{
namespace
jax
{
void
ScaledSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
output
=
buffers
[
1
];
const
auto
&
desc
=
*
UnpackOpaque
<
SoftmaxDescriptor
>
(
opaque
,
opaque_len
);
auto
shape
=
std
::
vector
<
size_t
>
{
desc
.
batch_size
,
desc
.
head_dim
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
dtype
=
desc
.
dtype
;
auto
input_tensor
=
TensorWrapper
(
input
,
shape
,
dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
shape
,
dtype
);
nvte_scaled_softmax_forward
(
input_tensor
.
data
(),
output_tensor
.
data
(),
desc
.
scale_factor
,
stream
);
}
void
ScaledSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
grad_output
=
buffers
[
0
];
auto
*
softmax_output
=
buffers
[
1
];
auto
*
dgrad
=
buffers
[
2
];
const
auto
&
desc
=
*
UnpackOpaque
<
SoftmaxDescriptor
>
(
opaque
,
opaque_len
);
auto
shape
=
std
::
vector
<
size_t
>
{
desc
.
batch_size
,
desc
.
head_dim
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
dtype
=
desc
.
dtype
;
auto
grad_output_tensor
=
TensorWrapper
(
grad_output
,
shape
,
dtype
);
auto
softmax_output_tensor
=
TensorWrapper
(
softmax_output
,
shape
,
dtype
);
auto
dgrad_tensor
=
TensorWrapper
(
dgrad
,
shape
,
dtype
);
nvte_scaled_softmax_backward
(
grad_output_tensor
.
data
(),
softmax_output_tensor
.
data
(),
dgrad_tensor
.
data
(),
desc
.
scale_factor
,
stream
);
}
void
ScaledMaskedSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
mask
=
buffers
[
1
];
auto
*
output
=
buffers
[
2
];
const
auto
&
desc
=
*
UnpackOpaque
<
SoftmaxDescriptor
>
(
opaque
,
opaque_len
);
auto
io_shape
=
std
::
vector
<
size_t
>
{
desc
.
batch_size
,
desc
.
head_dim
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
mask_shape
=
std
::
vector
<
size_t
>
{
desc
.
padding_size
,
1
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
dtype
=
desc
.
dtype
;
auto
input_tensor
=
TensorWrapper
(
input
,
io_shape
,
dtype
);
// Mask would be casted to uint8_t
auto
mask_tensor
=
TensorWrapper
(
mask
,
mask_shape
,
DType
::
kByte
);
auto
output_tensor
=
TensorWrapper
(
output
,
io_shape
,
dtype
);
nvte_scaled_masked_softmax_forward
(
input_tensor
.
data
(),
mask_tensor
.
data
(),
output_tensor
.
data
(),
desc
.
scale_factor
,
stream
);
}
void
ScaledMaskedSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
ScaledSoftmaxBackward
(
stream
,
buffers
,
opaque
,
opaque_len
);
}
void
ScaledUpperTriangMaskedSoftmaxForward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
auto
*
output
=
buffers
[
1
];
const
auto
&
desc
=
*
UnpackOpaque
<
SoftmaxDescriptor
>
(
opaque
,
opaque_len
);
auto
attn_batch
=
desc
.
batch_size
*
desc
.
head_dim
;
auto
shape
=
std
::
vector
<
size_t
>
{
attn_batch
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
dtype
=
desc
.
dtype
;
auto
input_tensor
=
TensorWrapper
(
input
,
shape
,
dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
shape
,
dtype
);
nvte_scaled_upper_triang_masked_softmax_forward
(
input_tensor
.
data
(),
output_tensor
.
data
(),
desc
.
scale_factor
,
stream
);
}
void
ScaledUpperTriangMaskedSoftmaxBackward
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
grad_output
=
buffers
[
0
];
auto
*
softmax_output
=
buffers
[
1
];
auto
*
dgrad
=
buffers
[
2
];
const
auto
&
desc
=
*
UnpackOpaque
<
SoftmaxDescriptor
>
(
opaque
,
opaque_len
);
auto
attn_batch
=
desc
.
batch_size
*
desc
.
head_dim
;
auto
shape
=
std
::
vector
<
size_t
>
{
attn_batch
,
desc
.
q_seqlen
,
desc
.
k_seqlen
};
auto
dtype
=
desc
.
dtype
;
auto
grad_output_tensor
=
TensorWrapper
(
grad_output
,
shape
,
dtype
);
auto
softmax_output_tensor
=
TensorWrapper
(
softmax_output
,
shape
,
dtype
);
auto
dgrad_tensor
=
TensorWrapper
(
dgrad
,
shape
,
dtype
);
nvte_scaled_upper_triang_masked_softmax_backward
(
grad_output_tensor
.
data
(),
softmax_output_tensor
.
data
(),
dgrad_tensor
.
data
(),
desc
.
scale_factor
,
stream
);
}
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
#define SOFTMAX_COMMON_BLOCK(tensor_buf) \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
auto tensor_dims = (tensor_buf).dimensions(); \
auto tensor_dims = (tensor_buf).dimensions(); \
...
...
transformer_engine/jax/csrc/extensions/transpose.cpp
deleted
100644 → 0
View file @
fbee8990
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/transpose.h"
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "xla/ffi/api/ffi.h"
namespace
transformer_engine
{
namespace
jax
{
void
TransposeImpl
(
void
*
input
,
size_t
rows
,
size_t
cols
,
DType
dtype
,
cudaStream_t
stream
,
void
*
output
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
rows
,
cols
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
cols
,
rows
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
dtype
);
auto
transposed_tensor
=
TensorWrapper
(
output
,
output_shape
,
dtype
);
nvte_transpose
(
input_tensor
.
data
(),
transposed_tensor
.
data
(),
stream
);
}
void
Transpose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
void
*
input
=
buffers
[
0
];
void
*
output
=
buffers
[
1
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
auto
rows
=
desc
.
shape
.
dims
[
0
];
auto
cols
=
desc
.
shape
.
dims
[
1
];
assert
(
desc
.
in_dtype
==
desc
.
out_dtype
);
auto
dtype
=
desc
.
out_dtype
;
TransposeImpl
(
input
,
rows
,
cols
,
dtype
,
stream
,
output
);
}
Error_Type
TransposeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Result_Type
output_buf
,
int64_t
transpose_axis
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
void
*
input
=
input_buf
.
untyped_data
();
void
*
output
=
output_buf
->
untyped_data
();
auto
input_dims
=
input_buf
.
dimensions
();
if
(
transpose_axis
<
0
)
transpose_axis
+=
input_dims
.
size
();
auto
m
=
product
(
input_dims
,
0
,
transpose_axis
);
auto
n
=
product
(
input_dims
,
transpose_axis
,
input_dims
.
size
());
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
out_dtype
);
nvte_transpose
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
TransposeHandler
,
TransposeFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Ret
<
Buffer_Type
>
()
// output
.
Attr
<
int64_t
>
(
"transpose_axis"
),
FFI_CudaGraph_Traits
);
void
CastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
1
]);
float
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
input_cast
=
buffers
[
4
];
auto
*
input_cast_trans
=
buffers
[
5
];
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
6
]);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX CastTranspose primitive."
);
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonDescriptor
>
(
opaque
,
opaque_len
);
if
(
!
use_fp8
(
desc
.
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
input_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
desc
.
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
input_cast
,
input_shape
,
desc
.
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
input_cast_trans
,
desc
.
out_dtype
,
input_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
nvte_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
}
Error_Type
CastTransposeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
amax_out_buf
,
int64_t
transpose_axis
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
());
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX CastTranspose primitive."
);
if
(
!
use_fp8
(
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
input_dims
=
input_buf
.
dimensions
();
if
(
transpose_axis
<
0
)
transpose_axis
+=
input_dims
.
size
();
auto
m
=
product
(
input_dims
,
0
,
transpose_axis
);
auto
n
=
product
(
input_dims
,
transpose_axis
,
input_dims
.
size
());
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
input_shape
;
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
nvte_quantize
(
input_tensor
.
data
(),
output_tensor
.
data
(),
stream
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
CastTransposeHandler
,
CastTransposeFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output_trans
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Attr
<
int64_t
>
(
"transpose_axis"
),
FFI_CudaGraph_Traits
);
pybind11
::
tuple
GetDBiasCastTransposeWorkspaceSizes
(
size_t
batch_size
,
size_t
hidden_size
,
DType
in_dtype
,
DType
out_dtype
)
{
auto
input_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
batch_size
,
hidden_size
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
hidden_size
,
batch_size
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
hidden_size
};
// Evil hack to specify TE impl
// Note: nvte_quantize_dbias chooses its internal impl based on what
// pointers are allocated, e.g. whether to output with column-wise
// data. However, we don't have access to any allocated buffers in
// this function. We pass a dummy pointer as a workaround.
int
temp
=
0
;
auto
input_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
output_shape
,
out_dtype
);
output_tensor
.
set_columnwise_data
(
reinterpret_cast
<
void
*>
(
&
temp
),
out_dtype
,
output_trans_shape
);
auto
dbias_tensor
=
TensorWrapper
(
reinterpret_cast
<
void
*>
(
&
temp
),
dbias_shape
,
in_dtype
);
TensorWrapper
dummy_workspace
;
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
dummy_workspace
.
data
(),
nullptr
);
auto
work_shape
=
MakeShapeVector
(
dummy_workspace
.
shape
());
return
pybind11
::
make_tuple
(
std
::
make_pair
(
work_shape
,
dummy_workspace
.
dtype
()));
}
void
DBiasCastTranspose
(
cudaStream_t
stream
,
void
**
buffers
,
const
char
*
opaque
,
size_t
opaque_len
)
{
auto
*
input
=
buffers
[
0
];
float
*
amax
=
reinterpret_cast
<
float
*>
(
buffers
[
1
]);
float
*
scale
=
reinterpret_cast
<
float
*>
(
buffers
[
2
]);
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
buffers
[
3
]);
auto
*
output
=
buffers
[
4
];
auto
*
output_trans
=
buffers
[
5
];
auto
*
dbias
=
buffers
[
6
];
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
buffers
[
7
]);
void
*
workspace_ptr
=
buffers
[
8
];
const
auto
&
desc
=
*
UnpackOpaque
<
CustomCallCommonWkDescriptor
>
(
opaque
,
opaque_len
);
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive."
);
if
(
!
use_fp8
(
desc
.
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
m
=
desc
.
shape
.
dims
[
0
];
auto
n
=
desc
.
shape
.
dims
[
1
];
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
desc
.
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
desc
.
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
desc
.
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
desc
.
in_dtype
);
auto
workspace
=
TensorWrapper
(
workspace_ptr
,
desc
.
wkshape
.
to_vector
(),
desc
.
wk_dtype
);
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace
.
data
(),
stream
);
}
Error_Type
DBiasCastTransposeFFI
(
cudaStream_t
stream
,
Buffer_Type
input_buf
,
Buffer_Type
amax_buf
,
Buffer_Type
scale_buf
,
Buffer_Type
scale_inv_buf
,
Result_Type
output_buf
,
Result_Type
output_trans_buf
,
Result_Type
dbias_buf
,
Result_Type
amax_out_buf
,
Result_Type
workspace_buf
,
int64_t
transpose_axis
)
{
auto
in_dtype
=
convert_ffi_datatype_to_te_dtype
(
input_buf
.
element_type
());
auto
out_dtype
=
convert_ffi_datatype_to_te_dtype
(
output_buf
->
element_type
());
auto
workspace_dtype
=
convert_ffi_datatype_to_te_dtype
(
workspace_buf
->
element_type
());
auto
*
input
=
input_buf
.
untyped_data
();
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
.
untyped_data
());
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
scale_inv
=
reinterpret_cast
<
float
*>
(
scale_inv_buf
.
untyped_data
());
auto
*
output
=
output_buf
->
untyped_data
();
auto
*
output_trans
=
output_trans_buf
->
untyped_data
();
auto
*
dbias
=
dbias_buf
->
untyped_data
();
float
*
amax_out
=
reinterpret_cast
<
float
*>
(
amax_out_buf
->
untyped_data
());
void
*
workspace
=
workspace_buf
->
untyped_data
();
NVTE_CHECK
(
amax
==
amax_out
,
"amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive."
);
if
(
!
use_fp8
(
out_dtype
))
{
scale
=
nullptr
;
scale_inv
=
nullptr
;
amax_out
=
nullptr
;
}
auto
input_dims
=
input_buf
.
dimensions
();
auto
workspace_dims
=
workspace_buf
->
dimensions
();
if
(
transpose_axis
<
0
)
transpose_axis
+=
input_dims
.
size
();
auto
m
=
product
(
input_dims
,
0
,
transpose_axis
);
auto
n
=
product
(
input_dims
,
transpose_axis
,
input_dims
.
size
());
auto
input_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_shape
=
std
::
vector
<
size_t
>
{
m
,
n
};
auto
output_trans_shape
=
std
::
vector
<
size_t
>
{
n
,
m
};
auto
dbias_shape
=
std
::
vector
<
size_t
>
{
n
};
std
::
vector
<
size_t
>
workspace_shape
(
workspace_dims
.
begin
(),
workspace_dims
.
end
());
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
output
,
output_shape
,
out_dtype
,
amax_out
,
scale
,
scale_inv
);
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
output_trans_shape
);
output_tensor
.
set_columnwise_scale_inv
(
scale_inv
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
nvte_quantize_dbias
(
input_tensor
.
data
(),
output_tensor
.
data
(),
dbias_tensor
.
data
(),
workspace_tensor
.
data
(),
stream
);
return
ffi_with_cuda_error_check
();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL
(
DBiasCastTransposeHandler
,
DBiasCastTransposeFFI
,
FFI
::
Bind
()
.
Ctx
<
FFI_Stream_Type
>
()
// stream
.
Arg
<
Buffer_Type
>
()
// input
.
Arg
<
Buffer_Type
>
()
// amax
.
Arg
<
Buffer_Type
>
()
// scale
.
Arg
<
Buffer_Type
>
()
// scale_inv
.
Ret
<
Buffer_Type
>
()
// output
.
Ret
<
Buffer_Type
>
()
// output_trans
.
Ret
<
Buffer_Type
>
()
// dbias
.
Ret
<
Buffer_Type
>
()
// amax_out
.
Ret
<
Buffer_Type
>
()
// workspace
.
Attr
<
int64_t
>
(
"transpose_axis"
),
FFI_CudaGraph_Traits
);
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/dense.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Dense layer transformation operations for Transformer Engine in JAX.
This module provides optimized dense layer transformation operations for transformer
architectures, including support for quantization and automatic differentiation.
It implements matrix multiplication with optional bias addition and supports
customizable contracting dimensions for flexible tensor operations.
"""
from
typing
import
Tuple
,
Sequence
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.quantize
import
QuantizerSet
,
noop_quantizer_set
def
dense
(
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
):
"""Perform dense layer transformation with optional quantization.
This function implements matrix multiplication with optional bias addition,
supporting quantization and custom contracting dimensions. It's optimized
for transformer architectures and supports automatic differentiation.
Args:
x: Input tensor
kernel: Weight matrix for the dense layer transformation
bias: Optional bias tensor to add after the transformation
contracting_dims: Tuple of sequences specifying which dimensions to contract
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
# Remove when tex.quantize() can handle quantizer=None
if
quantizer_set
==
noop_quantizer_set
:
output
=
tex
.
gemm
(
x
,
kernel
,
contracting_dims
)
if
bias
is
not
None
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
else
:
output
=
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,))
def
_dense
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
"""Internal implementation of dense layer transformation with custom VJP.
This function implements the core dense layer transformation logic with support
for custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Transformed output tensor
"""
output
,
_
=
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
)
return
output
def
_dense_fwd_rule
(
x
,
kernel
,
bias
,
contracting_dims
,
quantizer_set
):
"""Forward pass rule for dense layer transformation.
Args:
x: Input tensor
kernel: Weight matrix
bias: Optional bias tensor
contracting_dims: Contracting dimensions specification
quantizer_set: QuantizerSet which contains quantizers for different tensor types
Returns:
Tuple of (output, context) for backward pass
"""
x_contracting_dims
,
k_contracting_dims
=
contracting_dims
casted_x
=
tex
.
quantize
(
x
,
quantizer_set
.
x
)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
# GEMM NN
output
=
tex
.
gemm
(
casted_x
.
get_rowwise_tensor
(),
casted_kernel
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
)
use_bias
=
bias
is
not
None
if
use_bias
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
ctx
=
(
casted_x
.
get_colwise_tensor
()
if
quantizer_set
.
x
.
is_2x2x
()
else
None
,
casted_kernel
.
get_rowwise_tensor
()
if
quantizer_set
.
kernel
.
is_2x2x
()
else
None
,
x
.
shape
,
kernel
.
shape
,
use_bias
,
quantizer_set
,
)
return
output
,
ctx
def
_dense_bwd_rule
(
contracting_dims
,
ctx
,
grad
):
# pylint: disable=unused-argument
"""Backward pass rule for dense layer transformation.
Args:
contracting_dims: Contracting dimensions specification
ctx: Context from forward pass
grad: Gradient from upstream
Returns:
Tuple of gradients with respect to inputs
"""
fwd_x_contracting_dims
,
fwd_k_contracting_dims
=
contracting_dims
(
colwise_casted_x
,
rowwise_casted_kernel
,
x_shape
,
kernel_shape
,
use_bias
,
quantizer_set
,
)
=
ctx
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
# GEMM NT
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
fwd_k_contracting_dims
),
grad
.
ndim
)
)
# k_non_contracting_dims
k_constracting_dim
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_shape
))
if
dim
not
in
fwd_k_contracting_dims
)
dgrad
=
tex
.
gemm
(
casted_grad
.
get_rowwise_tensor
(),
rowwise_casted_kernel
,
(
g_constracting_dim
,
k_constracting_dim
),
)
# GEMM TN
# x_non_contracting_dims
g_constracting_dim
=
x_constracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
)
wgrad
=
tex
.
gemm
(
colwise_casted_x
,
casted_grad
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
)
)
return
dgrad
,
wgrad
,
dbias
,
quantizer_set
_dense
.
defvjp
(
_dense_fwd_rule
,
_dense_bwd_rule
)
def
grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
=
None
,
):
"""
Perform grouped_dense layer transformation with optional quantization.
"""
output_list
=
_grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
return
output_list
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,))
def
_grouped_dense
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
output_list
,
_
=
_grouped_dense_fwd_rule
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
)
return
output_list
def
_grouped_dense_fwd_rule
(
x_list
,
kernel_list
,
bias_list
,
contracting_dims_list
,
quantizer_set_list
):
use_bias
=
bias_list
is
not
None
output_list
=
[]
x_rowwise_list
=
[]
x_colwise_list
=
[]
kernel_colwise_list
=
[]
kernel_rowwise_list
=
[]
x_shape_list
=
[]
kernel_shape_list
=
[]
if
quantizer_set_list
is
None
:
x_rowwise_list
=
x_list
x_colwise_list
=
x_list
kernel_colwise_list
=
kernel_list
kernel_rowwise_list
=
kernel_list
x_shape_list
=
[
x
.
shape
for
x
in
x_list
]
kernel_shape_list
=
[
kernel
.
shape
for
kernel
in
kernel_list
]
else
:
for
i
in
range
(
len
(
x_list
)):
# pylint: disable=consider-using-enumerate
q_x
=
tex
.
quantize
(
x_list
[
i
],
quantizer_set_list
[
i
].
x
)
q_kernel
=
tex
.
quantize
(
kernel_list
[
i
],
quantizer_set_list
[
i
].
kernel
)
x_rowwise_list
.
append
(
q_x
.
get_rowwise_tensor
())
x_colwise_list
.
append
(
q_x
.
get_colwise_tensor
())
kernel_colwise_list
.
append
(
q_kernel
.
get_colwise_tensor
())
kernel_rowwise_list
.
append
(
q_kernel
.
get_rowwise_tensor
())
x_shape_list
.
append
(
x_rowwise_list
[
-
1
].
data
.
shape
)
kernel_shape_list
.
append
(
kernel_rowwise_list
[
-
1
].
data
.
shape
)
output_list
=
tex
.
grouped_gemm
(
x_rowwise_list
,
kernel_colwise_list
,
contracting_dims_list
,
bias_list
)
ctx
=
(
x_colwise_list
,
kernel_rowwise_list
,
x_shape_list
,
kernel_shape_list
,
use_bias
,
quantizer_set_list
,
)
return
output_list
,
ctx
def
_grouped_dense_bwd_rule
(
contracting_dims_list
,
ctx
,
grad_list
):
(
colwise_x_list
,
rowwise_kernel_list
,
x_shape_list
,
kernel_shape_list
,
use_bias
,
quantizer_set_list
,
)
=
ctx
group_size
=
len
(
grad_list
)
dbias_list
=
[]
grad_rowwise_list
=
[]
grad_colwise_list
=
[]
dgrad_contracting_dims_list
=
[]
wgrad_contracting_dims_list
=
[]
for
i
in
range
(
group_size
):
grad
=
grad_list
[
i
]
x_shape
=
x_shape_list
[
i
]
kernel_shape
=
kernel_shape_list
[
i
]
fwd_contracting_dims
=
contracting_dims_list
[
i
]
if
quantizer_set_list
is
None
:
casted_grad
=
grad
dbias
=
tex
.
quantization
.
_jax_dbias
(
grad
)
grad_rowwise_list
.
append
(
grad
)
grad_colwise_list
.
append
(
grad
)
else
:
quantizer_set
=
quantizer_set_list
[
i
]
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
grad_rowwise_list
.
append
(
casted_grad
.
get_rowwise_tensor
())
grad_colwise_list
.
append
(
casted_grad
.
get_colwise_tensor
())
dbias_list
.
append
(
dbias
)
# GEMM NT
fwd_x_contracting_dims
,
fwd_k_contracting_dims
=
fwd_contracting_dims
g_contracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
fwd_k_contracting_dims
),
grad
.
ndim
)
)
k_contracting_dim
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_shape
))
if
dim
not
in
fwd_k_contracting_dims
)
dgrad_contracting_dims
=
(
g_contracting_dim
,
k_contracting_dim
)
dgrad_contracting_dims_list
.
append
(
dgrad_contracting_dims
)
# GEMM TN
g_contracting_dim
=
x_contracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
fwd_x_contracting_dims
))
)
wgrad_contracting_dims
=
(
x_contracting_dim
,
g_contracting_dim
)
wgrad_contracting_dims_list
.
append
(
wgrad_contracting_dims
)
dgrad_list
=
tex
.
grouped_gemm
(
grad_rowwise_list
,
rowwise_kernel_list
,
dgrad_contracting_dims_list
)
wgrad_list
=
tex
.
grouped_gemm
(
colwise_x_list
,
grad_colwise_list
,
wgrad_contracting_dims_list
)
return
dgrad_list
,
wgrad_list
,
dbias_list
,
quantizer_set_list
_grouped_dense
.
defvjp
(
_grouped_dense_fwd_rule
,
_grouped_dense_bwd_rule
)
transformer_engine/jax/dot.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX te modules"""
from
typing
import
List
,
Tuple
,
Sequence
from
functools
import
partial
import
jax
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.fp8
import
FP8Helper
,
FP8MetaPackage
Precision
=
jax
.
lax
.
Precision
def
type_safe_dot_general
(
x
,
kernel
,
fp8_meta_pkg
:
FP8MetaPackage
=
None
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]]
=
((
1
,),
(
0
,)),
)
->
jnp
.
ndarray
:
"""
Type safe dot_general, including FP8.
"""
if
fp8_meta_pkg
is
None
:
assert
x
.
dtype
==
kernel
.
dtype
,
f
"lhs dtype =
{
x
.
dtype
}
, rhs dtype =
{
kernel
.
dtype
}
"
return
jax
.
lax
.
dot_general
(
x
,
kernel
,
(
contracting_dims
,
((),
())))
amax_list
=
fp8_meta_pkg
.
amax_list
scale_list
=
fp8_meta_pkg
.
scale_list
fwd_dtype
=
FP8Helper
.
FWD_DTYPE
bwd_dtype
=
FP8Helper
.
BWD_DTYPE
return
_fp8_dot
(
x
,
kernel
,
amax_list
,
scale_list
,
fwd_dtype
,
bwd_dtype
,
contracting_dims
)
def
quantize
(
x
,
q_dtype
,
scale
):
"""
Quantize with scale.
"""
updated_amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
astype
(
scale
.
dtype
)
dtype_max
=
(
jnp
.
finfo
(
q_dtype
).
max
).
astype
(
x
.
dtype
)
scale
=
scale
.
astype
(
x
.
dtype
)
clipped_scaled_x
=
jnp
.
clip
((
x
*
scale
),
-
dtype_max
,
dtype_max
)
return
clipped_scaled_x
.
astype
(
q_dtype
),
updated_amax
def
dequantize
(
x
,
dq_dtype
,
scale_inv
):
"""
Dequantize with scale_inv.
"""
return
x
.
astype
(
dq_dtype
)
*
scale_inv
.
astype
(
dq_dtype
)
# Apply jit to guarantee correctness of FP8 GEMM.
@
partial
(
jax
.
jit
,
static_argnums
=
(
4
,
5
,
6
))
def
fp8_dot_impl
(
q_lhs
:
jnp
.
ndarray
,
q_rhs
:
jnp
.
ndarray
,
lhs_scale_inv
:
jnp
.
ndarray
,
rhs_scale_inv
:
jnp
.
ndarray
,
ctype
:
jnp
.
dtype
,
# computing type
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]],
precision
:
Precision
=
None
,
):
"""
FP8 GEMM for XLA pattern match
"""
dim_nums
=
(
contracting_dims
,
((),
()))
lhs
=
dequantize
(
q_lhs
,
ctype
,
lhs_scale_inv
)
rhs
=
dequantize
(
q_rhs
,
ctype
,
rhs_scale_inv
)
return
jax
.
lax
.
dot_general
(
lhs
,
rhs
,
dim_nums
,
precision
=
precision
)
def
get_precision_of_fp8_dot
(
enable_2xACC
:
bool
):
"""
Get Precision of FP8 DOT.
"""
return
jax
.
lax
.
Precision
.
HIGHEST
if
enable_2xACC
else
jax
.
lax
.
Precision
.
DEFAULT
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
4
,
5
,
6
))
def
_fp8_dot
(
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
amax_list
:
List
[
jnp
.
ndarray
],
scale_list
:
List
[
jnp
.
ndarray
],
fwd_dtype
:
jnp
.
dtype
,
bwd_dtype
:
jnp
.
dtype
,
contracting_dims
:
Tuple
[
Sequence
[
int
],
Sequence
[
int
]],
):
output
,
_
=
_fp8_dot_fwd_rule
(
x
,
kernel
,
amax_list
,
scale_list
,
fwd_dtype
,
bwd_dtype
,
contracting_dims
)
return
output
def
_fp8_dot_fwd_rule
(
x
,
kernel
,
amax_list
,
scale_list
,
fwd_dtype
,
bwd_dtype
,
# pylint: disable=unused-argument
contracting_dims
,
):
maybe_fm32_to_fp32
,
maybe_fp32_to_fm32
=
FP8Helper
.
generate_fp8_meta_dtype_converter_pair
(
*
amax_list
,
*
scale_list
)
amax_list
=
maybe_fm32_to_fp32
(
*
amax_list
)
scale_list
=
maybe_fm32_to_fp32
(
*
scale_list
)
lhs_contracting_dims
,
rhs_contracting_dims
=
contracting_dims
x_shape_suf
=
x
.
shape
[
min
(
lhs_contracting_dims
)
:]
kernel_shape_pre
=
kernel
.
shape
[:
max
(
rhs_contracting_dims
)
+
1
]
assert
x_shape_suf
==
kernel_shape_pre
fp8_dtype_list
=
[
fwd_dtype
,
fwd_dtype
,
bwd_dtype
]
scale_list
,
scale_inv_list
=
FP8MetaPackage
.
update_fp8_scale
(
amax_list
,
scale_list
,
fp8_dtype_list
)
amax_list
=
FP8MetaPackage
.
update_amax_list
(
amax_list
)
x_scale
=
scale_list
[
FP8MetaPackage
.
INPUT_IDX
]
x_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
INPUT_IDX
]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_x
,
updated_x_amax
=
quantize
(
x
,
fwd_dtype
,
x_scale
)
kernel_scale
=
scale_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
kernel_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel
,
updated_kernel_amax
=
quantize
(
kernel
,
fwd_dtype
,
kernel_scale
)
output
=
fp8_dot_impl
(
casted_x
,
casted_kernel
,
x_scale_inv
,
kernel_scale_inv
,
x
.
dtype
,
(
lhs_contracting_dims
,
rhs_contracting_dims
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_FPROP
),
)
ctx
=
(
casted_x
,
casted_kernel
,
amax_list
,
scale_list
,
scale_inv_list
,
updated_x_amax
,
updated_kernel_amax
,
x
.
shape
,
kernel
.
shape
,
maybe_fp32_to_fm32
,
)
return
output
,
ctx
def
_fp8_dot_bwd_rule
(
fwd_dtype
,
bwd_dtype
,
contracting_dims
,
ctx
,
grad
):
# pylint: disable=unused-argument
lhs_contracting_dims
,
rhs_contracting_dims
=
contracting_dims
(
casted_x
,
casted_kernel
,
amax_list
,
scale_list
,
scale_inv_list
,
updated_x_amax
,
updated_kernel_amax
,
x_shape
,
kernel_shape
,
maybe_fp32_to_fm32
,
)
=
ctx
grad_amax
=
amax_list
[
FP8MetaPackage
.
GRAD_IDX
][
0
:
1
]
grad_scale
=
scale_list
[
FP8MetaPackage
.
GRAD_IDX
]
grad_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
GRAD_IDX
]
casted_grad
,
casted_grad_t
,
updated_grad_amax
=
tex
.
cast_transpose
(
grad
,
grad_amax
,
grad_scale
,
grad_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=
min
(
lhs_contracting_dims
),
)
x_constracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
lhs_contracting_dims
)))
gt_constracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
x_constracting_dim
),
grad
.
ndim
))
x_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
INPUT_IDX
]
wgrad
=
fp8_dot_impl
(
casted_x
,
casted_grad_t
,
x_scale_inv
,
grad_scale_inv
,
grad
.
dtype
,
(
x_constracting_dim
,
gt_constracting_dim
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_WGRAD
),
)
g_constracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
rhs_contracting_dims
),
grad
.
ndim
)
)
k_constracting_dim
=
tuple
(
range
(
len
(
rhs_contracting_dims
),
len
(
kernel_shape
)))
kernel_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
dgrad
=
fp8_dot_impl
(
casted_grad
,
casted_kernel
,
grad_scale_inv
,
kernel_scale_inv
,
grad
.
dtype
,
(
g_constracting_dim
,
k_constracting_dim
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_DGRAD
),
)
amax_list
[
FP8MetaPackage
.
INPUT_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
INPUT_IDX
].
at
[
0
].
set
(
updated_x_amax
)
)
amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
].
at
[
0
].
set
(
updated_kernel_amax
)
)
amax_list
[
FP8MetaPackage
.
GRAD_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
GRAD_IDX
].
at
[
0
].
set
(
updated_grad_amax
[
0
])
)
amax_list
=
maybe_fp32_to_fm32
(
*
amax_list
)
scale_list
=
maybe_fp32_to_fm32
(
*
scale_list
)
return
dgrad
,
wgrad
,
amax_list
,
scale_list
_fp8_dot
.
defvjp
(
_fp8_dot_fwd_rule
,
_fp8_dot_bwd_rule
)
transformer_engine/jax/flax/__init__.py
View file @
a207db1d
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
# See LICENSE for license information.
# See LICENSE for license information.
"""Transformer Engine bindings for JAX"""
"""Transformer Engine bindings for JAX"""
from
.module
import
DenseGeneral
,
LayerNorm
from
.module
import
DenseGeneral
,
LayerNorm
from
.module
import
LayerNormDenseGeneral
,
LayerNormMLP
,
TransformerEngineBase
from
.module
import
LayerNormDenseGeneral
,
LayerNormMLP
from
.transformer
import
extend_logical_axis_rules
from
.transformer
import
extend_logical_axis_rules
from
.transformer
import
DotProductAttention
,
MultiHeadAttention
,
RelativePositionBiases
from
.transformer
import
DotProductAttention
,
MultiHeadAttention
,
RelativePositionBiases
from
.transformer
import
TransformerLayer
,
TransformerLayerType
from
.transformer
import
TransformerLayer
,
TransformerLayerType
...
@@ -13,7 +13,6 @@ __all__ = [
...
@@ -13,7 +13,6 @@ __all__ = [
"LayerNorm"
,
"LayerNorm"
,
"LayerNormDenseGeneral"
,
"LayerNormDenseGeneral"
,
"LayerNormMLP"
,
"LayerNormMLP"
,
"TransformerEngineBase"
,
"extend_logical_axis_rules"
,
"extend_logical_axis_rules"
,
"DotProductAttention"
,
"DotProductAttention"
,
"MultiHeadAttention"
,
"MultiHeadAttention"
,
...
...
transformer_engine/jax/flax/module.py
View file @
a207db1d
...
@@ -4,7 +4,7 @@
...
@@ -4,7 +4,7 @@
"""
"""
Wrapper module for Transformer related layers with FP8 support.
Wrapper module for Transformer related layers with FP8 support.
"""
"""
import
functools
from
functools
import
reduce
import
operator
import
operator
from
typing
import
Any
,
Callable
,
Iterable
,
List
,
Sequence
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Iterable
,
List
,
Sequence
,
Tuple
,
Union
...
@@ -17,14 +17,17 @@ from jax import nn as jax_nn
...
@@ -17,14 +17,17 @@ from jax import nn as jax_nn
from
jax
import
random
as
jax_random
from
jax
import
random
as
jax_random
from
jax.ad_checkpoint
import
checkpoint_name
from
jax.ad_checkpoint
import
checkpoint_name
from
..dot
import
type_safe_dot_general
from
..dense
import
dense
from
..fp8
import
FP8Helper
,
FP8MetaPackage
from
..layernorm
import
canonicalize_layernorm_type
from
..layernorm
import
canonicalize_norm_type
from
..layernorm
import
layernorm
,
layernorm_fp8_dot
from
..layernorm
import
layernorm
from
..layernorm_mlp
import
fused_layernorm_fp8_mlp
,
activation_lu
from
..layernorm_dense
import
layernorm_dense
from
..layernorm_mlp
import
layernorm_mlp
from
..activation
import
activation
from
..softmax
import
softmax
,
SoftmaxType
from
..softmax
import
softmax
,
SoftmaxType
from
..sharding
import
with_sharding_constraint_by_logical_axes
from
..sharding
import
with_sharding_constraint_by_logical_axes
from
..cpp_extensions
import
is_softmax_kernel_available
from
..cpp_extensions
import
is_softmax_kernel_available
from
..quantize
import
QuantizerFactory
,
QuantizeConfig
,
QuantizeMeta
,
QuantizeMetaSet
,
ScalingMode
PRNGKey
=
Any
PRNGKey
=
Any
Shape
=
Tuple
[
int
,
...]
Shape
=
Tuple
[
int
,
...]
...
@@ -57,17 +60,24 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
...
@@ -57,17 +60,24 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def
_create_layernorm_parameters
(
def
_create_layernorm_parameters
(
layernorm_type
,
shape
,
scale_init
,
scale_axes
,
bias_init
,
bias_axes
,
input_dtype
,
dtype
norm_type
,
shape
,
scale_init
,
scale_axes
,
bias_init
,
bias_axes
,
input_dtype
,
dtype
,
):
):
scale
=
nn_partitioning
.
param_with_axes
(
"scale"
,
scale_init
,
shape
,
dtype
,
axes
=
scale_axes
)
scale
=
nn_partitioning
.
param_with_axes
(
"scale"
,
scale_init
,
shape
,
dtype
,
axes
=
scale_axes
)
scale
=
scale
.
astype
(
input_dtype
)
scale
=
scale
.
astype
(
input_dtype
)
layer
norm_type
=
canonicalize_
layer
norm_type
(
layer
norm_type
)
norm_type
=
canonicalize_norm_type
(
norm_type
)
if
layer
norm_type
==
"layernorm"
:
if
norm_type
==
"layernorm"
:
bias
=
nn_partitioning
.
param_with_axes
(
"ln_bias"
,
bias_init
,
shape
,
dtype
,
axes
=
bias_axes
)
bias
=
nn_partitioning
.
param_with_axes
(
"ln_bias"
,
bias_init
,
shape
,
dtype
,
axes
=
bias_axes
)
bias
=
bias
.
astype
(
input_dtype
)
bias
=
jnp
.
asarray
(
bias
,
input_dtype
)
else
:
else
:
assert
layer
norm_type
==
"rmsnorm"
assert
norm_type
==
"rmsnorm"
bias
=
None
bias
=
None
return
scale
,
bias
return
scale
,
bias
...
@@ -315,7 +325,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -315,7 +325,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
x
,
x
,
scale
,
scale
,
ln_bias
,
ln_bias
,
layer
norm_type
=
self
.
layernorm_type
,
norm_type
=
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
)
)
...
@@ -328,49 +338,44 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
...
@@ -328,49 +338,44 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method
Base class of transformer engine
Base class of transformer engine
"""
"""
@
staticmethod
def
generate_quantizer_set
(
self
,
postfix
:
str
=
""
):
def
generate_fp8_meta_set
(
postfix
:
str
)
->
FP8MetaPackage
:
"""
"""
Generate a set of FP8 meta for a GEMM.
Generate a set of FP8 meta for a GEMM.
"""
"""
input_name_post_fix
=
f
"_i_
{
postfix
}
"
def
generate_quantize_meta
(
quantizer_name
:
str
):
weight_name_post_fix
=
f
"_w_
{
postfix
}
"
scale
=
self
.
variable
(
grad_name_post_fix
=
f
"_g_
{
postfix
}
"
QuantizeConfig
.
COLLECTION_NAME
,
f
"
{
quantizer_name
}{
postfix
}
_scale"
,
def
generate_a_set
(
target_postfix
):
amax
=
nn_partitioning
.
variable_with_axes
(
FP8Helper
.
FP8_COLLECTION_NAME
,
f
"
{
FP8Helper
.
FP8_AMAX_NAME
}{
target_postfix
}
"
,
jnp
.
zeros
,
(
FP8Helper
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
,
axes
=
(
None
,),
)
scale
=
nn_partitioning
.
variable_with_axes
(
FP8Helper
.
FP8_COLLECTION_NAME
,
f
"
{
FP8Helper
.
FP8_SCALE_NAME
}{
target_postfix
}
"
,
jnp
.
ones
,
jnp
.
ones
,
(
1
,),
(
1
,),
jnp
.
float32
,
jnp
.
float32
,
axes
=
(
None
,),
).
value
)
amax_history
=
self
.
variable
(
QuantizeConfig
.
COLLECTION_NAME
,
return
amax
.
value
,
scale
.
value
f
"
{
quantizer_name
}{
postfix
}
_amax_history"
,
jnp
.
zeros
,
input_amax
,
input_scale
=
generate_a_set
(
input_name_post_fix
)
(
QuantizeConfig
.
AMAX_HISTORY_LEN
,),
weight_amax
,
weight_scale
=
generate_a_set
(
weight_name_post_fix
)
jnp
.
float32
,
grad_amax
,
grad_scale
=
generate_a_set
(
grad_name_post_fix
)
).
value
return
QuantizeMeta
(
scale
=
scale
,
amax_history
=
amax_history
)
if
QuantizeConfig
.
SCALING_MODE
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
x_meta
=
generate_quantize_meta
(
"x"
)
kernel_meta
=
generate_quantize_meta
(
"kernel"
)
grad_meta
=
generate_quantize_meta
(
"grad"
)
quantize_meta_set
=
QuantizeMetaSet
(
x
=
x_meta
,
kernel
=
kernel_meta
,
grad
=
grad_meta
)
kwargs
=
{
"quantize_meta_set"
:
quantize_meta_set
}
else
:
kwargs
=
{}
return
FP8MetaPackage
(
quantizer_set
=
QuantizerFactory
.
create_set
(
**
kwargs
)
input_amax
,
input_scale
,
weight_amax
,
weight_scale
,
grad_amax
,
grad_scale
return
quantizer_set
)
class
DenseGeneral
(
TransformerEngineBase
):
class
DenseGeneral
(
TransformerEngineBase
):
r
"""
r
"""
Applies a
linea
r transformation to the incoming data :math:`y = xA^T + b`.
Applies a
dense laye
r transformation to the incoming data :math:`y = xA^T + b`.
Parameters
Parameters
----------
----------
...
@@ -392,7 +397,7 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -392,7 +397,7 @@ class DenseGeneral(TransformerEngineBase):
The name of axes used to shard bias with a corresponding mesh,
The name of axes used to shard bias with a corresponding mesh,
only used when :attr:`use_bias=True`.
only used when :attr:`use_bias=True`.
enable_low_rank_adaptation: bool, default = False
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each
linear
layer.
Indicate whether to enable low rank adaptation for each
dense
layer.
low_rank_adaptation_dim: int, default = 32
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
:attr:`enable_low_rank_adaptation=True`
...
@@ -435,7 +440,7 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -435,7 +440,7 @@ class DenseGeneral(TransformerEngineBase):
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
:
Array
)
->
Array
:
def
__call__
(
self
,
inputs
:
Array
)
->
Array
:
"""
"""
Apply the
linea
r transformation to the input.
Apply the
dense laye
r transformation to the input.
Parameters
Parameters
----------
----------
...
@@ -455,28 +460,29 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -455,28 +460,29 @@ class DenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
axis
=
_normalize_axes
(
axis
,
inputs
.
ndim
)
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
tuple
(
inputs
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_param_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),)
+
features
kernel
=
nn_partitioning
.
param_with_axes
(
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
)
if
not
FP8Helper
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
if
self
.
use_bias
:
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
)
bias
=
bias
.
astype
(
input_dtype
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
else
:
else
:
bias
=
None
bias
=
None
quantizer_set
=
self
.
generate_quantizer_set
()
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
fp8_meta_pkg
=
None
y
=
dense
(
if
FP8Helper
.
is_fp8_enabled
():
inputs
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
fp8_meta_pkg
=
TransformerEngineBase
.
generate_fp8_meta_set
(
"0"
)
y
=
type_safe_dot_general
(
inputs
,
kernel
,
fp8_meta_pkg
=
fp8_meta_pkg
,
contracting_dims
=
(
axis
,
contract_ind
)
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
...
@@ -486,7 +492,7 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -486,7 +492,7 @@ class DenseGeneral(TransformerEngineBase):
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
lora_a_kernel_init_shape
=
(
lora_a_kernel_init_shape
=
(
kernel_
param
_shape
[
0
],
kernel_
compute
_shape
[
0
],
*
features
[:
-
1
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
...
@@ -521,19 +527,20 @@ class DenseGeneral(TransformerEngineBase):
...
@@ -521,19 +527,20 @@ class DenseGeneral(TransformerEngineBase):
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
y
+=
jnp
.
reshape
(
bias
,
bias_shape
)
assert
y
.
dtype
==
input_dtype
assert
y
.
dtype
==
input_dtype
y
=
y
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
return
y
return
y
class
LayerNormDenseGeneral
(
TransformerEngineBase
):
class
LayerNormDenseGeneral
(
TransformerEngineBase
):
r
"""
r
"""
Applies layer normalization followed by
linea
r transformation to the incoming data.
Applies layer normalization followed by
dense laye
r transformation to the incoming data.
Parameters
Parameters
----------
----------
features : Union[Iterable[int], int]
features : Union[Iterable[int], int]
The hidden size of each output sample.
The hidden size of each output sample.
enable_layernorm: bool, default = True
enable_layernorm: bool, default = True
Indicate whether to enable layer normalization before
linea
r transformation.
Indicate whether to enable layer normalization before
dense laye
r transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
epsilon : float, default = 1e-6
...
@@ -582,7 +589,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -582,7 +589,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
Indicate whether to return the output of layer normalization.
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set False, return None as the second tensor in outputs.
enable_low_rank_adaptation: bool, default = False
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each
linear
layer.
Indicate whether to enable low rank adaptation for each
dense
layer.
low_rank_adaptation_dim: int, default = 32
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`
:attr:`enable_low_rank_adaptation=True`
...
@@ -650,12 +657,13 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -650,12 +657,13 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self
.
scale_init
,
self
.
scale_init
,
self
.
zero_centered_gamma
,
self
.
zero_centered_gamma
,
)
)
self
.
quantizer_set
=
QuantizerFactory
.
create_set
()
super
().
__post_init__
()
super
().
__post_init__
()
@
nn
.
compact
@
nn
.
compact
def
__call__
(
self
,
inputs
:
Array
)
->
Array
:
def
__call__
(
self
,
inputs
:
Array
)
->
Array
:
"""
"""
Apply layer normalization to the input followed by a
linea
r transformation.
Apply layer normalization to the input followed by a
dense laye
r transformation.
Parameters
Parameters
----------
----------
...
@@ -674,8 +682,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -674,8 +682,10 @@ class LayerNormDenseGeneral(TransformerEngineBase):
input_dtype
=
inputs
.
dtype
input_dtype
=
inputs
.
dtype
ln_output
=
None
ln_output
=
None
quantizer_set
=
self
.
generate_quantizer_set
()
fuse_layernorm
=
(
fuse_layernorm
=
(
FP8Helper
.
is_fp8_enabled
()
QuantizeConfig
.
is_fp8_enabled
()
and
not
self
.
return_layernorm_output
and
not
self
.
return_layernorm_output
and
self
.
enable_layernorm
and
self
.
enable_layernorm
)
)
...
@@ -702,7 +712,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -702,7 +712,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs
,
inputs
,
scale
,
scale
,
ln_bias
,
ln_bias
,
layer
norm_type
=
self
.
layernorm_type
,
norm_type
=
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
)
)
...
@@ -722,37 +732,35 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -722,37 +732,35 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_shape
=
tuple
(
y
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_shape
=
tuple
(
y
.
shape
[
ax
]
for
ax
in
axis
)
+
features
kernel_param_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),)
+
features
kernel
=
nn_partitioning
.
param_with_axes
(
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
)
)
if
not
FP8Helper
.
is_fp8_enabled
():
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
kernel
=
kernel
.
astype
(
input_dtype
)
kernel_compute_shape
=
(
reduce
(
operator
.
mul
,
[
inputs
.
shape
[
ax
]
for
ax
in
axis
],
1
),
reduce
(
operator
.
mul
,
features
,
1
),
)
kernel
=
jnp
.
reshape
(
kernel
,
kernel_compute_shape
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
fp8_meta_pkg
=
None
if
FP8Helper
.
is_fp8_enabled
():
fp8_meta_pkg
=
TransformerEngineBase
.
generate_fp8_meta_set
(
"0"
)
if
fuse_layernorm
:
if
fuse_layernorm
:
z
=
layernorm_
fp8_dot
(
z
=
layernorm_
dense
(
y
,
y
,
kernel
,
kernel
,
scale
,
scale
,
ln_bias
,
ln_bias
,
fp8_meta_pkg
,
norm_type
=
self
.
layernorm_type
,
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_input_axes
,
dot_input_axes
=
self
.
dot_input_axes
,
quantizer_set
=
quantizer_set
,
)
)
else
:
else
:
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_input_axes
)
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_input_axes
)
z
=
type_safe_dot_general
(
z
=
dense
(
y
,
kernel
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
quantizer_set
)
y
,
kernel
,
fp8_meta_pkg
=
fp8_meta_pkg
,
contracting_dims
=
(
axis
,
contract_ind
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
lora_a_kernel_shape
=
(
lora_a_kernel_shape
=
(
...
@@ -761,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -761,7 +769,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
lora_a_kernel_init_shape
=
(
lora_a_kernel_init_shape
=
(
kernel_
param
_shape
[
0
],
kernel_
compute
_shape
[
0
],
*
features
[:
-
1
],
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
...
@@ -796,7 +804,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -796,7 +804,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias
=
nn_partitioning
.
param_with_axes
(
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
)
)
bias
=
bias
.
astype
(
input_dtype
)
bias
=
bias
.
reshape
(
kernel_compute_shape
[
-
1
]).
astype
(
input_dtype
)
if
bias
is
not
None
:
if
bias
is
not
None
:
bias_shape
=
(
1
,)
*
(
z
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
bias_shape
=
(
1
,)
*
(
z
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
@@ -805,21 +813,22 @@ class LayerNormDenseGeneral(TransformerEngineBase):
...
@@ -805,21 +813,22 @@ class LayerNormDenseGeneral(TransformerEngineBase):
if
self
.
depth_scaling
is
not
None
:
if
self
.
depth_scaling
is
not
None
:
z
=
z
/
self
.
depth_scaling
z
=
z
/
self
.
depth_scaling
assert
z
.
dtype
==
input_dtype
assert
z
.
dtype
==
input_dtype
,
f
"output_dtype=
{
z
.
dtype
}
, input_dtype=
{
input_dtype
}
"
z
=
z
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
features
)
return
z
,
ln_output
# dense_output, layer_norm_output
return
z
,
ln_output
# dense_output, layer_norm_output
class
LayerNormMLP
(
TransformerEngineBase
):
class
LayerNormMLP
(
TransformerEngineBase
):
r
"""
r
"""
Applies layer normalization on the input followed by the MLP module,
Applies layer normalization on the input followed by the MLP module,
consisting of 2 successive
linea
r transformations, separated by given activations.
consisting of 2 successive
dense laye
r transformations, separated by given activations.
Parameters
Parameters
----------
----------
intermediate_dim: int, default = 2048
intermediate_dim: int, default = 2048
Intermediate size to which input samples are projected.
Intermediate size to which input samples are projected.
enable_layernorm: bool, default = True
enable_layernorm: bool, default = True
Indicate whether to enable layer normalization before
linea
r transformation.
Indicate whether to enable layer normalization before
dense laye
r transformation.
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
layernorm_type : {'layernorm', 'rmsnorm'}, default = 'layernorm'
Indicate the type of layer normalization.
Indicate the type of layer normalization.
epsilon : float, default = 1e-6
epsilon : float, default = 1e-6
...
@@ -851,14 +860,14 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -851,14 +860,14 @@ class LayerNormMLP(TransformerEngineBase):
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
Only used when :attr:`enable_layernorm=True` and :attr:`layernorm_type='layernorm'`.
kernel_init : Initializer, default =
kernel_init : Initializer, default =
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal')
Used for initializing the weights of both
linea
r transformations.
Used for initializing the weights of both
dense laye
r transformations.
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
kernel_axes_1 : Tuple[str, ...], default = ('embed', 'act', 'mlp')
The name of axes used to shard the weights with a corresponding mesh for
The name of axes used to shard the weights with a corresponding mesh for
the weight of the first
linea
r transformation
s
.
the weight of the first
dense laye
r transformation.
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
kernel_axes_2 : Tuple[str, ...], default = ('mlp', 'embed')
The name of axes used to shard the weights with a corresponding mesh for
The name of axes used to shard the weights with a corresponding mesh for
the weight of the second
linea
r transformation
s
.
the weight of the second
dense laye
r transformation.
use_bias: bool, default = False
use_bias: bool, default = False
Indicate whether to enable bias shifting.
Indicate whether to enable bias shifting.
If set to False, the layer will not learn an additive bias.
If set to False, the layer will not learn an additive bias.
...
@@ -867,17 +876,17 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -867,17 +876,17 @@ class LayerNormMLP(TransformerEngineBase):
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype).
bias_axes_1: Tuple[str, ...], default = ('mlp',)
bias_axes_1: Tuple[str, ...], default = ('mlp',)
The name of axes used to shard bias with a corresponding mesh for
The name of axes used to shard bias with a corresponding mesh for
the weight of the first
linea
r transformation
s
.
the weight of the first
dense laye
r transformation.
Only used when :attr:`use_bias=True`.
Only used when :attr:`use_bias=True`.
bias_axes_2: Tuple[str, ...], default = ('embed',)
bias_axes_2: Tuple[str, ...], default = ('embed',)
The name of axes used to shard bias with a corresponding mesh for
The name of axes used to shard bias with a corresponding mesh for
the weight of the second
linea
r transformation
s
.
the weight of the second
dense laye
r transformation.
Only used when :attr:`use_bias=True`.
Only used when :attr:`use_bias=True`.
return_layernorm_output: bool, default = True
return_layernorm_output: bool, default = True
Indicate whether to return the output of layer normalization.
Indicate whether to return the output of layer normalization.
If set False, return None as the second tensor in outputs.
If set False, return None as the second tensor in outputs.
activations: Sequence[Union[str, Callable]], default = ('relu',)
activations: Sequence[Union[str, Callable]], default = ('relu',)
The sequence of activation functions to apply after the first
linea
r transformation.
The sequence of activation functions to apply after the first
dense laye
r transformation.
Each activation has its own transformation layer.
Each activation has its own transformation layer.
intermediate_dropout_rng_name: str, default = 'dropout'
intermediate_dropout_rng_name: str, default = 'dropout'
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks.
...
@@ -886,7 +895,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -886,7 +895,7 @@ class LayerNormMLP(TransformerEngineBase):
intermediate_hidden_dropout_dims: Sequence[int], default = ()
intermediate_hidden_dropout_dims: Sequence[int], default = ()
Dimensions that will share the same dropout mask for hidden
Dimensions that will share the same dropout mask for hidden
enable_low_rank_adaptation: bool, default = False
enable_low_rank_adaptation: bool, default = False
Indicate whether to enable low rank adaptation for each
linear
layer.
Indicate whether to enable low rank adaptation for each
dense
layer.
low_rank_adaptation_dim: int, default = 32
low_rank_adaptation_dim: int, default = 32
The dimension for low rank adaptation, only used when
The dimension for low rank adaptation, only used when
:attr:`enable_low_rank_adaptation=True`.
:attr:`enable_low_rank_adaptation=True`.
...
@@ -980,12 +989,16 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -980,12 +989,16 @@ class LayerNormMLP(TransformerEngineBase):
The output tensors of layer normalization.
The output tensors of layer normalization.
If :attr:`return_layernorm_output=False`, then this would be None.
If :attr:`return_layernorm_output=False`, then this would be None.
"""
"""
ffn1_quantizer_set
=
self
.
generate_quantizer_set
(
"_0"
)
ffn2_quantizer_set
=
self
.
generate_quantizer_set
(
"_1"
)
input_dtype
=
inputs
.
dtype
input_dtype
=
inputs
.
dtype
ln_output
=
None
ln_output
=
None
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm
=
(
fuse_layernorm
=
(
FP8Helper
.
is_fp8_enabled
()
QuantizeConfig
.
is_fp8_enabled
()
and
not
self
.
return_layernorm_output
and
not
self
.
return_layernorm_output
and
self
.
enable_layernorm
and
self
.
enable_layernorm
)
)
...
@@ -1012,7 +1025,6 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1012,7 +1025,6 @@ class LayerNormMLP(TransformerEngineBase):
use_fused_layernorm_mlp
=
(
use_fused_layernorm_mlp
=
(
fuse_layernorm
and
is_act_implemented
and
self
.
intermediate_dropout_rate
<
1e-3
fuse_layernorm
and
is_act_implemented
and
self
.
intermediate_dropout_rate
<
1e-3
)
)
# LayerNorm
# LayerNorm
if
self
.
enable_layernorm
:
if
self
.
enable_layernorm
:
assert
self
.
axis
==
-
1
# Only support axis == -1 at this moment
assert
self
.
axis
==
-
1
# Only support axis == -1 at this moment
...
@@ -1036,7 +1048,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1036,7 +1048,7 @@ class LayerNormMLP(TransformerEngineBase):
inputs
,
inputs
,
scale
,
scale
,
ln_bias
,
ln_bias
,
layer
norm_type
=
self
.
layernorm_type
,
norm_type
=
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
)
)
...
@@ -1056,18 +1068,9 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1056,18 +1068,9 @@ class LayerNormMLP(TransformerEngineBase):
kernels
.
append
(
self
.
kernel_init
(
init_key
,
*
init_args
))
kernels
.
append
(
self
.
kernel_init
(
init_key
,
*
init_args
))
return
jnp
.
stack
(
kernels
,
axis
=
stack_axis
,
dtype
=
self
.
dtype
)
return
jnp
.
stack
(
kernels
,
axis
=
stack_axis
,
dtype
=
self
.
dtype
)
wi_fp8_meta_pkg
=
None
wo_fp8_meta_pkg
=
None
if
FP8Helper
.
is_fp8_enabled
():
wi_fp8_meta_pkg
=
TransformerEngineBase
.
generate_fp8_meta_set
(
"0"
)
wo_fp8_meta_pkg
=
TransformerEngineBase
.
generate_fp8_meta_set
(
"1"
)
num_activations
=
len
(
normalized_acts
)
num_activations
=
len
(
normalized_acts
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
intermediate_dim
=
_canonicalize_tuple
((
num_activations
,
self
.
intermediate_dim
))
kernel_1_shape
=
tuple
(
y
.
shape
[
ax
]
for
ax
in
axis
)
+
intermediate_dim
kernel_1_each_shape
=
(
np
.
prod
([
y
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1_each_shape
=
(
np
.
prod
([
y
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1
=
nn_partitioning
.
param_with_axes
(
kernel_1
=
nn_partitioning
.
param_with_axes
(
"wi_kernel"
,
"wi_kernel"
,
...
@@ -1078,98 +1081,109 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1078,98 +1081,109 @@ class LayerNormMLP(TransformerEngineBase):
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
kernel_axes_1
,
axes
=
self
.
kernel_axes_1
,
)
)
kernel_1
=
jnp
.
reshape
(
kernel_1
,
kernel_1_shape
)
kernel_1_compute_shape
=
(
if
not
FP8Helper
.
is_fp8_enabled
():
reduce
(
operator
.
mul
,
[
y
.
shape
[
ax
]
for
ax
in
axis
],
1
),
num_activations
*
self
.
intermediate_dim
,
)
kernel_1
=
jnp
.
reshape
(
kernel_1
,
kernel_1_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_1
=
kernel_1
.
astype
(
input_dtype
)
kernel_1
=
kernel_1
.
astype
(
input_dtype
)
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
kernel_2_param_shape
=
(
self
.
intermediate_dim
,
np
.
prod
(
hidden_size_tuple
))
kernel_2
=
nn_partitioning
.
param_with_axes
(
kernel_2
=
nn_partitioning
.
param_with_axes
(
"wo_kernel"
,
"wo_kernel"
,
self
.
kernel_init
,
self
.
kernel_init
,
kernel_2_
param_
shape
,
kernel_2_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
self
.
kernel_axes_2
,
axes
=
self
.
kernel_axes_2
,
)
)
kernel_2
=
jnp
.
reshape
(
kernel_2
,
kernel_2_shape
)
kernel_2_compute_shape
=
(
if
not
FP8Helper
.
is_fp8_enabled
():
self
.
intermediate_dim
,
reduce
(
operator
.
mul
,
hidden_size_tuple
,
1
),
)
kernel_2
=
jnp
.
reshape
(
kernel_2
,
kernel_2_compute_shape
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
contract_ind
=
tuple
(
range
(
0
,
len
(
axis
)))
if
self
.
use_bias
:
bias_1_shape
=
num_activations
*
self
.
intermediate_dim
bias_1
=
nn_partitioning
.
param_with_axes
(
"wi_bias"
,
self
.
bias_init
,
bias_1_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
)
bias_1
=
bias_1
.
reshape
(
kernel_1_compute_shape
[
-
1
]).
astype
(
input_dtype
)
bias_2_shape
=
(
hidden_size
,)
bias_2
=
nn_partitioning
.
param_with_axes
(
"wo_bias"
,
self
.
bias_init
,
bias_2_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
)
bias_2
=
bias_2
.
reshape
(
kernel_2_compute_shape
[
-
1
]).
astype
(
input_dtype
)
else
:
bias_1
=
None
bias_2
=
None
ffn1_ckpt_name
=
"ffn1"
ffn1_ckpt_name
=
"ffn1"
ffn2_ckpt_name
=
"ffn2"
ffn2_ckpt_name
=
"ffn2"
if
use_fused_layernorm_mlp
:
if
use_fused_layernorm_mlp
:
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
assert
self
.
axis
==
-
1
# Only support axis = =-1 at this moment
if
self
.
use_bias
:
out
=
layernorm_mlp
(
bias_1_shape
=
intermediate_dim
bias_1
=
nn_partitioning
.
param_with_axes
(
"wi_bias"
,
self
.
bias_init
,
bias_1_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
)
bias_1
=
bias_1
.
astype
(
input_dtype
)
bias_2_shape
=
(
hidden_size
,)
bias_2
=
nn_partitioning
.
param_with_axes
(
"wo_bias"
,
self
.
bias_init
,
bias_2_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
)
bias_2
=
bias_2
.
astype
(
input_dtype
)
else
:
bias_1
=
None
bias_2
=
None
out
=
fused_layernorm_fp8_mlp
(
y
,
y
,
scale
,
scale
,
ln_bias
,
ln_bias
,
[
kernel_1
,
kernel_2
],
[
kernel_1
,
kernel_2
],
[
bias_1
,
bias_2
],
[
bias_1
,
bias_2
],
[
wi_fp8_meta_pkg
,
wo_fp8_meta_pkg
],
self
.
layernorm_type
,
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
layer
norm_input_axes
=
self
.
layernorm_input_axes
,
norm_input_axes
=
self
.
layernorm_input_axes
,
dot_1_input_axes
=
self
.
dot_1_input_axes
,
dot_1_input_axes
=
self
.
dot_1_input_axes
,
dot_2_input_axes
=
self
.
dot_2_input_axes
,
dot_2_input_axes
=
self
.
dot_2_input_axes
,
ffn1_ckpt_name
=
ffn1_ckpt_name
,
ffn1_ckpt_name
=
ffn1_ckpt_name
,
ffn2_ckpt_name
=
ffn2_ckpt_name
,
ffn2_ckpt_name
=
ffn2_ckpt_name
,
activation_type
=
normalized_acts
,
activation_type
=
normalized_acts
,
use_bias
=
self
.
use_bias
,
quantizer_sets
=
(
ffn1_quantizer_set
,
ffn2_quantizer_set
)
,
)
)
out
=
out
.
reshape
(
*
inputs
.
shape
[:
self
.
axis
],
*
hidden_size_tuple
)
else
:
# not use_fused_ln_geglu_mlp
else
:
# not use_fused_ln_geglu_mlp
# DenseGeneral 1
# DenseGeneral 1
if
fuse_layernorm
:
if
fuse_layernorm
:
x
=
layernorm_
fp8_dot
(
x
=
layernorm_
dense
(
y
,
y
,
kernel_1
,
kernel_1
,
scale
,
scale
,
ln_bias
,
ln_bias
,
wi_fp8_meta_pkg
,
norm_type
=
self
.
layernorm_type
,
self
.
layernorm_type
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
zero_centered_gamma
=
self
.
zero_centered_gamma
,
epsilon
=
self
.
epsilon
,
epsilon
=
self
.
epsilon
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
layernorm_input_axes
=
self
.
layernorm_input_axes
,
dot_input_axes
=
self
.
dot_1_input_axes
,
dot_input_axes
=
self
.
dot_1_input_axes
,
quantizer_set
=
ffn1_quantizer_set
,
)
)
else
:
else
:
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_1_input_axes
)
y
=
with_sharding_constraint_by_logical_axes
(
y
,
self
.
dot_1_input_axes
)
x
=
type_safe_dot_general
(
x
=
dense
(
y
,
kernel_1
,
fp8_meta_pkg
=
wi_fp8_meta_pkg
,
contracting_dims
=
(
axis
,
contract_ind
)
y
,
kernel_1
,
contracting_dims
=
(
axis
,
contract_ind
),
quantizer_set
=
ffn1_quantizer_set
,
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
wi_lora_a_kernel_shape
=
(
wi_lora_a_kernel_shape
=
(
*
kernel_1_
shape
[:
len
(
axis
)
],
kernel_1_
compute_shape
[
0
],
num_activations
,
num_activations
,
self
.
low_rank_adaptation_dim
,
self
.
low_rank_adaptation_dim
,
)
)
...
@@ -1187,7 +1201,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1187,7 +1201,7 @@ class LayerNormMLP(TransformerEngineBase):
"wi_lora_a_kernel"
,
"wi_lora_a_kernel"
,
kernel_1_init
,
kernel_1_init
,
num_activations
,
num_activations
,
-
2
,
-
1
,
wi_lora_a_kernel_init_each_shape
,
wi_lora_a_kernel_init_each_shape
,
self
.
dtype
,
self
.
dtype
,
axes
=
wi_lora_a_kernel_axes
,
axes
=
wi_lora_a_kernel_axes
,
...
@@ -1213,37 +1227,25 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1213,37 +1227,25 @@ class LayerNormMLP(TransformerEngineBase):
x
+=
_apply_low_rank_adaptation
(
x
+=
_apply_low_rank_adaptation
(
y
,
y
,
axis
,
axis
,
intermediate_dim
,
num_activations
*
self
.
intermediate_dim
,
wi_lora_a_kernel
,
wi_lora_a_kernel
,
wi_lora_b_kernel
,
wi_lora_b_kernel
,
self
.
low_rank_adaptation_alpha
,
self
.
low_rank_adaptation_alpha
,
)
)
bias_1
=
None
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_1
=
nn_partitioning
.
param_with_axes
(
"wi_bias"
,
self
.
bias_init
,
intermediate_dim
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
)
bias_1_shape
=
(
1
,)
*
(
x
.
ndim
-
bias_1
.
ndim
)
+
bias_1
.
shape
bias_1
=
bias_1
.
astype
(
input_dtype
)
x
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
+=
jnp
.
reshape
(
bias_1
,
bias_1_shape
)
x
=
checkpoint_name
(
x
,
ffn1_ckpt_name
)
x
=
checkpoint_name
(
x
,
ffn1_ckpt_name
)
if
is_act_implemented
:
if
is_act_implemented
:
z
=
activation
_lu
(
x
,
normalized_acts
)
z
=
activation
(
x
,
normalized_acts
)
else
:
else
:
activations
=
[]
activations
=
[]
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
2
)
x
=
jnp
.
split
(
x
,
num_activations
,
axis
=-
1
)
for
idx
,
act_fn
in
enumerate
(
normalized_acts
):
for
idx
,
act_fn
in
enumerate
(
normalized_acts
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
activations
.
append
(
x_i
)
activations
.
append
(
x_i
)
z
=
functools
.
reduce
(
operator
.
mul
,
activations
)
z
=
reduce
(
operator
.
mul
,
activations
)
# Remove act axis
z
=
jnp
.
reshape
(
z
,
(
*
z
.
shape
[:
-
2
],
-
1
))
z
=
z
.
astype
(
input_dtype
)
z
=
z
.
astype
(
input_dtype
)
z
=
nn
.
Dropout
(
z
=
nn
.
Dropout
(
...
@@ -1256,8 +1258,8 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1256,8 +1258,8 @@ class LayerNormMLP(TransformerEngineBase):
z
=
z
.
astype
(
input_dtype
)
z
=
z
.
astype
(
input_dtype
)
# DenseGeneral 2
# DenseGeneral 2
out
=
type_safe_dot_general
(
out
=
dense
(
z
,
kernel_2
,
fp8_meta_pkg
=
wo_fp8_meta_pkg
,
contracting_dims
=
(
axis
,
contract_ind
)
z
,
kernel_2
,
contracting_dims
=
(
axis
,
contract_ind
)
,
quantizer_set
=
ffn2_quantizer_set
)
)
if
self
.
enable_low_rank_adaptation
:
if
self
.
enable_low_rank_adaptation
:
...
@@ -1292,16 +1294,7 @@ class LayerNormMLP(TransformerEngineBase):
...
@@ -1292,16 +1294,7 @@ class LayerNormMLP(TransformerEngineBase):
self
.
low_rank_adaptation_alpha
,
self
.
low_rank_adaptation_alpha
,
)
)
bias_2
=
None
if
self
.
use_bias
:
if
self
.
use_bias
:
bias_2
=
nn_partitioning
.
param_with_axes
(
"wo_bias"
,
self
.
bias_init
,
(
hidden_size
,),
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
)
bias_2
=
bias_2
.
astype
(
input_dtype
)
out
+=
jnp
.
reshape
(
bias_2
,
(
1
,)
*
(
out
.
ndim
-
1
)
+
(
-
1
,))
out
+=
jnp
.
reshape
(
bias_2
,
(
1
,)
*
(
out
.
ndim
-
1
)
+
(
-
1
,))
out
=
checkpoint_name
(
out
,
ffn2_ckpt_name
)
out
=
checkpoint_name
(
out
,
ffn2_ckpt_name
)
...
...
transformer_engine/jax/flax/transformer.py
View file @
a207db1d
...
@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -638,7 +638,9 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods
else
:
else
:
assert
qkv_layout
.
is_separate
()
assert
qkv_layout
.
is_separate
()
assert
sequence_descriptor
is
None
or
isinstance
(
sequence_descriptor
,
jnp
.
ndarray
)
assert
sequence_descriptor
is
None
or
isinstance
(
sequence_descriptor
,
(
jnp
.
ndarray
,
np
.
ndarray
)
)
x
=
_UnfusedDotProductAttention
(
x
=
_UnfusedDotProductAttention
(
attention_dropout
=
self
.
attention_dropout
,
attention_dropout
=
self
.
attention_dropout
,
...
@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -928,7 +930,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
Optimization parameters
Optimization parameters
-----------------------
-----------------------
dtype: jax.numpy.dtype, default = jax.numpy.float32
dtype: jax.numpy.dtype, default = jax.numpy.float32
The data type used
for computation
.
The data type used
to allocate the initial parameters
.
fuse_qkv_params: bool, default = True
fuse_qkv_params: bool, default = True
If set to True, this module exposes a single fused
If set to True, this module exposes a single fused
parameter for query-key-value for self-attention and key-value for
parameter for query-key-value for self-attention and key-value for
...
@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
...
@@ -1788,6 +1790,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
outputs: jax.numpy.ndarray
outputs: jax.numpy.ndarray
Output tensors.
Output tensors.
"""
"""
input_dtype
=
inputs
.
dtype
input_dtype
=
inputs
.
dtype
assert
(
assert
(
self
.
layer_type
in
TransformerLayerType
self
.
layer_type
in
TransformerLayerType
...
...
transformer_engine/jax/fp8.py
deleted
100644 → 0
View file @
fbee8990
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Helper module for fp8 meta management
"""
from
contextlib
import
contextmanager
from
enum
import
Enum
from
functools
import
partial
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
jax
import
jax.numpy
as
jnp
from
flax.core.frozen_dict
import
FrozenDict
from
flax.linen
import
fp8_ops
from
transformer_engine_jax
import
DType
from
transformer_engine_jax
import
get_cublasLt_version
from
transformer_engine_jax
import
(
get_cuda_version
,
get_device_compute_capability
,
)
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
from
transformer_engine.jax.sharding
import
global_shard_guard
from
transformer_engine.jax.sharding
import
MeshResource
_is_fp8_available
=
None
_reason_for_no_fp8
=
""
Collection
=
Union
[
Dict
,
FrozenDict
]
def
_check_fp8_support
(
gpu_id
)
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
if
gpu_arch
>=
90
:
# hopper and above
return
True
,
""
if
gpu_arch
<
89
:
# pre-ada
return
False
,
"Device compute capability 8.9 or higher required for FP8 execution."
if
get_cublasLt_version
()
<
120103
:
return
False
,
"CublasLt version 12.1.3.x or higher required for FP8 execution on Ada."
if
get_cuda_version
()
<
12010
:
return
False
,
"Cuda version 12.1 or higher required for FP8 execution on Ada."
return
True
,
""
def
is_fp8_available
(
gpu_id
=
None
)
->
Tuple
[
bool
,
str
]:
"""Return if fp8 support is available"""
if
gpu_id
is
not
None
:
return
_check_fp8_support
(
gpu_id
)
global
_is_fp8_available
,
_reason_for_no_fp8
if
_is_fp8_available
is
None
:
_is_fp8_available
=
True
# JAX doesn't provide the local GPU id.
for
local_gpu_id
in
range
(
len
(
jax
.
local_devices
())):
ret
,
msg
=
_check_fp8_support
(
local_gpu_id
)
if
ret
is
False
:
_is_fp8_available
=
ret
_reason_for_no_fp8
=
msg
break
return
_is_fp8_available
,
_reason_for_no_fp8
def
_format2dtypes
(
format_
:
Format
):
if
format_
==
Format
.
E4M3
:
return
jnp
.
float8_e4m3fn
,
jnp
.
float8_e4m3fn
if
format_
==
Format
.
E5M2
:
return
jnp
.
float8_e5m2
,
jnp
.
float8_e5m2
if
format_
==
Format
.
HYBRID
:
return
jnp
.
float8_e4m3fn
,
jnp
.
float8_e5m2
return
jnp
.
bfloat16
,
jnp
.
bfloat16
# fm32 is a custom dtype to specify the "add" rules as max operation.
# This is typically used in Pipeline Parallelism + "MiconBatching > 1",
# which is implemented via nn.scan. Without this custom dtype, nn.scan
# would sum gradients from all micro-batches, and this is not the expected
# behavior for FP8 meta. Instead, the summation of FP8 meta gradients should
# be "MAX".
FlaxFloatMeta32
=
fp8_ops
.
fm32
class
FP8MetaPackage
:
"""
A container that contains all required meta data for FP8
"""
NUM_OF_META
:
int
=
3
INPUT_IDX
:
int
=
0
WEIGHT_IDX
:
int
=
1
GRAD_IDX
:
int
=
2
def
__init__
(
self
,
input_amax
:
jnp
.
ndarray
,
input_scale
:
jnp
.
ndarray
,
weight_amax
:
jnp
.
ndarray
,
weight_scale
:
jnp
.
ndarray
,
grad_amax
:
jnp
.
ndarray
,
grad_scale
:
jnp
.
ndarray
,
)
->
None
:
self
.
_amax_list
=
[
None
]
*
FP8MetaPackage
.
NUM_OF_META
self
.
_scale_list
=
[
None
]
*
FP8MetaPackage
.
NUM_OF_META
self
.
_amax_list
[
FP8MetaPackage
.
INPUT_IDX
]
=
input_amax
self
.
_scale_list
[
FP8MetaPackage
.
INPUT_IDX
]
=
input_scale
self
.
_amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
weight_amax
self
.
_scale_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
weight_scale
self
.
_amax_list
[
FP8MetaPackage
.
GRAD_IDX
]
=
grad_amax
self
.
_scale_list
[
FP8MetaPackage
.
GRAD_IDX
]
=
grad_scale
@
property
def
amax_list
(
self
)
->
List
[
jnp
.
ndarray
]:
"""
Get the amax list of this package.
"""
return
self
.
_amax_list
@
property
def
scale_list
(
self
)
->
List
[
jnp
.
ndarray
]:
"""
Get the scale list of this package.
"""
return
self
.
_scale_list
@
staticmethod
def
update_amax_list
(
amax_list
:
List
[
jnp
.
ndarray
])
->
jnp
.
ndarray
:
"""
Update the amax history list
"""
updated_amax_list
=
[
FP8Helper
.
update_amax_history
(
amax
)
for
amax
in
amax_list
]
return
updated_amax_list
@
staticmethod
def
update_fp8_scale
(
amax_list
:
List
[
jnp
.
ndarray
],
scale_list
:
List
[
jnp
.
ndarray
],
fp8_dtype_list
:
List
[
DType
]
)
->
Tuple
[
List
[
jnp
.
ndarray
],
List
[
jnp
.
ndarray
]]:
"""
Get update scale and scale_inv list
"""
update_scale_list
=
[]
update_scale_inv_list
=
[]
for
amax
,
scale
,
fp8_dtype
in
zip
(
amax_list
,
scale_list
,
fp8_dtype_list
):
upadted_scale
,
updated_scale_inv
=
FP8Helper
.
update_fp8_scale
(
amax
,
scale
,
fp8_dtype
)
update_scale_list
.
append
(
upadted_scale
)
update_scale_inv_list
.
append
(
updated_scale_inv
)
return
update_scale_list
,
update_scale_inv_list
class
AmaxComputeAlgo
(
Enum
):
"""AmaxComputeAlgo."""
MAX
=
"max"
MOST_RECENT
=
"most_recent"
NVTE_FP8_COLLECTION_NAME
=
"fp8_metas"
class
FP8Helper
:
"""
FP8 helper to manage the FP8 meta
"""
INITIALIZED
=
False
MARGIN
:
float
=
0.0
FP8_FORMAT
:
Format
=
Format
.
HYBRID
FWD_DTYPE
:
DType
=
_format2dtypes
(
Format
.
HYBRID
)[
0
]
BWD_DTYPE
:
DType
=
_format2dtypes
(
Format
.
HYBRID
)[
1
]
AMAX_HISTORY_LEN
:
int
=
1024
AMAX_COMPUTE_ALGO
:
AmaxComputeAlgo
=
AmaxComputeAlgo
.
MAX
FP8_COLLECTION_NAME
:
str
=
NVTE_FP8_COLLECTION_NAME
FP8_AMAX_NAME
:
str
=
"amax"
FP8_SCALE_NAME
:
str
=
"scale"
FP8_2X_ACC_FPROP
:
bool
=
False
FP8_2X_ACC_DGRAD
:
bool
=
True
FP8_2X_ACC_WGRAD
:
bool
=
True
@
staticmethod
def
is_fp8_enabled
():
"""
Indicate if fp8 training is enable or not.
"""
return
FP8Helper
.
INITIALIZED
@
staticmethod
def
initialize
(
margin
:
float
=
0.0
,
fp8_format
:
Format
=
Format
.
HYBRID
,
amax_history_len
:
int
=
1
,
amax_compute_algo
:
AmaxComputeAlgo
=
AmaxComputeAlgo
.
MAX
,
)
->
None
:
"""
Initialize the FP8 meta
"""
FP8Helper
.
INITIALIZED
=
True
FP8Helper
.
MARGIN
=
margin
FP8Helper
.
FP8_FORMAT
=
fp8_format
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
=
_format2dtypes
(
FP8Helper
.
FP8_FORMAT
)
FP8Helper
.
AMAX_HISTORY_LEN
=
amax_history_len
FP8Helper
.
AMAX_COMPUTE_ALGO
=
amax_compute_algo
FP8Helper
.
FP8_2X_ACC_FPROP
=
False
FP8Helper
.
FP8_2X_ACC_DGRAD
=
True
FP8Helper
.
FP8_2X_ACC_WGRAD
=
True
@
staticmethod
def
finalize
()
->
None
:
"""
FP8 helper finalize
"""
FP8Helper
.
INITIALIZED
=
False
FP8Helper
.
MARGIN
=
0.0
FP8Helper
.
FP8_FORMAT
=
Format
.
HYBRID
FP8Helper
.
FWD_DTYPE
,
FP8Helper
.
BWD_DTYPE
=
_format2dtypes
(
FP8Helper
.
FP8_FORMAT
)
FP8Helper
.
AMAX_HISTORY_LEN
=
1024
FP8Helper
.
AMAX_COMPUTE_ALGO
=
AmaxComputeAlgo
.
MAX
@
staticmethod
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
Collection
:
"""
Update the collections
"""
assert
isinstance
(
original
,
(
dict
,
FrozenDict
))
assert
isinstance
(
new
,
(
dict
,
FrozenDict
))
frozen_original
=
FrozenDict
(
original
)
if
not
isinstance
(
original
,
FrozenDict
)
else
original
for
key
in
new
:
if
key
in
frozen_original
:
frozen_original
,
_
=
frozen_original
.
pop
(
key
)
new_coll
=
FrozenDict
({
**
new
,
**
frozen_original
})
if
not
isinstance
(
original
,
FrozenDict
):
new_coll
=
new_coll
.
unfreeze
()
return
new_coll
@
staticmethod
def
generate_fp8_meta_dtype_converter_pair
(
*
args
):
"""
Generate a pair of conversion fun in-between fm32 and fp32.
"""
def
identical_fun
(
*
metas
):
return
list
(
metas
)
def
fm32_to_fp32_fun
(
*
metas
):
for
meta
in
metas
:
assert
meta
.
dtype
==
FlaxFloatMeta32
return
[
jax
.
lax
.
convert_element_type
(
meta
,
jnp
.
float32
)
for
meta
in
metas
]
def
fp32_to_fm32_fun
(
*
metas
):
for
meta
in
metas
:
assert
meta
.
dtype
==
jnp
.
float32
return
[
jax
.
lax
.
convert_element_type
(
meta
,
FlaxFloatMeta32
)
for
meta
in
metas
]
# Make functions to be a vaild JAX type
partial_identical_fun
=
jax
.
tree_util
.
Partial
(
identical_fun
)
partial_fm32_to_fp32_fun
=
jax
.
tree_util
.
Partial
(
fm32_to_fp32_fun
)
partial_fp32_to_fm32_fun
=
jax
.
tree_util
.
Partial
(
fp32_to_fm32_fun
)
if
len
(
args
)
<
1
:
return
partial_identical_fun
,
partial_identical_fun
original_dtype
=
args
[
0
].
dtype
for
arg
in
args
:
assert
arg
.
dtype
==
original_dtype
if
original_dtype
==
FlaxFloatMeta32
:
return
partial_fm32_to_fp32_fun
,
partial_fp32_to_fm32_fun
return
partial_identical_fun
,
partial_identical_fun
@
staticmethod
@
jax
.
jit
def
update_amax_history
(
amax
:
jnp
.
ndarray
)
->
jnp
.
ndarray
:
"""
Update the amax history
"""
updated_amax
=
jnp
.
roll
(
amax
,
-
1
,
-
1
)
updated_amax
=
updated_amax
.
at
[
0
].
set
(
0
)
return
updated_amax
@
staticmethod
@
partial
(
jax
.
jit
,
static_argnums
=
(
2
,))
def
update_fp8_scale
(
amax
:
jnp
.
ndarray
,
scale
:
jnp
.
ndarray
,
fp8_dtype
:
DType
)
->
jnp
.
ndarray
:
"""
Calculate fp8 scale and scale_inv based on given amax.
"""
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
fp8_dtype
).
max
,
jnp
.
float32
)
if
FP8Helper
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
:
amax
=
jnp
.
max
(
amax
,
axis
=-
1
,
keepdims
=
True
)
else
:
amax
=
amax
[
0
:
1
]
sf
=
(
fp8_max
/
amax
)
/
(
2
**
FP8Helper
.
MARGIN
)
sf
=
jnp
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
jnp
.
where
(
jnp
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
sf
scale_inv
=
1
/
sf
return
scale
,
scale_inv
@
contextmanager
def
fp8_autocast
(
enabled
:
bool
=
False
,
fp8_recipe
:
Optional
[
DelayedScaling
]
=
None
,
mesh_resource
:
Optional
[
MeshResource
]
=
None
,
)
->
None
:
r
"""
Context manager for FP8 usage.
.. code-block:: python
mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)
with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)
with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
rules = extend_logical_axis_rules(tuple())
transformer = TransformerLayer()
with partitioning.axis_rules(rules):
pjit(transformer.init, ...)(...)
.. note::
We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`,
and :attr:`amax_compute_algo` (with value 'max' and 'most_recent') in
recipe.DelayedScaling currently. Other parameters in recipe.DelayedScaling
will trigger an assertion.
Parameters
----------
enabled: bool, default = False
Whether or not to enable fp8
fp8_recipe: recipe.DelayedScaling, default = None
Recipe used for FP8 training.
mesh_resource: MeshResource, default = None
Specify the mesh axes for data and tensor parallelism to shard along.
If set to None, then no data or tensor parallelism will be used.
"""
if
fp8_recipe
is
None
:
fp8_recipe
=
DelayedScaling
()
assert
fp8_recipe
.
amax_compute_algo
in
[
"max"
,
"most_recent"
,
],
"DelayedScaling amax_compute_algo only supports max and most_recent with TE/JAX."
assert
(
fp8_recipe
.
scaling_factor_compute_algo
is
None
),
"DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX."
assert
fp8_recipe
.
reduce_amax
,
"DelayedScaling reduce_amax should be enabled for TE/JAX."
if
mesh_resource
is
None
:
mesh_resource
=
MeshResource
()
try
:
with
global_shard_guard
(
mesh_resource
):
if
enabled
:
fp8_available
,
reason_for_no_fp8
=
is_fp8_available
()
assert
fp8_available
,
reason_for_no_fp8
amax_compute_algo
=
AmaxComputeAlgo
.
MOST_RECENT
if
fp8_recipe
.
amax_compute_algo
==
"max"
:
amax_compute_algo
=
AmaxComputeAlgo
.
MAX
FP8Helper
.
initialize
(
margin
=
fp8_recipe
.
margin
,
fp8_format
=
fp8_recipe
.
fp8_format
,
amax_history_len
=
fp8_recipe
.
amax_history_len
,
amax_compute_algo
=
amax_compute_algo
,
)
yield
finally
:
FP8Helper
.
finalize
()
# Function Wrappers
def
update_collections
(
new
:
Collection
,
original
:
Collection
)
->
FrozenDict
:
r
"""
A helper to update Flax's Collection.
Collection = [dict, flax.core.frozen_dict.FrozenDict]
Parameters
----------
new: Collection
A collection that includes new data.
original: Collection
The base collection.
Returns
-------
outputs : Collection
The updated collection.
"""
return
FP8Helper
.
update_collections
(
new
,
original
)
def
get_delayed_scaling
():
r
"""
Obtain an instance of DelayedScaling which is set via fp8_autocast.
.. note::
We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`
, and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in
recipe.DelayedScaling would be returned as the default values.
Returns
-------
delay_scaling : DelayedScaling
an instance of DelayedScaling which is set via fp8_autocast.
"""
amax_compute_algo
=
(
"max"
if
FP8Helper
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
else
"most_recent"
)
return
DelayedScaling
(
margin
=
int
(
FP8Helper
.
MARGIN
),
fp8_format
=
FP8Helper
.
FP8_FORMAT
,
amax_history_len
=
FP8Helper
.
AMAX_HISTORY_LEN
,
amax_compute_algo
=
amax_compute_algo
,
)
transformer_engine/jax/layernorm.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX layernorm modules"""
"""Layer normalization operations for Transformer Engine in JAX.
This module provides optimized layer normalization operations for transformer
architectures, including support for different normalization types and quantization.
It implements various normalization strategies like LayerNorm and RMSNorm, with
optional zero-centered gamma and epsilon parameters.
"""
from
functools
import
partial
from
functools
import
partial
from
typing
import
List
,
Tuple
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.dot
import
fp8_dot_impl
,
get_precision_of_fp8_dot
from
.fp8
import
FP8Helper
,
FP8MetaPackage
from
.sharding
import
with_sharding_constraint_by_logical_axes
from
.quantize
import
(
ScaledTensor
,
Quantizer
,
)
def
canonicalize_layernorm_type
(
x
):
"""
def
canonicalize_norm_type
(
x
):
Canonicalize the layernorm type
"""Convert normalization type string to canonical form.
Args:
x: Input normalization type string
Returns:
Canonicalized normalization type string
"""
"""
canonicalized
=
x
.
lower
().
strip
().
replace
(
"-"
,
""
).
replace
(
"_"
,
""
)
canonicalized
=
x
.
lower
().
strip
().
replace
(
"-"
,
""
).
replace
(
"_"
,
""
)
assert
canonicalized
in
[
"layernorm"
,
"rmsnorm"
]
assert
canonicalized
in
[
"layernorm"
,
"rmsnorm"
]
...
@@ -25,365 +37,106 @@ def canonicalize_layernorm_type(x):
...
@@ -25,365 +37,106 @@ def canonicalize_layernorm_type(x):
def
layernorm
(
def
layernorm
(
inputs
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
layer
norm_type
:
str
,
norm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
epsilon
:
float
=
1e-6
,
quantizer
:
Quantizer
=
None
,
):
):
"""Apply layer normalization with optional quantization.
This function implements layer normalization with support for different
normalization types and optional quantization. It normalizes the input
tensor using the provided gamma and beta parameters.
Args:
x: Input tensor to normalize
gamma: Scale parameter for normalization
beta: Shift parameter for normalization
norm_type: Type of normalization to apply
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
quantizer: Optional quantizer for quantizing the output
Returns:
Normalized output tensor
"""
"""
LN/RMSNorm wrapper
output
=
_layernorm
(
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
)
Only support layernorm_type in ['layernorm', 'rmsnorm']
"""
output
=
_layernorm
(
inputs
,
gamma
,
beta
,
layernorm_type
=
layernorm_type
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
3
,
4
,
5
))
def
_layernorm
(
def
_layernorm
(
x
,
gamma
,
beta
,
norm_type
:
str
,
zero_centered_gamma
,
epsilon
,
quantizer
):
x
,
gamma
,
beta
,
layernorm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
"""Internal implementation of layer normalization with custom VJP.
):
output
,
_
=
_layernorm_fwd_rule
(
x
,
gamma
,
beta
,
layernorm_type
,
zero_centered_gamma
,
epsilon
)
This function implements the core layer normalization logic with support
return
output
for custom vector-Jacobian product (VJP) for automatic differentiation.
Args:
def
_layernorm_fwd_rule
(
x: Input tensor
x
,
gamma
,
beta
,
layernorm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
gamma: Scale parameter
):
beta: Shift parameter
layernorm_type
=
canonicalize_layernorm_type
(
layernorm_type
)
norm_type: Type of normalization
if
layernorm_type
==
"layernorm"
:
zero_centered_gamma: Whether to use zero-centered gamma
output
,
mu
,
rsigma
=
tex
.
layernorm_fwd
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
)
epsilon: Small constant for numerical stability
elif
layernorm_type
==
"rmsnorm"
:
quantizer: Optional quantizer
assert
(
not
zero_centered_gamma
Returns:
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
Normalized tensor
output
,
rsigma
=
tex
.
rmsnorm_fwd
(
x
,
gamma
,
epsilon
)
mu
=
None
else
:
raise
ValueError
(
f
"
{
layernorm_type
=
}
is not supported."
)
return
output
,
(
x
,
mu
,
rsigma
,
gamma
,
beta
)
def
_layernorm_bwd_rule
(
layernorm_type
,
zero_centered_gamma
,
epsilon
,
ctx
,
dz
):
x
,
mu
,
rsigma
,
gamma
,
beta
=
ctx
if
layernorm_type
==
"layernorm"
:
dx
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
dz
,
x
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
)
elif
layernorm_type
==
"rmsnorm"
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx
,
dgamma
=
tex
.
rmsnorm_bwd
(
dz
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
dbeta
=
None
else
:
raise
ValueError
(
f
"
{
layernorm_type
=
}
is not supported."
)
return
dx
,
dgamma
,
dbeta
_layernorm
.
defvjp
(
_layernorm_fwd_rule
,
_layernorm_bwd_rule
)
def
layernorm_fp8_dot
(
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
fp8_meta_pkg
:
FP8MetaPackage
,
layernorm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
layernorm_input_axes
:
Tuple
[
str
,
...
]
=
None
,
# The logic axes of sharding constraint to the layernorm input.
dot_input_axes
:
Tuple
[
str
,
...
]
=
None
,
# The logic axes of sharding constraint to the dot input.
)
->
jnp
.
ndarray
:
"""
"""
Layernorm + FP8 GEMM
output
,
_
=
_layernorm_fwd_rule
(
"""
x
,
gamma
,
beta
,
norm_type
,
zero_centered_gamma
,
epsilon
,
quantizer
amax_list
=
fp8_meta_pkg
.
amax_list
scale_list
=
fp8_meta_pkg
.
scale_list
fwd_dtype
=
FP8Helper
.
FWD_DTYPE
bwd_dtype
=
FP8Helper
.
BWD_DTYPE
output
=
_layernorm_fp8_dot
(
x
,
kernel
,
gamma
,
beta
,
amax_list
,
scale_list
,
layernorm_type
,
fwd_dtype
,
bwd_dtype
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
)
)
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
6
,
7
,
8
,
9
,
10
,
11
,
12
))
def
_layernorm_fwd_rule
(
x
,
gamma
,
beta
,
norm_type
:
str
,
zero_centered_gamma
,
epsilon
,
quantizer
):
def
_layernorm_fp8_dot
(
"""Forward pass rule for layer normalization.
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
amax_list
:
List
[
jnp
.
ndarray
],
scale_list
:
List
[
jnp
.
ndarray
],
layernorm_type
:
str
,
fwd_dtype
:
jnp
.
dtype
,
bwd_dtype
:
jnp
.
dtype
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
layernorm_input_axes
:
Tuple
[
str
,
...],
dot_input_axes
:
Tuple
[
str
,
...],
):
output
,
_
=
_layernorm_fp8_dot_fwd_rule
(
x
,
kernel
,
gamma
,
beta
,
amax_list
,
scale_list
,
layernorm_type
,
fwd_dtype
,
bwd_dtype
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
)
return
output
def
_layernorm_fp8_dot_fwd_rule
(
x
,
kernel
,
gamma
,
beta
,
amax_list
,
scale_list
,
layernorm_type
,
fwd_dtype
,
bwd_dtype
,
# pylint: disable=unused-argument
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
):
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
Args:
k_contracting_dims
=
(
0
,)
x: Input tensor
assert
x
.
shape
[
-
1
]
==
kernel
.
shape
[
0
]
gamma: Scale parameter
beta: Shift parameter
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
quantizer: Optional quantizer
maybe_fm32_to_fp32
,
maybe_fp32_to_fm32
=
FP8Helper
.
generate_fp8_meta_dtype_converter_pair
(
Returns:
*
amax_list
,
*
scale_list
Tuple of (output, context) for backward pass
)
"""
amax_list
=
maybe_fm32_to_fp32
(
*
amax_list
)
scale_list
=
maybe_fm32_to_fp32
(
*
scale_list
)
fp8_dtype_list
=
[
fwd_dtype
,
fwd_dtype
,
bwd_dtype
]
scale_list
,
scale_inv_list
=
FP8MetaPackage
.
update_fp8_scale
(
amax_list
,
scale_list
,
fp8_dtype_list
)
amax_list
=
FP8MetaPackage
.
update_amax_list
(
amax_list
)
x_amax
=
amax_list
[
FP8MetaPackage
.
INPUT_IDX
][
0
:
1
]
x_scale
=
scale_list
[
FP8MetaPackage
.
INPUT_IDX
]
x_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
INPUT_IDX
]
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
if
layernorm_type
==
"layernorm"
:
ln_out
,
mu
,
rsigma
,
updated_x_amax
=
tex
.
layernorm_fwd_fp8
(
x
,
gamma
,
beta
,
x_amax
,
x_scale
,
x_scale_inv
,
out_dtype
=
fwd_dtype
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
else
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out
,
rsigma
,
updated_x_amax
=
tex
.
rmsnorm_fwd_fp8
(
x
,
gamma
,
x_amax
,
x_scale
,
x_scale_inv
,
out_dtype
=
fwd_dtype
,
epsilon
=
epsilon
)
mu
=
None
assert
x
.
shape
==
ln_out
.
shape
kernel_amax
=
amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
][
0
:
1
]
kernel_scale
=
scale_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
kernel_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
# Kernel in (hidden_in, hidden_out...)
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel
,
updated_kernel_amax
=
tex
.
cast_fp8
(
kernel
,
kernel_amax
,
kernel_scale
,
kernel_scale_inv
,
fwd_dtype
)
ln_out
=
with_sharding_constraint_by_logical_axes
(
ln_out
,
dot_input_axes
)
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output
=
fp8_dot_impl
(
ln_out
,
casted_kernel
,
x_scale_inv
,
kernel_scale_inv
,
x
.
dtype
,
(
x_contracting_dims
,
k_contracting_dims
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_FPROP
),
)
ctx
=
(
norm_type
=
canonicalize_norm_type
(
norm_type
)
ln_out
,
output
,
mu
,
rsigma
=
tex
.
normalization_fwd
(
casted_kernel
,
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
norm_type
,
quantizer
amax_list
,
scale_list
,
scale_inv_list
,
updated_x_amax
,
updated_kernel_amax
,
x
.
shape
,
kernel
.
shape
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
x_contracting_dims
,
k_contracting_dims
,
maybe_fp32_to_fm32
,
)
)
if
isinstance
(
output
,
ScaledTensor
):
output
=
output
.
dequantize
()
return
output
,
ctx
return
output
,
(
x
,
mu
,
rsigma
,
gamma
,
beta
,
quantizer
)
def
_layernorm_fp8_dot_bwd_rule
(
def
_layernorm_bwd_rule
(
norm_type
,
zero_centered_gamma
,
epsilon
,
ctx
,
dz
):
layernorm_type
,
"""Backward pass rule for layer normalization.
fwd_dtype
,
# pylint: disable=unused-argument
bwd_dtype
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
# pylint: disable=unused-argument
ctx
,
grad
,
):
(
ln_out_
,
casted_kernel
,
amax_list
,
scale_list
,
scale_inv_list
,
updated_x_amax
,
updated_kernel_amax
,
x_shape
,
kernel_shape
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
x_contracting_dims
,
k_contracting_dims
,
maybe_fp32_to_fm32
,
)
=
ctx
ln_out_t
=
tex
.
transpose
(
ln_out_
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
)
grad_amax
=
amax_list
[
FP8MetaPackage
.
GRAD_IDX
][
0
:
1
]
grad_scale
=
scale_list
[
FP8MetaPackage
.
GRAD_IDX
]
grad_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
GRAD_IDX
]
casted_grad
,
casted_grad_t
,
updated_grad_amax
=
tex
.
cast_transpose
(
grad
,
grad_amax
,
grad_scale
,
grad_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=
min
(
x_contracting_dims
),
)
xt_constracting_dim
=
tuple
(
range
(
len
(
x_contracting_dims
),
len
(
x_shape
)))
Args:
gt_constracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
xt_constracting_dim
),
grad
.
ndim
))
norm_type: Type of normalization
x_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
INPUT_IDX
]
zero_centered_gamma: Whether to use zero-centered gamma
wgrad
=
fp8_dot_impl
(
epsilon: Small constant for numerical stability
ln_out_t
,
ctx: Context from forward pass
casted_grad_t
,
dz: Gradient from upstream
x_scale_inv
,
grad_scale_inv
,
grad
.
dtype
,
(
xt_constracting_dim
,
gt_constracting_dim
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_WGRAD
),
)
g_for_dgrad_constracting_dim
=
tuple
(
Returns:
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
k_contracting_dims
),
grad
.
ndim
)
Tuple of gradients with respect to inputs
)
"""
k_constracting_dim
=
tuple
(
range
(
len
(
k_contracting_dims
),
len
(
kernel_shape
)))
x
,
mu
,
rsigma
,
gamma
,
beta
,
quantizer
=
ctx
kernel_scale_inv
=
scale_inv_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
dgrad
=
fp8_dot_impl
(
casted_grad
,
casted_kernel
,
grad_scale_inv
,
kernel_scale_inv
,
grad
.
dtype
,
(
g_for_dgrad_constracting_dim
,
k_constracting_dim
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_DGRAD
),
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
layernorm_input_axes
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
if
layernorm_type
==
"layernorm"
:
dz
,
x
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
norm_type
dx
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
dgrad
,
x
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
else
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
dbeta
=
None
amax_list
[
FP8MetaPackage
.
INPUT_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
INPUT_IDX
].
at
[
0
].
set
(
updated_x_amax
[
0
])
)
amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
WEIGHT_IDX
].
at
[
0
].
set
(
updated_kernel_amax
[
0
])
)
amax_list
[
FP8MetaPackage
.
GRAD_IDX
]
=
(
amax_list
[
FP8MetaPackage
.
GRAD_IDX
].
at
[
0
].
set
(
updated_grad_amax
[
0
])
)
)
return
dx
,
dgamma
,
dbeta
,
quantizer
amax_list
=
maybe_fp32_to_fm32
(
*
amax_list
)
scale_list
=
maybe_fp32_to_fm32
(
*
scale_list
)
return
dx
,
wgrad
,
dgamma
,
dbeta
,
amax_list
,
scale_list
_layernorm
.
defvjp
(
_layernorm_fwd_rule
,
_layernorm_bwd_rule
)
_layernorm_fp8_dot
.
defvjp
(
_layernorm_fp8_dot_fwd_rule
,
_layernorm_fp8_dot_bwd_rule
)
transformer_engine/jax/layernorm_dense.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Fused Layer normalization and dense layer transformation operations for Transformer Engine in JAX.
This module provides optimized implementations of layer normalization followed by
dense layer transformation (GEMM) operations, which are commonly used in transformer
architectures. It supports various normalization types, quantization, and
distributed training through sharding constraints.
"""
from
functools
import
partial
from
typing
import
Tuple
import
jax
import
jax.numpy
as
jnp
from
.
import
cpp_extensions
as
tex
from
.quantize
import
(
QuantizerSet
,
noop_quantizer_set
,
with_sharding_constraint_by_logical_axes
,
)
def
layernorm_dense
(
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
=
None
,
norm_type
:
str
=
"layernorm"
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
# The logic axes of sharding constraint to the layernorm input.
layernorm_input_axes
:
Tuple
[
str
,
...]
=
None
,
# The logic axes of sharding constraint to the dot input.
dot_input_axes
:
Tuple
[
str
,
...]
=
None
,
quantizer_set
:
QuantizerSet
=
noop_quantizer_set
,
)
->
jnp
.
ndarray
:
"""Apply layer normalization followed by dense layer transformation.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. Linear transformation: y = x * kernel + bias
Args:
x: Input tensor with shape [batch..., hidden_in]
kernel: Weight matrix with shape [hidden_in, hidden_out]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
bias: Optional bias term for dense layer transformation with shape [hidden_out]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
layernorm_input_axes: Logical axes for sharding the layernorm input
dot_input_axes: Logical axes for sharding the matrix multiplication input
quantizer_set: Set of quantizers for different tensor types
Returns:
Output tensor with shape [batch..., hidden_out]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both the normalized input and kernel
"""
output
=
_layernorm_dense
(
x
,
kernel
,
gamma
,
beta
,
bias
,
norm_type
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
quantizer_set
,
)
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
5
,
6
,
7
,
8
,
9
,
),
)
def
_layernorm_dense
(
x
:
jnp
.
ndarray
,
kernel
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
bias
:
jnp
.
ndarray
,
norm_type
:
str
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
layernorm_input_axes
:
Tuple
[
str
,
...],
dot_input_axes
:
Tuple
[
str
,
...],
quantizer_set
,
):
"""Internal implementation of layernorm_dense with custom VJP.
This function implements the forward pass of layernorm_dense with support for
automatic differentiation. It handles the normalization and dense layer transformation
operations, including quantization and sharding constraints.
Args:
x: Input tensor
kernel: Weight matrix
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
bias: Optional bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
layernorm_input_axes: Logical axes for layernorm sharding
dot_input_axes: Logical axes for matrix multiplication sharding
quantizer_set: Set of quantizers
Returns:
Output tensor from the combined operations
"""
output
,
_
=
_layernorm_dense_fwd_rule
(
x
,
kernel
,
gamma
,
beta
,
bias
,
norm_type
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
quantizer_set
,
)
return
output
def
_layernorm_dense_fwd_rule
(
x
,
kernel
,
gamma
,
beta
,
bias
,
norm_type
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
quantizer_set
,
):
"""Forward pass rule for layernorm_dense.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. Matrix multiplication with quantized kernel
3. Optional bias addition
4. Sharding constraints
Returns:
Tuple of (output, context) for automatic differentiation
"""
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
k_contracting_dims
=
(
0
,)
assert
x
.
shape
[
-
1
]
==
kernel
.
shape
[
0
]
assert
len
(
kernel
.
shape
)
==
2
# Otherwise need to merge dims in quantize
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
casted_ln_out
,
mu
,
rsigma
=
tex
.
normalization_fwd
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
norm_type
,
quantizer_set
.
x
,
)
# Kernel in (hidden_in, hidden_out...)
casted_kernel
=
tex
.
quantize
(
kernel
,
quantizer_set
.
kernel
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_input_axes
)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out...)
output
=
tex
.
gemm
(
casted_ln_out
.
get_rowwise_tensor
(),
casted_kernel
.
get_colwise_tensor
(),
(
x_contracting_dims
,
k_contracting_dims
),
)
use_bias
=
bias
is
not
None
if
use_bias
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
output
+=
jnp
.
reshape
(
bias
,
bias_new_shape
)
ctx
=
(
casted_ln_out
.
get_colwise_tensor
()
if
quantizer_set
.
x
.
is_2x2x
()
else
None
,
casted_kernel
.
get_rowwise_tensor
()
if
quantizer_set
.
kernel
.
is_2x2x
()
else
None
,
x
.
shape
,
kernel
.
shape
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
x_contracting_dims
,
k_contracting_dims
,
use_bias
,
quantizer_set
,
)
return
output
,
ctx
def
_layernorm_dense_bwd_rule
(
norm_type
,
zero_centered_gamma
,
epsilon
,
layernorm_input_axes
,
dot_input_axes
,
# pylint: disable=unused-argument
ctx
,
grad
,
):
"""Backward pass rule for layernorm_dense.
Implements the backward pass computation including:
1. Gradient computation for matrix multiplication
2. Gradient computation for layer normalization
3. Gradient computation for bias terms
4. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
(
colwise_casted_ln_out
,
rowwise_casted_kernel
,
x_shape
,
kernel_shape
,
mu
,
rsigma
,
x
,
gamma
,
beta
,
x_contracting_dims_in_fwd
,
k_contracting_dims_in_fwd
,
use_bias
,
quantizer_set
,
)
=
ctx
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_input_axes
)
casted_grad
,
dbias
=
tex
.
quantize_dbias
(
grad
,
is_dbias
=
use_bias
,
quantizer
=
quantizer_set
.
dgrad
)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim
g_constracting_dim
=
tuple
(
range
(
grad
.
ndim
-
len
(
kernel_shape
)
+
len
(
k_contracting_dims_in_fwd
),
grad
.
ndim
)
)
# k_non_contracting_dims
k_constracting_dim
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
# NT GEMM
dgrad
=
tex
.
gemm
(
casted_grad
.
get_rowwise_tensor
(),
rowwise_casted_kernel
,
(
g_constracting_dim
,
k_constracting_dim
),
)
dgrad
=
with_sharding_constraint_by_logical_axes
(
dgrad
,
layernorm_input_axes
)
g_constracting_dim
=
x_constracting_dim
=
tuple
(
range
(
0
,
len
(
x_shape
)
-
len
(
x_contracting_dims_in_fwd
))
)
# TN GEMM
wgrad
=
tex
.
gemm
(
colwise_casted_ln_out
,
casted_grad
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
),
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
dgrad
,
x
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
norm_type
=
norm_type
,
)
return
dx
,
wgrad
,
dgamma
,
dbeta
,
dbias
,
quantizer_set
_layernorm_dense
.
defvjp
(
_layernorm_dense_fwd_rule
,
_layernorm_dense_bwd_rule
)
transformer_engine/jax/layernorm_mlp.py
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX MLP modules"""
"""Multi-layer perceptron (MLP) operations with layer normalization for Transformer Engine in JAX.
This module provides optimized implementations of MLP blocks commonly used in transformer
architectures. Each MLP block consists of:
1. Layer normalization
2. First dense layer transformation (GEMM1) with bias and activation
3. Second dense layer transformation (GEMM2) with bias
The implementation supports various normalization types, activation functions,
quantization, and distributed training through sharding constraints.
"""
from
typing
import
List
,
Tuple
,
Sequence
,
Union
,
Callable
from
typing
import
List
,
Tuple
,
Sequence
,
Union
,
Callable
from
functools
import
partial
from
functools
import
partial
...
@@ -11,92 +21,81 @@ import jax.numpy as jnp
...
@@ -11,92 +21,81 @@ import jax.numpy as jnp
from
jax.ad_checkpoint
import
checkpoint_name
from
jax.ad_checkpoint
import
checkpoint_name
from
.
import
cpp_extensions
as
tex
from
.
import
cpp_extensions
as
tex
from
.dot
import
fp8_dot_impl
,
get_precision_of_fp8_dot
,
quantize
from
.layernorm
import
canonicalize_norm_type
from
.layernorm
import
canonicalize_layernorm_type
from
.quantize
import
with_sharding_constraint_by_logical_axes
,
QuantizerSet
,
noop_quantizer_set
from
.fp8
import
FP8Helper
,
FP8MetaPackage
from
.sharding
import
with_sharding_constraint_by_logical_axes
def
activation_lu
(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]):
"""
Activation Unit
"""
if
len
(
activation_type
)
>
1
:
assert
x
.
shape
[
-
2
]
==
2
# Linear + GeLU
output
=
_activation_lu
(
x
,
activation_type
)
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
1
,))
def
_activation_lu
(
x
:
jnp
.
ndarray
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]):
_output
,
_
=
_activation_lu_fwd_rule
(
x
,
activation_type
)
return
_output
def
_activation_lu_fwd_rule
(
x
,
activation_type
):
fwd_output
=
tex
.
act_lu
(
x
,
activation_type
)
return
fwd_output
,
(
x
,)
def
_activation_lu_bwd_rule
(
activation_type
,
ctx
,
g
):
(
x
,)
=
ctx
assert
x
.
dtype
==
g
.
dtype
dx
=
tex
.
dact_lu
(
g
,
x
,
activation_type
)
dx
=
jnp
.
reshape
(
dx
,
x
.
shape
)
return
(
dx
,)
_activation_lu
.
defvjp
(
_activation_lu_fwd_rule
,
_activation_lu_bwd_rule
)
def
layernorm_mlp
(
def
fused_layernorm_fp8_mlp
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
kernels
:
List
[
jnp
.
ndarray
],
kernels
:
List
[
jnp
.
ndarray
],
biases
:
List
[
jnp
.
ndarray
],
biases
:
List
[
jnp
.
ndarray
],
fp8_meta_pkgs
:
List
[
FP8MetaPackage
],
norm_type
:
str
,
layernorm_type
:
str
,
zero_centered_gamma
:
bool
=
False
,
zero_centered_gamma
:
bool
=
False
,
epsilon
:
float
=
1e-6
,
epsilon
:
float
=
1e-6
,
layer
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
norm_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_1_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
dot_2_input_axes
:
Tuple
[
str
,
...]
=
None
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn1_ckpt_name
:
str
=
"ffn1"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
ffn2_ckpt_name
:
str
=
"ffn2"
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
activation_type
:
Sequence
[
Union
[
str
,
Callable
]]
=
(
"gelu"
,),
use_bias
:
bool
=
True
,
quantizer_sets
:
Tuple
[
QuantizerSet
]
=
(
noop_quantizer_set
,
noop_quantizer_set
)
,
)
->
jnp
.
ndarray
:
)
->
jnp
.
ndarray
:
"""Apply layer normalization followed by MLP block.
This function implements the following sequence of operations:
1. Layer normalization: (x - mean) / sqrt(var + epsilon) * gamma + beta
2. First dense layer transformation: y1 = x * kernel1 + bias1
3. Activation function: y2 = activation(y1)
4. Second dense layer transformation: y3 = y2 * kernel2 + bias2
Args:
x: Input tensor with shape [batch..., hidden_in]
gamma: Scale parameter for normalization with shape [hidden_in]
beta: Bias parameter for normalization with shape [hidden_in]
kernels: List of two weight matrices:
- kernel1: [hidden_in, intermediate]
- kernel2: [intermediate, hidden_in]
biases: List of two bias terms:
- bias1: [intermediate]
- bias2: [hidden_in]
norm_type: Type of normalization ("layernorm" or "rmsnorm")
zero_centered_gamma: Whether to use zero-centered gamma for normalization
epsilon: Small constant for numerical stability in normalization
norm_input_axes: Logical axes for sharding the layernorm input
dot_1_input_axes: Logical axes for sharding the first matrix multiplication
dot_2_input_axes: Logical axes for sharding the second matrix multiplication
ffn1_ckpt_name: Name for checkpointing the first feed-forward network
ffn2_ckpt_name: Name for checkpointing the second feed-forward network
activation_type: Activation function(s) to apply after the first dense layer transformation
quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations
Returns:
Output tensor with shape [batch..., hidden_in]
Note:
- For RMSNorm (norm_type="rmsnorm"), beta must be None and zero_centered_gamma
must be False
- The function supports automatic differentiation through JAX's custom VJP
- Quantization is applied to both dense layer transformations
- Checkpointing is applied to both feed-forward networks for memory efficiency
"""
"""
Layernorm + GEMM1 + bias + activation + GEMM2 + bias
"""
assert
len
(
kernels
)
==
2
assert
len
(
kernels
)
==
2
assert
len
(
fp8_meta_pkgs
)
==
len
(
kernels
)
kernel_1
=
kernels
[
0
]
kernel_1
=
kernels
[
0
]
kernel_2
=
kernels
[
1
]
kernel_2
=
kernels
[
1
]
bias_1
=
biases
[
0
]
bias_1
=
biases
[
0
]
bias_2
=
biases
[
1
]
bias_2
=
biases
[
1
]
amax_list_1
=
fp8_meta_pkgs
[
0
].
amax_list
amax_list_2
=
fp8_meta_pkgs
[
1
].
amax_list
scale_list_1
=
fp8_meta_pkgs
[
0
].
scale_list
scale_list_2
=
fp8_meta_pkgs
[
1
].
scale_list
fwd_dtype
=
FP8Helper
.
FWD_DTYPE
norm_type
=
canonicalize_norm_type
(
norm_type
)
bwd_dtype
=
FP8Helper
.
BWD_DTYPE
if
norm_type
==
"rmsnorm"
:
assert
beta
is
None
,
"beta should be None if norm_type is 'rmsnorm'"
layernorm_type
=
canonicalize_layernorm_type
(
layernorm_type
)
if
layernorm_type
==
"rmsnorm"
:
assert
beta
is
None
,
"beta should be None if layernorm_type is 'rmsnorm'"
assert
(
assert
(
not
zero_centered_gamma
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if
layer
norm_type is 'rmsnorm'"
),
"zero_centered_gamma is not supported if norm_type is 'rmsnorm'"
output
=
_fused
_layernorm_
fp8_
mlp
(
output
=
_layernorm_mlp
(
x
,
x
,
gamma
,
gamma
,
beta
,
beta
,
...
@@ -104,28 +103,22 @@ def fused_layernorm_fp8_mlp(
...
@@ -104,28 +103,22 @@ def fused_layernorm_fp8_mlp(
kernel_2
,
kernel_2
,
bias_1
,
bias_1
,
bias_2
,
bias_2
,
amax_list_1
,
norm_type
,
amax_list_2
,
scale_list_1
,
scale_list_2
,
fwd_dtype
,
bwd_dtype
,
layernorm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
layer
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
use_bia
s
,
quantizer_set
s
,
)
)
return
output
return
output
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
11
,
12
,
13
,
1
4
,
1
5
,
16
,
1
7
,
1
8
,
1
9
,
20
,
21
,
22
))
@
partial
(
jax
.
custom_vjp
,
nondiff_argnums
=
(
7
,
8
,
9
,
1
0
,
1
1
,
1
2
,
1
3
,
1
4
,
15
))
def
_fused
_layernorm_
fp8_
mlp
(
def
_layernorm_mlp
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
gamma
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
beta
:
jnp
.
ndarray
,
...
@@ -133,24 +126,46 @@ def _fused_layernorm_fp8_mlp(
...
@@ -133,24 +126,46 @@ def _fused_layernorm_fp8_mlp(
kernel_2
:
jnp
.
ndarray
,
kernel_2
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_1
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
bias_2
:
jnp
.
ndarray
,
amax_list_1
:
List
[
jnp
.
ndarray
],
norm_type
:
str
,
amax_list_2
:
List
[
jnp
.
ndarray
],
scale_list_1
:
List
[
jnp
.
ndarray
],
scale_list_2
:
List
[
jnp
.
ndarray
],
fwd_dtype
:
jnp
.
dtype
,
bwd_dtype
:
jnp
.
dtype
,
layernorm_type
:
str
,
zero_centered_gamma
:
bool
,
zero_centered_gamma
:
bool
,
epsilon
:
float
,
epsilon
:
float
,
layer
norm_input_axes
:
Tuple
[
str
,
...],
norm_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_1_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
dot_2_input_axes
:
Tuple
[
str
,
...],
ffn1_ckpt_name
:
str
,
ffn1_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
ffn2_ckpt_name
:
str
,
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
activation_type
:
Sequence
[
Union
[
str
,
Callable
]],
use_bias
:
bool
,
quantizer_sets
,
):
):
output
,
_
=
_fused_layernorm_fp8_mlp_fwd_rule
(
"""Internal implementation of layernorm_mlp with custom VJP.
This function implements the forward pass of layernorm_mlp with support for
automatic differentiation. It handles the normalization, dense layer transformations,
activation, and quantization operations.
Args:
x: Input tensor
gamma: Scale parameter for normalization
beta: Bias parameter for normalization
kernel_1: First weight matrix
kernel_2: Second weight matrix
bias_1: First bias term
bias_2: Second bias term
norm_type: Type of normalization
zero_centered_gamma: Whether to use zero-centered gamma
epsilon: Small constant for numerical stability
norm_input_axes: Logical axes for layernorm sharding
dot_1_input_axes: Logical axes for first matrix multiplication sharding
dot_2_input_axes: Logical axes for second matrix multiplication sharding
ffn1_ckpt_name: Name for first feed-forward network checkpointing
ffn2_ckpt_name: Name for second feed-forward network checkpointing
activation_type: Activation function(s)
quantizer_sets: Tuple of quantizer sets
Returns:
Output tensor from the combined operations
"""
output
,
_
=
_layernorm_mlp_fwd_rule
(
x
,
x
,
gamma
,
gamma
,
beta
,
beta
,
...
@@ -158,27 +173,21 @@ def _fused_layernorm_fp8_mlp(
...
@@ -158,27 +173,21 @@ def _fused_layernorm_fp8_mlp(
kernel_2
,
kernel_2
,
bias_1
,
bias_1
,
bias_2
,
bias_2
,
amax_list_1
,
norm_type
,
amax_list_2
,
scale_list_1
,
scale_list_2
,
fwd_dtype
,
bwd_dtype
,
layernorm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
layer
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
use_bia
s
,
quantizer_set
s
,
)
)
return
output
return
output
def
_fused
_layernorm_
fp8_
mlp_fwd_rule
(
def
_layernorm_mlp_fwd_rule
(
x
,
x
,
gamma
,
gamma
,
beta
,
beta
,
...
@@ -186,444 +195,257 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
...
@@ -186,444 +195,257 @@ def _fused_layernorm_fp8_mlp_fwd_rule(
kernel_2
,
kernel_2
,
bias_1
,
bias_1
,
bias_2
,
bias_2
,
amax_list_1
,
norm_type
,
amax_list_2
,
scale_list_1
,
scale_list_2
,
fwd_dtype
,
bwd_dtype
,
# pylint: disable=unused-argument
layernorm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
layer
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
ffn1_ckpt_name
,
ffn2_ckpt_name
,
ffn2_ckpt_name
,
activation_type
,
activation_type
,
use_bia
s
,
quantizer_set
s
,
):
):
"""Forward pass rule for layernorm_mlp.
Implements the forward pass computation including:
1. Layer normalization with quantization
2. First matrix multiplication with quantized kernel
3. Activation function application
4. Second matrix multiplication with quantized kernel
5. Optional bias additions
6. Sharding constraints
7. Checkpointing for memory efficiency
Returns:
Tuple of (output, context) for automatic differentiation
"""
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
# x should be in shape of (batch..., hidden)
# x should be in shape of (batch..., hidden)
# Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out)
# Kernel_1 should be in shape of (hidden_in, activation_len * intermediate)
# Kernel_2 should be in shape of (Hidden_in, Hidden_out)
# Kernel_2 should be in shape of (intermediate, hidden_in)
assert
len
(
kernel_1
.
shape
)
==
3
assert
len
(
kernel_1
.
shape
)
==
2
assert
kernel_1
.
shape
[
-
2
]
==
len
(
activation_type
)
assert
len
(
kernel_2
.
shape
)
==
2
assert
len
(
kernel_2
.
shape
)
==
2
assert
kernel_1
.
shape
[
1
]
==
kernel_2
.
shape
[
0
]
*
len
(
activation_type
)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
x_contracting_dims
=
(
len
(
x
.
shape
)
-
1
,)
xt_batch_dims
=
tuple
(
range
(
1
,
x
.
ndim
)
)
k_contracting_dims
=
(
0
,
)
assert
x
.
shape
[
x_contracting_dims
[
0
]]
==
kernel_1
.
shape
[
0
]
assert
x
.
shape
[
x_contracting_dims
[
0
]]
==
kernel_1
.
shape
[
k_contracting_dims
[
0
]
]
assert
kernel_1
.
shape
[
-
1
]
==
kernel_2
.
shape
[
0
]
assert
kernel_1
.
shape
[
1
]
==
len
(
activation_type
)
*
kernel_2
.
shape
[
0
]
maybe_fm32_to_fp32
,
maybe_fp32_to_fm32
=
FP8Helper
.
generate_fp8_meta_dtype_converter_pair
(
use_bias_1
=
bias_1
is
not
None
*
amax_list_1
,
*
scale_list_1
,
*
amax_list_2
,
*
scale_list_2
use_bias_2
=
bias_1
is
not
None
)
amax_list_1
=
maybe_fm32_to_fp32
(
*
amax_list_1
)
x
=
with_sharding_constraint_by_logical_axes
(
x
,
norm_input_axes
)
scale_list_1
=
maybe_fm32_to_fp32
(
*
scale_list_1
)
amax_list_2
=
maybe_fm32_to_fp32
(
*
amax_list_2
)
casted_ln_out
,
mu
,
rsigma
=
tex
.
normalization_fwd
(
scale_list_2
=
maybe_fm32_to_fp32
(
*
scale_list_2
)
x
,
gamma
,
fp8_dtype_list
=
[
fwd_dtype
,
fwd_dtype
,
bwd_dtype
]
beta
,
scale_list_1
,
scale_inv_list_1
=
FP8MetaPackage
.
update_fp8_scale
(
zero_centered_gamma
,
amax_list_1
,
scale_list_1
,
fp8_dtype_list
epsilon
,
)
norm_type
,
amax_list_1
=
FP8MetaPackage
.
update_amax_list
(
amax_list_1
)
quantizer
=
ffn1_quantizer_set
.
x
,
scale_list_2
,
scale_inv_list_2
=
FP8MetaPackage
.
update_fp8_scale
(
amax_list_2
,
scale_list_2
,
fp8_dtype_list
)
amax_list_2
=
FP8MetaPackage
.
update_amax_list
(
amax_list_2
)
x_amax
=
amax_list_1
[
FP8MetaPackage
.
INPUT_IDX
][
0
:
1
]
x_scale
=
scale_list_1
[
FP8MetaPackage
.
INPUT_IDX
]
x_scale_inv
=
scale_inv_list_1
[
FP8MetaPackage
.
INPUT_IDX
]
x
=
with_sharding_constraint_by_logical_axes
(
x
,
layernorm_input_axes
)
if
layernorm_type
==
"layernorm"
:
ln_out
,
mu
,
rsigma
,
updated_x_amax
=
tex
.
layernorm_fwd_fp8
(
x
,
gamma
,
beta
,
x_amax
,
x_scale
,
x_scale_inv
,
out_dtype
=
fwd_dtype
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
else
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
ln_out
,
rsigma
,
updated_x_amax
=
tex
.
rmsnorm_fwd_fp8
(
x
,
gamma
,
x_amax
,
x_scale
,
x_scale_inv
,
out_dtype
=
fwd_dtype
,
epsilon
=
epsilon
)
mu
=
None
assert
x
.
shape
==
ln_out
.
shape
kernel_1_amax
=
amax_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
][
0
:
1
]
kernel_1_scale
=
scale_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
]
kernel_1_scale_inv
=
scale_inv_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
]
# Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_1
,
updated_kernel_1_amax
=
tex
.
cast_fp8
(
kernel_1
,
kernel_1_amax
,
kernel_1_scale
,
kernel_1_scale_inv
,
fwd_dtype
)
)
ln_out
=
with_sharding_constraint_by_logical_axes
(
ln_out
,
dot_1_input_axes
)
casted_kernel_1
=
tex
.
quantize
(
kernel_1
,
quantizer
=
ffn1_quantizer_set
.
kernel
)
casted_ln_out
=
with_sharding_constraint_by_logical_axes
(
casted_ln_out
,
dot_1_input_axes
)
# NN GEMM
# (batch..., hidden_in) x (hidden_in, hidden_out)
# (batch..., hidden_in) x (hidden_in, hidden_out)
dot_1_output
=
fp8_dot_impl
(
dot_1_output
=
tex
.
gemm
(
ln_out
,
casted_ln_out
.
get_rowwise_tensor
(),
casted_kernel_1
,
casted_kernel_1
.
get_colwise_tensor
(),
x_scale_inv
,
(
x_contracting_dims
,
k_contracting_dims
),
kernel_1_scale_inv
,
x
.
dtype
,
(
x_contracting_dims
,
(
0
,)),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_FPROP
),
)
)
if
use_bias
:
if
use_bias
_1
:
bias_1_shape
=
bias_1
.
shape
bias_1_shape
=
bias_1
.
shape
bias_1_new_shape
=
(
1
,)
*
(
dot_1_output
.
ndim
-
bias_1
.
ndim
)
+
bias_1_shape
bias_1_new_shape
=
(
1
,)
*
(
dot_1_output
.
ndim
-
bias_1
.
ndim
)
+
bias_1_shape
dot_1_output
+=
jnp
.
reshape
(
bias_1
,
bias_1_new_shape
)
dot_1_output
+=
jnp
.
reshape
(
bias_1
,
bias_1_new_shape
)
else
:
bias_1_shape
=
None
dot_1_output
=
checkpoint_name
(
dot_1_output
,
ffn1_ckpt_name
)
activation_lu_out_amax
=
amax_list_2
[
FP8MetaPackage
.
INPUT_IDX
][
0
:
1
]
dot_1_output
=
checkpoint_name
(
dot_1_output
,
ffn1_ckpt_name
)
activation_lu_out_scale
=
scale_list_2
[
FP8MetaPackage
.
INPUT_IDX
]
activation_lu_out_scale_inv
=
scale_inv_list_2
[
FP8MetaPackage
.
INPUT_IDX
]
# (batch..., hidden_in) -> (batch..., hidden)
# (batch..., hidden_in) -> (batch..., hidden)
casted_activation_lu_out
,
updated_activation_lu_amax
=
tex
.
act_lu_fp8
(
casted_act_out
=
tex
.
act_lu
(
dot_1_output
,
activation_type
,
quantizer
=
ffn2_quantizer_set
.
x
)
dot_1_output
,
activation_lu_out_amax
,
activation_lu_out_scale
,
activation_lu_out_scale_inv
,
fwd_dtype
,
activation_type
,
)
casted_activation_lu_out
=
with_sharding_constraint_by_logical_axes
(
casted_act_out
=
with_sharding_constraint_by_logical_axes
(
casted_act_out
,
dot_2_input_axes
)
casted_activation_lu_out
,
dot_2_input_axes
)
kernel_2_scale
=
scale_list_2
[
FP8MetaPackage
.
WEIGHT_IDX
]
casted_kernel_2
=
tex
.
quantize
(
kernel_2
,
quantizer
=
ffn2_quantizer_set
.
kernel
)
kernel_2_scale_inv
=
scale_inv_list_2
[
FP8MetaPackage
.
WEIGHT_IDX
]
# Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding
# unnecessary copy to break FP8 GEMM pattern matching.
casted_kernel_2
,
updated_kernel_2_amax
=
quantize
(
kernel_2
,
fwd_dtype
,
kernel_2_scale
)
# NN GEMM
# (batch..., hidden_in) x (hidden_out, hidden_in)
# (batch..., hidden_in) x (hidden_out, hidden_in)
dot_2_output
=
fp8_dot_impl
(
dot_2_output
=
tex
.
gemm
(
casted_activation_lu_out
,
casted_act_out
.
get_rowwise_tensor
(),
casted_kernel_2
,
casted_kernel_2
.
get_colwise_tensor
(),
activation_lu_out_scale_inv
,
(
x_contracting_dims
,
k_contracting_dims
),
kernel_2_scale_inv
,
x
.
dtype
,
(
x_contracting_dims
,
(
0
,)),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_FPROP
),
)
)
if
use_bias
:
if
use_bias
_2
:
bias_2_shape
=
bias_2
.
shape
bias_2_shape
=
bias_2
.
shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
dot_2_output
+=
jnp
.
reshape
(
bias_2
,
bias_2_new_shape
)
dot_2_output
+=
jnp
.
reshape
(
bias_2
,
bias_2_new_shape
)
else
:
bias_2_shape
=
None
dot_2_output
=
checkpoint_name
(
dot_2_output
,
ffn2_ckpt_name
)
dot_2_output
=
checkpoint_name
(
dot_2_output
,
ffn2_ckpt_name
)
ctx
=
(
ctx
=
(
x
,
x
,
ln_out
,
mu
,
mu
,
rsigma
,
rsigma
,
gamma
,
gamma
,
beta
,
beta
,
casted_ln_out
.
get_colwise_tensor
(),
casted_kernel_1
.
get_rowwise_tensor
(),
dot_1_output
,
dot_1_output
,
casted_activation_lu_out
,
casted_act_out
.
get_colwise_tensor
(),
casted_kernel_1
,
casted_kernel_2
.
get_rowwise_tensor
(),
casted_kernel_2
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
,
scale_inv_list_1
,
scale_inv_list_2
,
updated_x_amax
,
updated_activation_lu_amax
,
updated_kernel_1_amax
,
updated_kernel_2_amax
,
x_contracting_dims
,
x_contracting_dims
,
xt_batch_dims
,
k_contracting_dims
,
bias_1_shape
,
kernel_1
.
shape
,
bias_2_shape
,
kernel_2
.
shape
,
maybe_fp32_to_fm32
,
use_bias_1
,
use_bias_2
,
quantizer_sets
,
)
)
return
dot_2_output
,
ctx
return
dot_2_output
,
ctx
def
_fused_layernorm_fp8_mlp_bwd_rule
(
def
_layernorm_mlp_bwd_rule
(
fwd_dtype
,
# pylint: disable=unused-argument
norm_type
,
bwd_dtype
,
layernorm_type
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
layer
norm_input_axes
,
norm_input_axes
,
dot_1_input_axes
,
dot_1_input_axes
,
dot_2_input_axes
,
dot_2_input_axes
,
ffn1_ckpt_name
,
# pylint: disable=unused-argument
ffn1_ckpt_name
,
# pylint: disable=unused-argument
ffn2_ckpt_name
,
# pylint: disable=unused-argument
ffn2_ckpt_name
,
# pylint: disable=unused-argument
activation_type
,
activation_type
,
use_bias
,
ctx
,
ctx
,
grad
,
grad
,
):
):
"""Backward pass rule for layernorm_mlp.
Implements the backward pass computation including:
1. Gradient computation for second matrix multiplication
2. Gradient computation for activation function
3. Gradient computation for first matrix multiplication
4. Gradient computation for layer normalization
5. Gradient computation for bias terms
6. Proper handling of quantization
Returns:
Tuple of gradients for all input parameters
"""
(
(
x
,
x
,
ln_out
,
mu
,
mu
,
rsigma
,
rsigma
,
gamma
,
gamma
,
beta
,
beta
,
colwise_casted_ln_out
,
rowwise_casted_kernel_1
,
dot_1_output
,
dot_1_output
,
casted_activation_lu_out
,
colwise_casted_act_out
,
casted_kernel_1
,
rowwise_casted_kernel_2
,
casted_kernel_2
,
x_contracting_dims_in_fwd
,
amax_list_1
,
k_contracting_dims_in_fwd
,
amax_list_2
,
kernel_1_shape
,
scale_list_1
,
kernel_2_shape
,
scale_list_2
,
use_bias_1
,
scale_inv_list_1
,
use_bias_2
,
scale_inv_list_2
,
quantizer_sets
,
updated_x_amax
,
updated_activation_lu_amax
,
updated_kernel_1_amax
,
updated_kernel_2_amax
,
x_contracting_dims
,
xt_batch_dims
,
bias_1_shape
,
bias_2_shape
,
maybe_fp32_to_fm32
,
)
=
ctx
)
=
ctx
grad_amax
=
amax_list_2
[
FP8MetaPackage
.
GRAD_IDX
][
0
:
1
]
ffn1_quantizer_set
,
ffn2_quantizer_set
=
quantizer_sets
grad_scale
=
scale_list_2
[
FP8MetaPackage
.
GRAD_IDX
]
grad_scale_inv
=
scale_inv_list_2
[
FP8MetaPackage
.
GRAD_IDX
]
# Since the sharding of outputs should be the same as dot_1's input
# Since the sharding of outputs should be the same as dot_1's input
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_1_input_axes
)
grad
=
with_sharding_constraint_by_logical_axes
(
grad
,
dot_1_input_axes
)
if
use_bias
:
casted_grad
,
casted_grad_t
,
dbias_2
,
updated_grad_amax
=
tex
.
dbias_cast_transpose
(
casted_grad
,
dbias_2
=
tex
.
quantize_dbias
(
grad
,
grad
,
is_dbias
=
use_bias_2
,
quantizer
=
ffn1_quantizer_set
.
dgrad
grad_amax
,
grad_scale
,
grad_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
,
)
dbias_2
=
jnp
.
reshape
(
dbias_2
,
bias_2_shape
)
else
:
casted_grad
,
casted_grad_t
,
updated_grad_amax
=
tex
.
cast_transpose
(
grad
,
grad_amax
,
grad_scale
,
grad_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
,
)
dbias_2
=
None
casted_activation_lu_out_t
=
tex
.
transpose
(
casted_activation_lu_out
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
)
)
# (hidden, batch...,) x (hidden, batch...)
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
gemm2_x_scale_inv
=
scale_inv_list_2
[
FP8MetaPackage
.
INPUT_IDX
]
g_constracting_dim_2
=
tuple
(
wgrad_2
=
fp8_dot_impl
(
range
(
grad
.
ndim
-
len
(
kernel_2_shape
)
+
len
(
k_contracting_dims_in_fwd
),
grad
.
ndim
)
casted_activation_lu_out_t
,
)
casted_grad_t
,
# k_non_contracting_dims
gemm2_x_scale_inv
,
k_constracting_dim_2
=
tuple
(
grad_scale_inv
,
dim
for
dim
in
range
(
len
(
kernel_2_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
grad
.
dtype
,
(
xt_batch_dims
,
xt_batch_dims
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_WGRAD
),
)
)
# NT GEMM
# (batch..., hidden_out) x (hidden_in, hidden_out)
# (batch..., hidden_out) x (hidden_in, hidden_out)
kernel_2_scale_inv
=
scale_inv_list_2
[
FP8MetaPackage
.
WEIGHT_IDX
]
dgrad_2
=
tex
.
gemm
(
dgrad_2
=
fp8_dot_impl
(
casted_grad
.
get_rowwise_tensor
(),
casted_grad
,
rowwise_casted_kernel_2
,
casted_kernel_2
,
(
g_constracting_dim_2
,
k_constracting_dim_2
),
grad_scale_inv
,
kernel_2_scale_inv
,
grad
.
dtype
,
(
x_contracting_dims
,
(
1
,)),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_DGRAD
),
)
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
dgrad_2
=
with_sharding_constraint_by_logical_axes
(
dgrad_2
,
dot_2_input_axes
)
dactivation_lu_amax
=
amax_list_1
[
FP8MetaPackage
.
GRAD_IDX
][
0
:
1
]
x_constracting_dim
=
g_constracting_dim
=
tuple
(
dactivation_lu_scale
=
scale_list_1
[
FP8MetaPackage
.
GRAD_IDX
]
range
(
0
,
len
(
x
.
shape
)
-
len
(
x_contracting_dims_in_fwd
))
dactivation_lu_scale_inv
=
scale_inv_list_1
[
FP8MetaPackage
.
GRAD_IDX
]
if
len
(
activation_type
)
>
1
:
# if gated
if
use_bias
:
dactivation_lu
=
tex
.
dact_lu
(
dgrad_2
,
dot_1_output
,
activation_type
)
casted_dactivation_lu
,
casted_dactivation_lu_t
,
dbias_1
,
updated_dactivation_lu_amax
=
(
tex
.
dbias_cast_transpose
(
dactivation_lu
,
dactivation_lu_amax
,
dactivation_lu_scale
,
dactivation_lu_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
2
,
)
)
dbias_1
=
jnp
.
reshape
(
dbias_1
,
bias_1_shape
)
else
:
casted_dactivation_lu
,
casted_dactivation_lu_t
,
updated_dactivation_lu_amax
=
(
tex
.
dgated_act_lu_cast_transpose
(
dgrad_2
,
dot_1_output
,
dactivation_lu_amax
,
dactivation_lu_scale
,
dactivation_lu_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
activation_type
=
activation_type
,
)
)
dbias_1
=
None
else
:
if
use_bias
:
casted_dactivation_lu
,
casted_dactivation_lu_t
,
dbias_1
,
updated_dactivation_lu_amax
=
(
tex
.
dact_lu_dbias_cast_transpose
(
dgrad_2
,
dot_1_output
,
dactivation_lu_amax
,
dactivation_lu_scale
,
dactivation_lu_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
activation_type
=
activation_type
,
)
)
dbias_1
=
jnp
.
reshape
(
dbias_1
,
bias_1_shape
)
else
:
dactivation_lu
=
tex
.
dact_lu
(
dgrad_2
,
dot_1_output
,
activation_type
)
casted_dactivation_lu
,
casted_dactivation_lu_t
,
updated_dactivation_lu_amax
=
(
tex
.
cast_transpose
(
dactivation_lu
,
dactivation_lu_amax
,
dactivation_lu_scale
,
dactivation_lu_scale_inv
,
bwd_dtype
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
2
,
)
)
dbias_1
=
None
ln_out_t
=
tex
.
transpose
(
ln_out
,
static_axis_boundary
=-
1
,
transpose_axis_boundary
=-
1
)
# (hidden, batch...) x (hidden, batch...)
gemm1_x_scale_inv
=
scale_inv_list_1
[
FP8MetaPackage
.
INPUT_IDX
]
xt_batch_dims_2
=
tuple
(
i
+
1
for
i
in
xt_batch_dims
)
wgrad_1
=
fp8_dot_impl
(
ln_out_t
,
casted_dactivation_lu_t
,
gemm1_x_scale_inv
,
dactivation_lu_scale_inv
,
grad
.
dtype
,
(
xt_batch_dims
,
xt_batch_dims_2
),
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_WGRAD
),
)
)
x_contracting_dims
=
(
# TN GEMM
(
min
(
x_contracting_dims
),)
+
tuple
(
i
+
1
for
i
in
x_contracting_dims
),
# (hidden, batch...,) x (hidden, batch...)
(
1
,
2
),
wgrad_2
=
tex
.
gemm
(
)
colwise_casted_act_out
,
kernel_1_scale_inv
=
scale_inv_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
]
casted_grad
.
get_colwise_tensor
(),
dgrad_1
=
fp8_dot_impl
(
(
x_constracting_dim
,
g_constracting_dim
),
casted_dactivation_lu
,
casted_kernel_1
,
dactivation_lu_scale_inv
,
kernel_1_scale_inv
,
grad
.
dtype
,
x_contracting_dims
,
get_precision_of_fp8_dot
(
FP8Helper
.
FP8_2X_ACC_DGRAD
),
)
)
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
layernorm_input_axes
)
casted_dact_out
,
dbias_1
=
tex
.
quantize_dact_dbias
(
dgrad_2
,
if
layernorm_type
==
"layernorm"
:
dot_1_output
,
dx
,
dgamma
,
dbeta
=
tex
.
layernorm_bwd
(
activation_type
=
activation_type
,
dgrad_1
,
is_dbias
=
use_bias_1
,
x
,
quantizer
=
ffn2_quantizer_set
.
dgrad
,
mu
,
rsigma
,
gamma
,
beta
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
)
else
:
assert
(
not
zero_centered_gamma
),
"zero_centered_gamma is not supported if layernorm_type is 'rmsnorm'"
dx
,
dgamma
=
tex
.
rmsnorm_bwd
(
dgrad_1
,
x
,
rsigma
,
gamma
,
epsilon
=
epsilon
)
dbeta
=
None
amax_list_1
[
FP8MetaPackage
.
INPUT_IDX
]
=
(
amax_list_1
[
FP8MetaPackage
.
INPUT_IDX
].
at
[
0
].
set
(
updated_x_amax
[
0
])
)
amax_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
(
amax_list_1
[
FP8MetaPackage
.
WEIGHT_IDX
].
at
[
0
].
set
(
updated_kernel_1_amax
[
0
])
)
)
amax_list_1
[
FP8MetaPackage
.
GRAD_IDX
]
=
(
amax_list_1
[
FP8MetaPackage
.
GRAD_IDX
].
at
[
0
].
set
(
updated_dactivation_lu_amax
[
0
])
# k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim
g_constracting_dim_1
=
tuple
(
range
(
dgrad_2
.
ndim
-
len
(
kernel_1_shape
)
+
len
(
k_contracting_dims_in_fwd
),
dgrad_2
.
ndim
)
)
)
amax_list_2
[
FP8MetaPackage
.
INPUT_IDX
]
=
(
# k_non_contracting_dims
amax_list_2
[
FP8MetaPackage
.
INPUT_IDX
].
at
[
0
].
set
(
updated_activation_lu_amax
[
0
])
k_constracting_dim_1
=
tuple
(
dim
for
dim
in
range
(
len
(
kernel_1_shape
))
if
dim
not
in
k_contracting_dims_in_fwd
)
)
amax_list_2
[
FP8MetaPackage
.
WEIGHT_IDX
]
=
(
amax_list_2
[
FP8MetaPackage
.
WEIGHT_IDX
].
at
[
0
].
set
(
updated_kernel_2_amax
)
# NT GEMM
dgrad_1
=
tex
.
gemm
(
casted_dact_out
.
get_rowwise_tensor
(),
rowwise_casted_kernel_1
,
(
g_constracting_dim_1
,
k_constracting_dim_1
),
)
)
amax_list_2
[
FP8MetaPackage
.
GRAD_IDX
]
=
(
amax_list_2
[
FP8MetaPackage
.
GRAD_IDX
].
at
[
0
].
set
(
updated_grad_amax
[
0
])
dgrad_1
=
with_sharding_constraint_by_logical_axes
(
dgrad_1
,
norm_input_axes
)
# TN GEMM
# (hidden, batch...) x (hidden, batch...)
wgrad_1
=
tex
.
gemm
(
colwise_casted_ln_out
,
casted_dact_out
.
get_colwise_tensor
(),
(
x_constracting_dim
,
g_constracting_dim
),
)
)
amax_list_1
=
maybe_fp32_to_fm32
(
*
amax_list_1
)
dx
,
dgamma
,
dbeta
=
tex
.
normalization_bwd
(
scale_list_1
=
maybe_fp32_to_fm32
(
*
scale_list_1
)
dgrad_1
,
amax_list_2
=
maybe_fp32_to_fm32
(
*
amax_list_2
)
x
,
scale_list_2
=
maybe_fp32_to_fm32
(
*
scale_list_2
)
mu
,
rsigma
,
return
(
gamma
,
dx
,
beta
,
dgamma
,
zero_centered_gamma
=
zero_centered_gamma
,
dbeta
,
epsilon
=
epsilon
,
wgrad_1
,
norm_type
=
norm_type
,
wgrad_2
,
dbias_1
,
dbias_2
,
amax_list_1
,
amax_list_2
,
scale_list_1
,
scale_list_2
,
)
)
return
(
dx
,
dgamma
,
dbeta
,
wgrad_1
,
wgrad_2
,
dbias_1
,
dbias_2
,
quantizer_sets
)
_fused_layernorm_fp8_mlp
.
defvjp
(
_layernorm_mlp
.
defvjp
(
_layernorm_mlp_fwd_rule
,
_layernorm_mlp_bwd_rule
)
_fused_layernorm_fp8_mlp_fwd_rule
,
_fused_layernorm_fp8_mlp_bwd_rule
)
transformer_engine/jax/quantize/__init__.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Python interface for quantization helpers.
This module provides a high-level interface for tensor quantization in JAX,
including support for various scaling modes and quantization strategies.
It exports all the necessary classes and functions from the underlying
implementation modules.
"""
from
.tensor
import
*
from
.quantizer
import
*
from
.dequantizer
import
*
from
.scaling_modes
import
*
from
.metadata
import
*
from
.helper
import
*
transformer_engine/jax/quantize/dequantizer.py
0 → 100644
View file @
a207db1d
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Dequantization utilities for TE/JAX.
This module provides utilities for dequantizing tensors that have been quantized
using various scaling modes, including delayed scaling and block scaling.
"""
import
jax
import
jax.numpy
as
jnp
from
.scaling_modes
import
ScalingMode
__all__
=
[
"Dequantizer"
]
class
Dequantizer
:
"""Encapsulation class for dequantization helpers.
This class provides static methods for dequantizing tensors that have been
quantized using different scaling modes. It supports both delayed scaling
and block scaling modes.
"""
@
staticmethod
def
_dq_func_tensor_scaling
(
scaled_tensor
):
"""Dequantize a tensor using delayed scaling.
This function dequantizes a tensor that was quantized using delayed scaling
by multiplying the quantized data with the inverse scaling factor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
return
jnp
.
asarray
(
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
*
scaled_tensor
.
scale_inv
.
astype
(
jnp
.
float32
),
scaled_tensor
.
dq_dtype
,
)
@
staticmethod
def
_dq_func_block_scaling
(
scaled_tensor
):
"""Dequantize a tensor using block scaling.
This function dequantizes a tensor that was quantized using block scaling
by applying the inverse scaling factor to each block of data.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
data
=
scaled_tensor
.
data
.
astype
(
jnp
.
float32
)
data_shape
=
data
.
shape
scale
=
scaled_tensor
.
scale_inv
.
view
(
jnp
.
uint8
).
astype
(
jnp
.
float32
)
scale_shape
=
scaled_tensor
.
scaling_mode
.
get_scale_shape
(
scaled_tensor
.
data
.
shape
,
scaled_tensor
.
is_colwise
,
is_padded
=
False
)
scale
=
jax
.
lax
.
slice
(
scale
,
[
0
]
*
len
(
scale_shape
),
scale_shape
)
# slice out the padding
data
=
data
.
reshape
(
*
data_shape
[:
-
2
],
scale_shape
[
-
2
],
int
(
data_shape
[
-
2
]
/
scale_shape
[
-
2
]),
scale_shape
[
-
1
],
int
(
data_shape
[
-
1
]
/
scale_shape
[
-
1
]),
)
scale
=
jnp
.
expand_dims
(
scale
,
axis
=
(
-
1
,
-
3
))
# E8M0 does not have a bit for sign. So 0 - 127 represent negative numbers.
return
jnp
.
asarray
(
data
*
jnp
.
power
(
2
,
scale
-
127
),
scaled_tensor
.
dq_dtype
).
reshape
(
data_shape
)
funcs
=
{
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
}
@
staticmethod
def
dequantize
(
scaled_tensor
):
"""Dequantize a scaled tensor using the appropriate scaling mode.
This method selects the appropriate dequantization function based on the
scaling mode used for quantization and applies it to the tensor.
Args:
scaled_tensor: The quantized tensor to dequantize
Returns:
The dequantized tensor in the specified data type
"""
dq_func
=
Dequantizer
.
funcs
[
scaled_tensor
.
scaling_mode
]
return
dq_func
(
scaled_tensor
)
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment