Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
ab3e5a92
Commit
ab3e5a92
authored
May 09, 2025
by
yuguo
Browse files
Merge commit '
04c730c0
' of...
Merge commit '
04c730c0
' of
https://github.com/NVIDIA/TransformerEngine
parents
a8d19fd9
04c730c0
Changes
174
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1647 additions
and
667 deletions
+1647
-667
transformer_engine/common/util/cast.cu
transformer_engine/common/util/cast.cu
+22
-10
transformer_engine/common/util/cast_gated_kernels.cuh
transformer_engine/common/util/cast_gated_kernels.cuh
+0
-6
transformer_engine/common/util/cast_kernels.cuh
transformer_engine/common/util/cast_kernels.cuh
+39
-6
transformer_engine/common/util/dequantize_kernels.cuh
transformer_engine/common/util/dequantize_kernels.cuh
+2
-3
transformer_engine/common/util/ptx.cuh
transformer_engine/common/util/ptx.cuh
+30
-25
transformer_engine/debug/__init__.py
transformer_engine/debug/__init__.py
+11
-0
transformer_engine/debug/pytorch/__init__.py
transformer_engine/debug/pytorch/__init__.py
+3
-0
transformer_engine/debug/pytorch/debug_quantization.py
transformer_engine/debug/pytorch/debug_quantization.py
+528
-0
transformer_engine/debug/pytorch/debug_state.py
transformer_engine/debug/pytorch/debug_state.py
+68
-0
transformer_engine/debug/pytorch/utils.py
transformer_engine/debug/pytorch/utils.py
+10
-0
transformer_engine/jax/__init__.py
transformer_engine/jax/__init__.py
+6
-2
transformer_engine/jax/activation.py
transformer_engine/jax/activation.py
+0
-1
transformer_engine/jax/cpp_extensions/activation.py
transformer_engine/jax/cpp_extensions/activation.py
+304
-197
transformer_engine/jax/cpp_extensions/attention.py
transformer_engine/jax/cpp_extensions/attention.py
+43
-1
transformer_engine/jax/cpp_extensions/base.py
transformer_engine/jax/cpp_extensions/base.py
+12
-1
transformer_engine/jax/cpp_extensions/gemm.py
transformer_engine/jax/cpp_extensions/gemm.py
+100
-149
transformer_engine/jax/cpp_extensions/misc.py
transformer_engine/jax/cpp_extensions/misc.py
+25
-19
transformer_engine/jax/cpp_extensions/normalization.py
transformer_engine/jax/cpp_extensions/normalization.py
+173
-103
transformer_engine/jax/cpp_extensions/quantization.py
transformer_engine/jax/cpp_extensions/quantization.py
+220
-135
transformer_engine/jax/cpp_extensions/softmax.py
transformer_engine/jax/cpp_extensions/softmax.py
+51
-9
No files found.
transformer_engine/common/util/cast.cu
View file @
ab3e5a92
...
@@ -37,8 +37,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
...
@@ -37,8 +37,8 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
input
,
grad
,
nullptr
,
output
,
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
input
,
grad
,
output
,
dbias
,
dbias
,
workspace
,
stream
);
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
void
nvte_quantize_noop
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
noop
,
...
@@ -46,6 +46,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
...
@@ -46,6 +46,18 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
NVTE_API_CALL
(
nvte_quantize_noop
);
NVTE_API_CALL
(
nvte_quantize_noop
);
using
namespace
transformer_engine
;
using
namespace
transformer_engine
;
// Create config with noop tensor
QuantizationConfig
quant_config
;
quant_config
.
noop_tensor
=
noop
;
nvte_quantize_v2
(
input
,
output
,
reinterpret_cast
<
NVTEQuantizationConfig
>
(
&
quant_config
),
stream
);
}
void
nvte_quantize_v2
(
const
NVTETensor
input
,
NVTETensor
output
,
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
NVTE_API_CALL
(
nvte_quantize_v2
);
using
namespace
transformer_engine
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DBIAS
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_DACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
...
@@ -53,8 +65,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
...
@@ -53,8 +65,8 @@ void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor no
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
NVTETensor
workspace
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
constexpr
const
NVTETensor
grad
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
input
,
grad
,
noop
,
output
,
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
dbias
,
workspace
,
stream
);
input
,
grad
,
output
,
dbias
,
workspace
,
quant_config
,
stream
);
}
}
void
nvte_quantize_dbias
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
dbias
,
void
nvte_quantize_dbias
(
const
NVTETensor
input
,
NVTETensor
output
,
NVTETensor
dbias
,
...
@@ -68,7 +80,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
...
@@ -68,7 +80,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
constexpr
const
NVTETensor
activation_input
=
nullptr
;
constexpr
const
NVTETensor
activation_input
=
nullptr
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
nullptr
>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
...
@@ -82,7 +94,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
...
@@ -82,7 +94,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dgelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dsilu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
...
@@ -96,7 +108,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
...
@@ -96,7 +108,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsilu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_drelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
...
@@ -110,7 +122,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
...
@@ -110,7 +122,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
drelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dqgelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
...
@@ -124,7 +136,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
...
@@ -124,7 +136,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dqgelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
void
nvte_quantize_dbias_dsrelu
(
const
NVTETensor
input
,
const
NVTETensor
activation_input
,
...
@@ -138,7 +150,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
...
@@ -138,7 +150,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat
constexpr
bool
IS_ACT
=
false
;
constexpr
bool
IS_ACT
=
false
;
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
detail
::
quantize_helper
<
IS_DBIAS
,
IS_DACT
,
IS_ACT
,
Empty
,
dsrelu
<
fp32
,
fp32
>>
(
activation_input
,
input
,
nullptr
,
output
,
dbias
,
workspace
,
stream
);
activation_input
,
input
,
output
,
dbias
,
workspace
,
nullptr
,
stream
);
}
}
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
void
nvte_dequantize
(
const
NVTETensor
input
,
NVTETensor
output
,
cudaStream_t
stream
)
{
...
...
transformer_engine/common/util/cast_gated_kernels.cuh
View file @
ab3e5a92
...
@@ -99,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -99,8 +99,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr
size_t
in_mem
=
in_act_mem
+
in_gate_mem
;
constexpr
size_t
in_mem
=
in_act_mem
+
in_gate_mem
;
constexpr
size_t
out_act_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_act_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_gate_mem
=
buff_size_aligned_out
;
constexpr
size_t
out_mem
=
out_act_mem
+
out_gate_mem
;
// const size_t in_transaction_size = grad_mem + in_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr
size_t
in_transaction_size
=
buff_elems
*
sizeof
(
IType
);
constexpr
size_t
in_transaction_size
=
buff_elems
*
sizeof
(
IType
);
...
@@ -111,7 +109,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -111,7 +109,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType
*
in_gate_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
grad_mem
+
in_act_mem
);
IType
*
in_gate_sh
=
reinterpret_cast
<
IType
*>
(
dshmem
+
grad_mem
+
in_act_mem
);
OType
*
out_act_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
);
OType
*
out_act_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
);
OType
*
out_gate_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
+
out_act_mem
);
OType
*
out_gate_sh
=
reinterpret_cast
<
OType
*>
(
dshmem
+
grad_mem
+
in_mem
+
out_act_mem
);
// uint64_t *mbar = reinterpret_cast<uint64_t *>(dshmem + grad_mem + in_mem + out_mem);
const
uint64_t
*
TMAP_grad_in
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_grad
);
const
uint64_t
*
TMAP_grad_in
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_grad
);
const
uint64_t
*
TMAP_in_act
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input_act
);
const
uint64_t
*
TMAP_in_act
=
reinterpret_cast
<
const
uint64_t
*>
(
&
tensor_map_input_act
);
...
@@ -294,7 +291,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -294,7 +291,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
USE_ROWWISE_SCALING
=
SCALE_DIM_X
>
1
;
constexpr
bool
USE_ROWWISE_SCALING
=
SCALE_DIM_X
>
1
;
constexpr
bool
USE_COLWISE_SCALING
=
SCALE_DIM_Y
>
1
;
constexpr
bool
USE_COLWISE_SCALING
=
SCALE_DIM_Y
>
1
;
constexpr
bool
COMPUTE_IN_ROWWISE_SECTION
=
!
USE_COLWISE_SCALING
;
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_Y
=
CHUNK_DIM_Y
;
// 128
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_Y
=
CHUNK_DIM_Y
;
// 128
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
// 4 = 128 / 32
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
// 4 = 128 / 32
...
@@ -839,8 +835,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
...
@@ -839,8 +835,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
size_t
scale_stride_rowwise
=
USE_ROWWISE_SCALING
?
output
->
scale_inv
.
shape
[
1
]
:
1
;
size_t
scale_stride_rowwise
=
USE_ROWWISE_SCALING
?
output
->
scale_inv
.
shape
[
1
]
:
1
;
size_t
scale_stride_colwise
=
USE_COLWISE_SCALING
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
1
;
size_t
scale_stride_colwise
=
USE_COLWISE_SCALING
?
output
->
columnwise_scale_inv
.
shape
[
1
]
:
1
;
float
*
const
amax_ptr
=
reinterpret_cast
<
float
*>
(
output
->
amax
.
dptr
);
e8m0_t
*
const
scales_rowwise_ptr
=
e8m0_t
*
const
scales_rowwise_ptr
=
USE_ROWWISE_SCALING
?
reinterpret_cast
<
e8m0_t
*>
(
output
->
scale_inv
.
dptr
)
:
nullptr
;
USE_ROWWISE_SCALING
?
reinterpret_cast
<
e8m0_t
*>
(
output
->
scale_inv
.
dptr
)
:
nullptr
;
e8m0_t
*
const
scales_colwise_ptr
=
e8m0_t
*
const
scales_colwise_ptr
=
...
...
transformer_engine/common/util/cast_kernels.cuh
View file @
ab3e5a92
...
@@ -145,7 +145,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
...
@@ -145,7 +145,6 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
OType
out_colwise_sh
[
MXFP8_BUFFERS_NUM
][
MXFP8_SHMEM_DIM_Y
][
MXFP8_SHMEM_DIM_X
];
OType
out_colwise_sh
[
MXFP8_BUFFERS_NUM
][
MXFP8_SHMEM_DIM_Y
][
MXFP8_SHMEM_DIM_X
];
constexpr
int
shmem_buff_size
=
sizeof
(
in_sh
)
/
MXFP8_BUFFERS_NUM
;
constexpr
int
shmem_buff_size
=
sizeof
(
in_sh
)
/
MXFP8_BUFFERS_NUM
;
constexpr
int
transaction_size
=
shmem_buff_size
*
(
IS_DACT
?
2
:
1
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
...
@@ -518,7 +517,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
...
@@ -518,7 +517,6 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
__shared__
alignas
(
128
)
OType
out_sh
[
FP8_BUFFERS_NUM
][
FP8_SHMEM_DIM_Y
][
FP8_SHMEM_DIM_X
];
__shared__
alignas
(
128
)
OType
out_sh
[
FP8_BUFFERS_NUM
][
FP8_SHMEM_DIM_Y
][
FP8_SHMEM_DIM_X
];
constexpr
int
shmem_buff_size
=
sizeof
(
in_sh
)
/
FP8_BUFFERS_NUM
;
constexpr
int
shmem_buff_size
=
sizeof
(
in_sh
)
/
FP8_BUFFERS_NUM
;
constexpr
int
transaction_size
=
shmem_buff_size
*
(
IS_DACT
?
2
:
1
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
const
bool
is_master_thread
=
(
threadIdx
.
x
==
0
);
...
@@ -940,7 +938,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
...
@@ -940,7 +938,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
bool
use_colwise_scaling
=
output
->
has_columnwise_data
();
checkCuDriverContext
(
stream
);
checkCuDriverContext
(
stream
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
NVTE_CHECK
(
input
.
has_data
(),
"Cannot quantize tensor without rowwise data."
);
const
auto
&
input_shape
=
input
.
data
.
shape
;
NVTE_CHECK
(
is_fp8_dtype
(
output
->
dtype
()),
"Output must have FP8 type."
);
NVTE_CHECK
(
is_fp8_dtype
(
output
->
dtype
()),
"Output must have FP8 type."
);
if
(
use_rowwise_scaling
)
{
if
(
use_rowwise_scaling
)
{
...
@@ -1250,9 +1247,9 @@ namespace detail {
...
@@ -1250,9 +1247,9 @@ namespace detail {
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
template
<
bool
IS_DBIAS
,
bool
IS_DACT
,
bool
IS_ACT
,
typename
ParamOP
,
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
float
(
*
OP
)(
float
,
const
ParamOP
&
)>
void
quantize_helper
(
const
NVTETensor
input
,
const
NVTETensor
grad
,
const
NVTETensor
noop
,
void
quantize_helper
(
const
NVTETensor
input
,
const
NVTETensor
grad
,
NVTETensor
output
,
NVTETensor
output
,
NVTETensor
dbias
,
NVTETensor
workspace
,
NVTETensor
dbias
,
NVTETensor
workspace
,
cudaStream_t
stream
)
{
const
NVTEQuantizationConfig
quant_config
,
cudaStream_t
stream
)
{
const
Tensor
*
input_tensor
;
const
Tensor
*
input_tensor
;
const
Tensor
*
activation_input_tensor
;
const
Tensor
*
activation_input_tensor
;
if
constexpr
(
IS_DBIAS
||
IS_DACT
)
{
if
constexpr
(
IS_DBIAS
||
IS_DACT
)
{
...
@@ -1267,6 +1264,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
...
@@ -1267,6 +1264,12 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
auto
output_tensor
=
reinterpret_cast
<
Tensor
*>
(
output
);
auto
output_tensor
=
reinterpret_cast
<
Tensor
*>
(
output
);
auto
dbias_tensor
=
reinterpret_cast
<
Tensor
*>
(
dbias
);
auto
dbias_tensor
=
reinterpret_cast
<
Tensor
*>
(
dbias
);
auto
workspace_tensor
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
auto
workspace_tensor
=
reinterpret_cast
<
Tensor
*>
(
workspace
);
const
QuantizationConfig
*
quant_config_cpp
=
reinterpret_cast
<
const
QuantizationConfig
*>
(
quant_config
);
// extract noop tensor from quant_config_cpp if it's not null
const
NVTETensor
noop
=
quant_config_cpp
?
quant_config_cpp
->
noop_tensor
:
nullptr
;
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
reinterpret_cast
<
const
Tensor
*>
(
noop
))
:
Tensor
();
const
auto
noop_tensor
=
noop
!=
nullptr
?
*
(
reinterpret_cast
<
const
Tensor
*>
(
noop
))
:
Tensor
();
switch
(
output_tensor
->
scaling_mode
)
{
switch
(
output_tensor
->
scaling_mode
)
{
...
@@ -1294,6 +1297,36 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
...
@@ -1294,6 +1297,36 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, const NVTETe
workspace_tensor
,
stream
);
workspace_tensor
,
stream
);
break
;
break
;
}
}
case
NVTE_BLOCK_SCALING_2D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
true
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
quantize_transpose_square_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
/*return_transpose=*/
output_tensor
->
has_columnwise_data
(),
force_pow_2_scales
,
stream
);
break
;
}
case
NVTE_BLOCK_SCALING_1D
:
{
// TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support.
NVTE_CHECK
((
!
IS_DBIAS
&&
!
IS_DACT
&&
!
IS_ACT
),
"IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"
);
bool
force_pow_2_scales
=
quant_config_cpp
?
quant_config_cpp
->
force_pow_2_scales
:
false
;
float
epsilon
=
quant_config_cpp
?
quant_config_cpp
->
amax_epsilon
:
0.0
f
;
FP8BlockwiseRowwiseOption
rowwise_option
=
output_tensor
->
has_data
()
?
FP8BlockwiseRowwiseOption
::
ROWWISE
:
FP8BlockwiseRowwiseOption
::
NONE
;
FP8BlockwiseColumnwiseOption
columnwise_option
=
output_tensor
->
has_columnwise_data
()
?
FP8BlockwiseColumnwiseOption
::
COLUMNWISE_TRANSPOSE
:
FP8BlockwiseColumnwiseOption
::
NONE
;
quantize_transpose_vector_blockwise
(
input_tensor
->
data
,
output_tensor
->
scale_inv
,
output_tensor
->
columnwise_scale_inv
,
output_tensor
->
data
,
output_tensor
->
columnwise_data
,
epsilon
,
rowwise_option
,
columnwise_option
,
force_pow_2_scales
,
stream
);
break
;
}
default:
default:
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output_tensor
->
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
output_tensor
->
scaling_mode
)
+
"."
);
}
}
...
...
transformer_engine/common/util/dequantize_kernels.cuh
View file @
ab3e5a92
...
@@ -59,7 +59,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -59,7 +59,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const
size_t
scales_stride
)
{
const
size_t
scales_stride
)
{
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr
bool
USE_ROWWISE_SCALING
=
SCALE_DIM_X
>
1
;
constexpr
bool
USE_ROWWISE_SCALING
=
SCALE_DIM_X
>
1
;
constexpr
bool
USE_COLWISE_SCALING
=
SCALE_DIM_Y
>
1
;
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_Y
=
CHUNK_DIM_Y
;
// 128
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_Y
=
CHUNK_DIM_Y
;
// 128
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
// 4 = 128 / 32
constexpr
size_t
SCALES_ROWWISE_PER_CHUNK_X
=
CHUNK_DIM_X
/
SCALE_DIM_X
;
// 4 = 128 / 32
...
@@ -68,8 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
...
@@ -68,8 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr
size_t
SCALES_COLWISE_PER_CHUNK_X
=
CHUNK_DIM_X
;
// 128
constexpr
size_t
SCALES_COLWISE_PER_CHUNK_X
=
CHUNK_DIM_X
;
// 128
constexpr
size_t
THREADS_PER_SCALE_X_ROWWISE
=
constexpr
size_t
THREADS_PER_SCALE_X_ROWWISE
=
DIVUP
(
SCALE_DIM_X
,
ELEMS_PER_THREAD
);
// 2 = 32 / 16
DIVUP
(
SCALE_DIM_X
,
ELEMS_PER_THREAD
);
// 2 = 32 / 16
constexpr
size_t
SUBWARP_WIDTH
=
THREADS_PER_SCALE_X_ROWWISE
;
// 2
const
int
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
chunk_offset_Y
=
blockIdx
.
y
*
CHUNK_DIM_Y
;
const
int
chunk_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
const
int
chunk_offset_X
=
blockIdx
.
x
*
CHUNK_DIM_X
;
...
@@ -357,6 +355,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
...
@@ -357,6 +355,7 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream)
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
NVTE_ERROR
(
"MXFP8 Dequantization is NOT supported by architectures < 10.0"
);
}
}
}
else
{
}
else
{
// TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
NVTE_ERROR
(
"Not implemented scaling mode: "
+
to_string
(
input
.
scaling_mode
)
+
"."
);
}
}
}
}
...
...
transformer_engine/common/util/ptx.cuh
View file @
ab3e5a92
...
@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
...
@@ -83,6 +83,29 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared(
:
"memory"
);
:
"memory"
);
}
}
__device__
__forceinline__
bool
mbarrier_try_wait_parity
(
uint32_t
mbar_ptr
,
const
uint32_t
parity
)
{
uint32_t
waitComplete
;
asm
volatile
(
"{
\n\t
.reg .pred P_OUT;
\n\t
"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;
\n\t
"
"selp.b32 %0, 1, 0, P_OUT;
\n
"
"}"
:
"=r"
(
waitComplete
)
:
"r"
(
mbar_ptr
),
"r"
(
parity
)
:
"memory"
);
return
static_cast
<
bool
>
(
waitComplete
);
}
__device__
__forceinline__
void
mbarrier_wait_parity
(
uint64_t
*
mbar
,
const
uint32_t
parity
)
{
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
while
(
!
mbarrier_try_wait_parity
(
mbar_ptr
,
parity
))
{
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
// shared::cta -> global
// shared::cta -> global
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_shared_to_global
(
uint64_t
*
dst_global_ptr
,
__device__
__forceinline__
void
cp_async_bulk_tensor_1d_shared_to_global
(
uint64_t
*
dst_global_ptr
,
...
@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
...
@@ -106,30 +129,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global(
:
"memory"
);
:
"memory"
);
}
}
__device__
__forceinline__
bool
mbarrier_try_wait_parity
(
uint32_t
mbar_ptr
,
const
uint32_t
parity
)
{
uint32_t
waitComplete
;
asm
volatile
(
"{
\n\t
.reg .pred P_OUT;
\n\t
"
"mbarrier.try_wait.parity.shared::cta.b64 P_OUT, [%1], %2;
\n\t
"
"selp.b32 %0, 1, 0, P_OUT;
\n
"
"}"
:
"=r"
(
waitComplete
)
:
"r"
(
mbar_ptr
),
"r"
(
parity
)
:
"memory"
);
return
static_cast
<
bool
>
(
waitComplete
);
}
__device__
__forceinline__
void
mbarrier_wait_parity
(
uint64_t
*
mbar
,
const
uint32_t
parity
)
{
uint32_t
mbar_ptr
=
__cvta_generic_to_shared
(
mbar
);
while
(
!
mbarrier_try_wait_parity
(
mbar_ptr
,
parity
))
{
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__
__forceinline__
void
cp_async_bulk_commit_group
()
{
asm
volatile
(
"cp.async.bulk.commit_group;"
);
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group
__device__
__forceinline__
void
cp_async_bulk_wait_group
()
{
__device__
__forceinline__
void
cp_async_bulk_wait_group
()
{
asm
volatile
(
"cp.async.bulk.wait_group 0;"
);
asm
volatile
(
"cp.async.bulk.wait_group 0;"
);
...
@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
...
@@ -158,13 +157,19 @@ __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() {
asm
volatile
(
"cp.async.bulk.wait_group.read 4;"
);
asm
volatile
(
"cp.async.bulk.wait_group.read 4;"
);
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group
__device__
__forceinline__
void
cp_async_bulk_commit_group
()
{
asm
volatile
(
"cp.async.bulk.commit_group;"
);
}
// Proxy fence (bi-directional):
// Proxy fence (bi-directional):
__device__
__forceinline__
void
fence_proxy_async
()
{
asm
volatile
(
"fence.proxy.async;"
);
}
__device__
__forceinline__
void
fence_proxy_async
()
{
asm
volatile
(
"fence.proxy.async;"
);
}
__device__
__forceinline__
void
fence_proxy_async_shared_cta
()
{
__device__
__forceinline__
void
fence_proxy_async_shared_cta
()
{
asm
volatile
(
"fence.proxy.async.shared::cta;"
);
asm
volatile
(
"fence.proxy.async.shared::cta;"
);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >=
10
00)
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >=
9
00)
}
// namespace ptx
}
// namespace ptx
...
...
transformer_engine/debug/__init__.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Top level package for numerical debugging."""
try
:
from
.
import
pytorch
from
.pytorch.debug_state
import
set_weight_tensor_tp_group_reduce
except
ImportError
as
e
:
pass
transformer_engine/debug/pytorch/__init__.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
transformer_engine/debug/pytorch/debug_quantization.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
This file contains DebugQuantizer and DebugQuantizedTensor objects,
which are wrappers over Quantizer and QuantizedTensor.
These wrappers add logic related to debugging, using the nvdlfw_inspect package.
"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
Iterable
,
Union
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine.pytorch.tensor.quantized_tensor
import
(
QuantizedTensor
,
Quantizer
,
prepare_for_saving
,
restore_from_saved
,
)
aten
=
torch
.
ops
.
aten
_tensor_to_gemm_names_map
=
{
"weight"
:
[
"fprop"
,
"dgrad"
],
"activation"
:
[
"fprop"
,
"wgrad"
],
"output"
:
[
"fprop"
,
None
],
"gradient"
:
[
"dgrad"
,
"wgrad"
],
"wgrad"
:
[
"wgrad"
,
None
],
"dgrad"
:
[
"dgrad"
,
None
],
}
API_CALL_MODIFY
=
"modify_tensor()"
STANDARD_FP8_QUANTIZE
=
"FP8 Quantize"
HIGH_PRECISION
=
"High Precision"
class
DebugQuantizer
(
Quantizer
):
"""
DebugQuantizer is a Quantizer object used for debugging with nvidia-dlframework-inspect.
It allows adding custom calls inside the quantization process - which enables modifying tensors
or gathering tensor stats.
"""
def
__init__
(
self
,
layer_name
:
str
,
tensor_name
:
str
,
parent_quantizer
:
Optional
[
Quantizer
],
tp_group
:
torch
.
distributed
.
ProcessGroup
,
):
import
nvdlfw_inspect.api
as
debug_api
super
().
__init__
(
rowwise
=
True
,
columnwise
=
True
)
self
.
layer_name
=
layer_name
self
.
tensor_name
=
tensor_name
self
.
parent_quantizer
=
parent_quantizer
self
.
tp_group
=
tp_group
# used in inspect_tensor calls
self
.
iteration
=
debug_api
.
DEBUG_MANAGER
.
_trainer_iteration_count
self
.
rowwise_gemm_name
,
self
.
columnwise_gemm_name
=
_tensor_to_gemm_names_map
[
tensor_name
]
# The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled,
# rowwise_tensor_plan, and columnwise_tensor_plan are computed.
# These fields indicate the path where API calls will be inserted.
#
# inspect_tensor*_enabled are bool fields,
# indicating whether some feature will need to run inspect_tensor_* calls.
#
# *_tensor_plan are one of [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, HIGH_PRECISION]
# determining what will happen when the quantizer is used for that tensor.
self
.
output_tensor
=
tensor_name
in
[
"output"
,
"wgrad"
,
"dgrad"
]
if
self
.
output_tensor
:
self
.
inspect_tensor_enabled
,
self
.
rowwise_tensor_plan
=
(
self
.
get_plans_for_output_tensors
()
)
else
:
(
self
.
inspect_tensor_enabled
,
self
.
inspect_tensor_postquantize_enabled_rowwise
,
self
.
inspect_tensor_postquantize_enabled_columnwise
,
)
=
self
.
get_enabled_look_at_tensors
()
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
=
self
.
get_tensors_plan
()
self
.
log_messages_about_plans
()
def
get_plans_for_output_tensors
(
self
)
->
Tuple
[
bool
,
str
]:
"""
Returns tuple (inspect_tensor_enabled: bool, plan: str). Plan is one of the
API_CALL_MODIFY or HIGH_PRECISION, because debug quantizer does not support
gemm output in FP8.
"""
import
nvdlfw_inspect.api
as
debug_api
inspect_tensor_enabled
=
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
)
modify_enabled
=
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
layer_name
=
self
.
layer_name
,
gemm
=
self
.
rowwise_gemm_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
,
)
plan
=
API_CALL_MODIFY
if
modify_enabled
else
HIGH_PRECISION
return
inspect_tensor_enabled
,
plan
def
get_enabled_look_at_tensors
(
self
):
"""
Returns a tuple of booleans determining which functions look_at_tensor_*(...) should be called.
"""
import
nvdlfw_inspect.api
as
debug_api
inspect_tensor_enabled
=
debug_api
.
transformer_engine
.
inspect_tensor_enabled
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
)
inspect_tensor_postquantize_enabled_rowwise
=
(
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
,
gemm
=
self
.
rowwise_gemm_name
,
)
)
inspect_tensor_postquantize_enabled_columnwise
=
(
debug_api
.
transformer_engine
.
inspect_tensor_postquantize_enabled
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
,
gemm
=
self
.
columnwise_gemm_name
,
)
)
return
(
inspect_tensor_enabled
,
inspect_tensor_postquantize_enabled_rowwise
,
inspect_tensor_postquantize_enabled_columnwise
,
)
def
get_tensors_plan
(
self
):
"""
Returns (rowwise_plan, columnwise_plan). Each element of the tuple is one of
API_CALL_MODIFY, STANDARD_FP8_QUANTIZE, or HIGH_PRECISION, indicating the behavior
of this quantizer with respect to these tensors.
"""
import
nvdlfw_inspect.api
as
debug_api
rowwise_plan
=
None
columnwise_plan
=
None
modify_rowwise
=
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
layer_name
=
self
.
layer_name
,
gemm
=
self
.
rowwise_gemm_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
,
)
if
modify_rowwise
:
rowwise_plan
=
API_CALL_MODIFY
else
:
if
self
.
parent_quantizer
is
not
None
:
fp8_quantize
=
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
layer_name
=
self
.
layer_name
,
gemm
=
self
.
rowwise_gemm_name
,
iteration
=
self
.
iteration
,
)
if
fp8_quantize
:
rowwise_plan
=
STANDARD_FP8_QUANTIZE
if
rowwise_plan
is
None
:
rowwise_plan
=
HIGH_PRECISION
if
self
.
columnwise_gemm_name
is
not
None
:
modify_columnwise
=
debug_api
.
transformer_engine
.
modify_tensor_enabled
(
layer_name
=
self
.
layer_name
,
gemm
=
self
.
columnwise_gemm_name
,
tensor_name
=
self
.
tensor_name
,
iteration
=
self
.
iteration
,
)
if
modify_columnwise
:
columnwise_plan
=
API_CALL_MODIFY
else
:
if
self
.
parent_quantizer
is
not
None
:
fp8_quantize
=
debug_api
.
transformer_engine
.
fp8_gemm_enabled
(
layer_name
=
self
.
layer_name
,
gemm
=
self
.
columnwise_gemm_name
,
iteration
=
self
.
iteration
,
)
if
fp8_quantize
:
columnwise_plan
=
STANDARD_FP8_QUANTIZE
if
columnwise_plan
is
None
:
columnwise_plan
=
HIGH_PRECISION
return
rowwise_plan
,
columnwise_plan
def
log_messages_about_plans
(
self
):
"""
Logs the messages about the plans for each of the tensors.
"""
import
nvdlfw_inspect.api
as
debug_api
debug_api
.
log_message
(
f
"Tensor:
{
self
.
tensor_name
}
, gemm
{
self
.
rowwise_gemm_name
}
-"
f
"
{
self
.
rowwise_tensor_plan
}
"
,
layer_name
=
self
.
layer_name
,
extra_cachable_args
=
(
self
.
rowwise_gemm_name
,
self
.
tensor_name
),
)
debug_api
.
log_message
(
f
"Tensor:
{
self
.
tensor_name
}
, gemm
{
self
.
columnwise_gemm_name
}
-"
f
"
{
self
.
columnwise_tensor_plan
}
"
,
layer_name
=
self
.
layer_name
,
extra_cachable_args
=
(
self
.
columnwise_gemm_name
,
self
.
tensor_name
),
)
def
_call_inspect_tensor_api
(
self
,
tensor
,
rowwise_gemm_tensor
=
None
,
columnwise_gemm_tensor
=
None
):
import
nvdlfw_inspect.api
as
debug_api
args
=
{
"layer_name"
:
self
.
layer_name
,
"tensor"
:
tensor
,
"tensor_name"
:
self
.
tensor_name
,
"iteration"
:
debug_api
.
DEBUG_MANAGER
.
_trainer_iteration_count
,
"tp_group"
:
self
.
tp_group
,
}
if
tensor
is
not
None
and
self
.
inspect_tensor_enabled
:
debug_api
.
transformer_engine
.
inspect_tensor
(
**
args
)
if
self
.
output_tensor
:
return
if
(
self
.
rowwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_FP8_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_rowwise
):
args
[
"tensor"
]
=
rowwise_gemm_tensor
args
[
"rowwise"
]
=
True
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
**
args
)
if
(
self
.
columnwise_tensor_plan
in
[
API_CALL_MODIFY
,
STANDARD_FP8_QUANTIZE
]
and
self
.
inspect_tensor_postquantize_enabled_columnwise
):
args
[
"tensor"
]
=
columnwise_gemm_tensor
args
[
"rowwise"
]
=
False
debug_api
.
transformer_engine
.
inspect_tensor_postquantize
(
**
args
)
def
quantize
(
self
,
tensor
:
torch
.
Tensor
,
*
,
out
:
Optional
[
Union
[
torch
.
Tensor
,
DebugQuantizedTensor
]]
=
None
,
dtype
:
torch
.
dtype
=
None
,
):
"""Returns DebugQuantizedTensor object."""
import
nvdlfw_inspect.api
as
debug_api
assert
not
self
.
output_tensor
if
out
is
not
None
:
return
self
.
update_quantized
(
tensor
,
self
)
# 1. If there is fp8 quantization in at least one of the gemms,
# the quantization using the self.parent_quantizer is performed.
# rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise
rowwise_gemm_quantize
=
(
self
.
rowwise_usage
and
self
.
rowwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
)
columnwise_gemm_quantize
=
(
self
.
columnwise_usage
and
self
.
columnwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
)
if
columnwise_gemm_quantize
and
not
rowwise_gemm_quantize
:
rowwise_gemm_quantize
=
True
# only columnwise quantization not implemented
rowwise_gemm_tensor
,
columnwise_gemm_tensor
=
None
,
None
if
STANDARD_FP8_QUANTIZE
in
[
self
.
rowwise_tensor_plan
,
self
.
columnwise_tensor_plan
]:
self
.
parent_quantizer
.
set_usage
(
rowwise
=
True
,
columnwise
=
columnwise_gemm_quantize
,
# columnwise usage only is not supported
)
quantized_tensor
=
self
.
parent_quantizer
(
tensor
)
# if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8,
# one tensor with columnwise=True and rowwise=True is computed
# and both rowwise_tensor_plan and columnwise_tensor_plan point to it.
if
self
.
rowwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
:
rowwise_gemm_tensor
=
quantized_tensor
if
self
.
columnwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
:
columnwise_gemm_tensor
=
quantized_tensor
# 2. modify_tensor() is called, if it is used.
if
self
.
columnwise_tensor_plan
==
API_CALL_MODIFY
:
columnwise_gemm_tensor
=
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
gemm
=
self
.
columnwise_gemm_name
,
tensor
=
tensor
,
default_quantizer
=
self
.
parent_quantizer
,
iteration
=
self
.
iteration
,
dtype
=
dtype
,
)
if
columnwise_gemm_tensor
.
dtype
!=
dtype
:
raise
ValueError
(
"Dtype does not match the output of the modify_tensor call"
)
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
rowwise_gemm_tensor
=
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
gemm
=
self
.
rowwise_gemm_name
,
tensor
=
tensor
,
default_quantizer
=
self
.
parent_quantizer
,
iteration
=
self
.
iteration
,
dtype
=
dtype
,
)
if
rowwise_gemm_tensor
.
dtype
!=
dtype
:
raise
ValueError
(
"Dtype does not match the output of the modify_tensor call"
)
# 3. If some tensors still are not defined we use high precision tensor.
if
self
.
rowwise_tensor_plan
==
HIGH_PRECISION
:
rowwise_gemm_tensor
=
tensor
.
to
(
dtype
)
if
self
.
columnwise_tensor_plan
==
HIGH_PRECISION
:
columnwise_gemm_tensor
=
tensor
.
to
(
dtype
)
self
.
_call_inspect_tensor_api
(
tensor
,
rowwise_gemm_tensor
,
columnwise_gemm_tensor
)
# sometimes we may want to return simple tensor with only rowwise_gemm
if
self
.
tensor_name
in
[
"wgrad"
,
"dgrad"
,
"output"
]:
return
rowwise_gemm_tensor
return
DebugQuantizedTensor
(
rowwise_gemm_tensor
=
rowwise_gemm_tensor
,
columnwise_gemm_tensor
=
columnwise_gemm_tensor
,
quantizer
=
self
,
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
)
def
process_gemm_output
(
self
,
tensor
:
torch
.
Tensor
):
"""This call is invoked after the gemm to inspect and modify the output tensor."""
import
nvdlfw_inspect.api
as
debug_api
assert
self
.
parent_quantizer
is
None
,
"FP8 output is not supported for debug=True."
assert
self
.
output_tensor
tensor_to_gemm
=
{
"output"
:
"fprop"
,
"wgrad"
:
"wgrad"
,
"dgrad"
:
"dgrad"
}
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
tensor
=
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
self
.
layer_name
,
gemm
=
tensor_to_gemm
[
self
.
tensor_name
],
tensor_name
=
self
.
tensor_name
,
tensor
=
tensor
,
iteration
=
self
.
iteration
,
default_quantizer
=
self
.
parent_quantizer
,
)
self
.
_call_inspect_tensor_api
(
tensor
)
return
tensor
def
make_empty
(
self
,
shape
:
Iterable
[
int
],
*
,
dtype
:
torch
.
dtype
=
torch
.
float32
,
device
:
Optional
[
torch
.
device
]
=
None
,
)
->
QuantizedTensor
:
"""Override make_empty() from Quantizer class."""
if
self
.
parent_quantizer
is
not
None
:
return
self
.
parent_quantizer
.
make_empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
return
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
)
def
calibrate
(
self
,
tensor
:
torch
.
Tensor
):
"""Calibration override, should not be invoked."""
raise
RuntimeError
(
"[NVTORCH-INSPECT ERROR] Calibration with debug is not supported"
)
def
update_quantized
(
self
,
src
:
torch
.
Tensor
,
dst
:
QuantizedTensor
,
*
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
QuantizedTensor
:
"""Update quantized tensor - used in weight caching."""
import
nvdlfw_inspect.api
as
debug_api
assert
noop_flag
is
None
,
"CUDA Graphs are not supported with debug=True!"
updated_rowwise_gemm
=
False
if
self
.
parent_quantizer
is
not
None
:
if
(
dst
.
rowwise_gemm_tensor
is
not
None
and
self
.
rowwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
):
if
hasattr
(
dst
.
rowwise_gemm_tensor
,
"quantize_"
):
dst
.
rowwise_gemm_tensor
.
quantize_
(
src
,
noop_flag
=
None
)
else
:
tex
.
quantize
(
src
,
self
.
parent_quantizer
,
dst
.
rowwise_gemm_tensor
,
None
)
updated_rowwise_gemm
=
True
if
(
dst
.
columnwise_gemm_tensor
is
not
None
and
self
.
columnwise_tensor_plan
==
STANDARD_FP8_QUANTIZE
and
not
updated_rowwise_gemm
):
if
hasattr
(
dst
.
columnwise_gemm_tensor
,
"quantize_"
):
dst
.
columnwise_gemm_tensor
.
quantize_
(
src
,
noop_flag
=
None
)
else
:
tex
.
quantize
(
src
,
self
.
parent_quantizer
,
dst
.
columnwise_gemm_tensor
,
None
)
if
self
.
columnwise_tensor_plan
==
API_CALL_MODIFY
:
out
=
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
gemm
=
self
.
columnwise_gemm_name
,
tensor
=
src
,
default_quantizer
=
self
.
parent_quantizer
,
out
=
dst
.
columnwise_gemm_tensor
,
iteration
=
self
.
iteration
,
)
assert
out
is
None
,
(
"API call debug_api.transformer_engine.modify_tensor with out != None should"
" return None"
)
if
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
:
debug_api
.
transformer_engine
.
modify_tensor
(
layer_name
=
self
.
layer_name
,
tensor_name
=
self
.
tensor_name
,
gemm
=
self
.
rowwise_gemm_name
,
tensor
=
src
,
default_quantizer
=
self
.
parent_quantizer
,
out
=
dst
.
rowwise_gemm_tensor
,
iteration
=
self
.
iteration
,
)
if
self
.
rowwise_tensor_plan
==
HIGH_PRECISION
:
dst
.
rowwise_gemm_tensor
.
copy_
(
src
)
if
self
.
columnwise_tensor_plan
==
HIGH_PRECISION
:
# if they are the same tensor object, it is sufficient to update one
if
dst
.
columnwise_gemm_tensor
is
not
dst
.
rowwise_gemm_tensor
:
dst
.
columnwise_gemm_tensor
.
copy_
(
src
)
self
.
_call_inspect_tensor_api
(
src
,
dst
.
rowwise_gemm_tensor
,
dst
.
columnwise_gemm_tensor
)
def
any_feature_enabled
(
self
)
->
bool
:
"""Returns bool if there is at least one API call enabled."""
if
self
.
output_tensor
:
return
self
.
inspect_tensor_enabled
or
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
if
(
self
.
inspect_tensor_enabled
or
self
.
inspect_tensor_postquantize_enabled_rowwise
or
self
.
inspect_tensor_postquantize_enabled_columnwise
or
self
.
rowwise_tensor_plan
==
API_CALL_MODIFY
or
self
.
columnwise_tensor_plan
==
API_CALL_MODIFY
):
return
True
if
self
.
parent_quantizer
is
not
None
:
if
self
.
rowwise_tensor_plan
!=
STANDARD_FP8_QUANTIZE
:
return
True
if
self
.
columnwise_tensor_plan
!=
STANDARD_FP8_QUANTIZE
:
return
True
return
False
class
DebugQuantizedTensor
:
"""
Class containing quantized tensors after debug. Depending on configuration
it can contain one or two different objects. These objects can be accessed by the method
get_tensor().
"""
def
__init__
(
self
,
rowwise_gemm_tensor
,
columnwise_gemm_tensor
,
quantizer
,
layer_name
=
None
,
tensor_name
=
None
,
):
self
.
rowwise_gemm_tensor
=
rowwise_gemm_tensor
self
.
columnwise_gemm_tensor
=
columnwise_gemm_tensor
self
.
quantizer
=
quantizer
self
.
_layer_name
=
layer_name
self
.
_tensor_name
=
tensor_name
def
prepare_for_saving
(
self
):
""" " Prepare for saving method override"""
self
.
tensors_to_save
=
(
[
self
.
rowwise_gemm_tensor
,
self
.
columnwise_gemm_tensor
]
if
self
.
rowwise_gemm_tensor
is
not
self
.
columnwise_gemm_tensor
else
[
self
.
rowwise_gemm_tensor
]
)
tensor_list
,
tensor_objects_list
=
prepare_for_saving
(
*
self
.
tensors_to_save
)
self
.
tensors_to_save
=
tensor_objects_list
# pylint: disable=unbalanced-tuple-unpacking
return
tensor_list
,
self
def
restore_from_saved
(
self
,
tensors
):
"""Restore from saved method override"""
tensor_objects_list
,
saved_tensors
=
restore_from_saved
(
self
.
tensors_to_save
,
tensors
,
return_saved_tensors
=
True
,
)
if
len
(
tensor_objects_list
)
==
2
:
# pylint: disable=unbalanced-tuple-unpacking
self
.
rowwise_gemm_tensor
,
self
.
columnwise_gemm_tensor
=
tensor_objects_list
else
:
self
.
rowwise_gemm_tensor
=
tensor_objects_list
[
0
]
self
.
columnwise_gemm_tensor
=
self
.
rowwise_gemm_tensor
return
saved_tensors
def
quantize_
(
self
,
tensor
,
*
,
noop_flag
=
None
):
""" " quantize_ method override"""
assert
noop_flag
is
None
,
"CUDA Graphs are not supported with debug=True!"
self
.
quantizer
.
update_quantized
(
tensor
,
self
)
def
dequantize
(
self
,
*
,
dtype
=
None
):
""" " dequantize method override"""
if
dtype
is
None
:
dtype
=
self
.
rowwise_gemm_tensor
.
dtype
return
self
.
rowwise_gemm_tensor
.
dequantize
().
to
(
dtype
)
def
get_tensor
(
self
,
transpose
:
bool
):
"""Is used in the python gemm() to get tensor or transpose of the tensor."""
return
self
.
rowwise_gemm_tensor
if
not
transpose
else
self
.
columnwise_gemm_tensor
def
size
(
self
):
"""Size of the tensor."""
return
self
.
rowwise_gemm_tensor
.
size
()
def
update_usage
(
self
,
rowwise_usage
:
bool
,
columnwise_usage
:
bool
):
"""Update usage of the tensor."""
transformer_engine/debug/pytorch/debug_state.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Managing the state of all the debugged layers.
"""
import
sys
class
TEDebugState
:
"""
A class to manage the state of debug layers.
"""
layer_count
=
1
layers_initialized
=
{}
weight_tensor_tp_group_reduce
=
True
debug_enabled
=
None
@
classmethod
def
initialize
(
cls
):
"""
If debug_api module is initialized, then sets cls.debug_enabled to True.
"""
if
"nvdlfw_inspect"
in
sys
.
modules
:
import
nvdlfw_inspect.api
as
debug_api
if
cls
.
debug_enabled
is
False
and
debug_api
.
DEBUG_MANAGER
is
not
None
:
# This method is invoked when initializing TE modules.
# If this error is thrown, it means that some TE module had been initialized before
# debug_api was initialized, and now a new TE module is being initialized.
# This is likely to be a bug.
raise
RuntimeError
(
"[nv_dlfw_inspect] nv_dlfw_inspect module should be initialized before"
" initialization of the first TE module"
)
cls
.
debug_enabled
=
debug_api
.
DEBUG_MANAGER
is
not
None
@
classmethod
def
_reset
(
cls
):
"""Resets layer count and stats buffers."""
from
..features.utils.stats_buffer
import
STATS_BUFFERS
STATS_BUFFERS
.
reset
()
cls
.
debug_enabled
=
None
cls
.
layers_initialized
.
clear
()
@
classmethod
def
get_layer_count
(
cls
):
"""
Layer counter is used when layer names are not provided to modules by the user.
"""
lc
=
cls
.
layer_count
cls
.
layer_count
+=
1
return
lc
@
classmethod
def
set_weight_tensor_tp_group_reduce
(
cls
,
enabled
):
"""Sets weight tensor reduction mode."""
cls
.
weight_tensor_tp_group_reduce
=
enabled
def
set_weight_tensor_tp_group_reduce
(
enabled
):
"""Sets weight tensor reduction mode."""
TEDebugState
.
set_weight_tensor_tp_group_reduce
(
enabled
)
transformer_engine/debug/pytorch/utils.py
0 → 100644
View file @
ab3e5a92
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utils functions for the debug module."""
def
any_feature_enabled
(
quantizers
):
"""Returns True if at least one API call is made from DebugQuantizer."""
return
any
(
q
.
any_feature_enabled
()
for
q
in
quantizers
)
transformer_engine/jax/__init__.py
View file @
ab3e5a92
...
@@ -83,7 +83,8 @@ _load_library()
...
@@ -83,7 +83,8 @@ _load_library()
from
.
import
flax
from
.
import
flax
from
.
import
quantize
from
.
import
quantize
from
.quantize
import
fp8_autocast
from
.quantize
import
fp8_autocast
,
update_collections
,
get_delayed_scaling
from
.quantize
import
NVTE_FP8_COLLECTION_NAME
from
.sharding
import
MeshResource
from
.sharding
import
MeshResource
from
.sharding
import
MajorShardingType
,
ShardingResource
,
ShardingType
from
.sharding
import
MajorShardingType
,
ShardingResource
,
ShardingType
...
@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper(
...
@@ -101,11 +102,14 @@ ShardingResource = deprecate_wrapper(
)
)
__all__
=
[
__all__
=
[
"NVTE_FP8_COLLECTION_NAME"
,
"fp8_autocast"
,
"fp8_autocast"
,
"update_collections"
,
"get_delayed_scaling"
,
"MeshResource"
,
"MeshResource"
,
"MajorShardingType"
,
"MajorShardingType"
,
"ShardingResource"
,
"ShardingResource"
,
"ShardingType"
,
"ShardingType"
,
"flax"
,
"flax"
,
"
praxis
"
,
"
quantize
"
,
]
]
transformer_engine/jax/activation.py
View file @
ab3e5a92
...
@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
...
@@ -91,7 +91,6 @@ def _activation_bwd_rule(activation_type, ctx, g):
(
x
,
_
)
=
ctx
(
x
,
_
)
=
ctx
assert
x
.
dtype
==
g
.
dtype
assert
x
.
dtype
==
g
.
dtype
dx
=
tex
.
dact_lu
(
g
,
x
,
activation_type
)
dx
=
tex
.
dact_lu
(
g
,
x
,
activation_type
)
dx
=
jnp
.
reshape
(
dx
,
x
.
shape
)
return
(
dx
,
None
)
return
(
dx
,
None
)
...
...
transformer_engine/jax/cpp_extensions/activation.py
View file @
ab3e5a92
...
@@ -10,6 +10,7 @@ from packaging import version
...
@@ -10,6 +10,7 @@ from packaging import version
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax
import
dtypes
from
jax.experimental.custom_partitioning
import
SdyShardingRule
from
jax.sharding
import
PartitionSpec
from
jax.sharding
import
PartitionSpec
import
transformer_engine_jax
import
transformer_engine_jax
...
@@ -26,12 +27,12 @@ from .misc import (
...
@@ -26,12 +27,12 @@ from .misc import (
should_apply_1x_fused_dbias_war_for_arch_l_100
,
should_apply_1x_fused_dbias_war_for_arch_l_100
,
NamedSharding
,
NamedSharding
,
)
)
from
.quantization
import
_jax_
quantize_dbias
,
_jax_
dbias
,
quantize_dbias
from
.quantization
import
_jax_dbias
,
_
quantize_dbias
_impl
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
Quantize
Axis
,
Quantize
Layout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
)
)
...
@@ -110,41 +111,31 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -110,41 +111,31 @@ class ActLuPrimitive(BasePrimitive):
"""
"""
te_act_lu_p abstract
te_act_lu_p abstract
"""
"""
del
act_enum
,
act_len
,
scale_shapes
del
act_enum
,
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
assert
x_aval
.
shape
[
-
2
]
==
act_len
,
(
out_shape
=
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
*
x_aval
.
shape
[:
-
2
],
f
"
{
x_aval
.
shape
}
and act_len
{
act_len
}
"
1
,
x_aval
.
shape
[
-
1
],
)
)
out_shape
=
(
*
x_aval
.
shape
[:
-
2
],
x_aval
.
shape
[
-
1
])
# Exclude act dim
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
out_shape
[:
-
2
]
+
(
out_shape
[
-
1
],),
is_padded
=
not
is_outer
)
).
get_scale_shape_2x
(
out_shape
,
is_padded
=
not
is_outer
,
flatten_axis
=-
1
)
if
not
is_2x
:
if
len
(
rowwise_scale_inv_shape
)
>
1
:
out_shape
=
(
1
,)
rowwise_scale_inv_shape
=
(
colwise_scale_inv_shape
=
(
1
,)
rowwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
rowwise_scale_inv_shape
[
-
1
:]
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
out_shape
,
dtype
=
out_dtype
)
)
if
len
(
colwise_scale_inv_shape
)
>
1
:
colwise_scale_inv_shape
=
(
colwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
colwise_scale_inv_shape
[
-
1
:]
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
)
if
is_2x
:
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
out_shape
,
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
return
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
return
out_aval
,
colwise_out_aval
,
scale_inv_aval
,
colwise_scale_inv_aval
,
updated_amax_aval
...
@@ -172,7 +163,7 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -172,7 +163,7 @@ class ActLuPrimitive(BasePrimitive):
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
out
=
ffi
.
ffi_lowering
(
ActLuPrimitive
.
name
)(
out
=
ffi
.
ffi_lowering
(
ActLuPrimitive
.
name
)(
ctx
,
x
,
scale
,
act_enum
=
act_enum
,
scaling_mode
=
scaling_mode
,
is_2x
=
is_2x
ctx
,
x
,
scale
,
act_enum
=
act_enum
,
scaling_mode
=
scaling_mode
.
value
,
is_2x
=
is_2x
)
)
return
out
return
out
...
@@ -211,15 +202,8 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -211,15 +202,8 @@ class ActLuPrimitive(BasePrimitive):
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
out
.
shape
[:
-
2
]
+
(
out
.
shape
[
-
1
],),
is_padded
=
False
)
).
get_scale_shape_2x
(
out
.
shape
,
is_padded
=
False
,
flatten_axis
=-
1
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
# Slice out padding for MXFP8, noop for DelayedScaling
rowwise_scale_inv_shape
=
(
rowwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
rowwise_scale_inv_shape
[
-
1
:]
)
if
is_2x
:
colwise_scale_inv_shape
=
(
colwise_scale_inv_shape
[:
-
1
]
+
(
1
,)
+
colwise_scale_inv_shape
[
-
1
:]
)
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
)
)
...
@@ -227,6 +211,7 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -227,6 +211,7 @@ class ActLuPrimitive(BasePrimitive):
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
)
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
@
staticmethod
@
staticmethod
...
@@ -292,11 +277,14 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -292,11 +277,14 @@ class ActLuPrimitive(BasePrimitive):
is_outer
,
is_outer
,
)
# Unused.
)
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_spec
=
(
*
x_spec
[:
-
2
],
None
,
x_spec
[
-
2
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_spec
=
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
if
is_2x
:
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
out_spec
)
colwise_out_spec
=
multidim_transpose
(
out_spec
,
transpose_axis
=-
1
)
else
:
else
:
colwise_out_spec
=
out_spec
colwise_out_spec
=
out_spec
else
:
else
:
...
@@ -304,18 +292,24 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -304,18 +292,24 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding
=
NamedSharding
(
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
)
)
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
out_spec
if
is_2x
:
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])
),
desc
=
"ActLuPrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.amax"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"ActLuPrimitive.amax"
)
colwise_scale_inv_sharding
=
NamedSharding
(
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
mesh
,
PartitionSpec
(
*
colwise_scale_inv_spec
),
desc
=
"ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.colwise_scale_inv"
)
)
return
(
return
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
...
@@ -340,14 +334,14 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -340,14 +334,14 @@ class ActLuPrimitive(BasePrimitive):
):
):
del
result_infos
,
is_outer
# Unused.
del
result_infos
,
is_outer
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
out_spec
=
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
if
act_len
==
2
and
x_spec
[
-
1
]
is
None
:
# Ensure last axis is partitioned and not the gating axis
out_spec
=
(
*
x_spec
[:
-
2
],
x_spec
[
-
1
])
out_spec
=
(
*
x_spec
[:
-
2
],
None
,
x_spec
[
-
2
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.out"
)
if
is_2x
:
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
out_spec
)
colwise_out_spec
=
multidim_transpose
(
out_spec
,
transpose_axis
=-
1
)
else
:
else
:
colwise_out_spec
=
out_spec
colwise_out_spec
=
out_spec
else
:
else
:
...
@@ -355,21 +349,25 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -355,21 +349,25 @@ class ActLuPrimitive(BasePrimitive):
colwise_out_sharding
=
NamedSharding
(
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"ActLuPrimitive.colwise_out"
)
)
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
out_spec
if
is_2x
:
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])
),
desc
=
"ActLuPrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.amax"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"ActLuPrimitive.amax"
)
colwise_scale_inv_sharding
=
NamedSharding
(
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
mesh
,
PartitionSpec
(
*
colwise_scale_inv_spec
),
desc
=
"ActLuPrimitive.colwise_scale_inv"
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"ActLuPrimitive.colwise_scale_inv"
)
)
arg_shardings
=
list
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
[
0
]
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
))
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
tuple
(
arg_shardings
)
out_shardings
=
(
out_shardings
=
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
...
@@ -394,7 +392,7 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -394,7 +392,7 @@ class ActLuPrimitive(BasePrimitive):
)
)
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
else
:
global_updated_amax
=
local_amax
global_updated_amax
=
local_amax
...
@@ -409,10 +407,59 @@ class ActLuPrimitive(BasePrimitive):
...
@@ -409,10 +407,59 @@ class ActLuPrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
out_dtype
,
act_enum
,
act_len
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
value_types
,
result_types
,
):
del
out_dtype
,
act_enum
,
act_len
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
result_types
x_rank
=
len
(
value_types
[
0
].
shape
)
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
x_rank
-
1
,
unique_var
=
"i"
,
flatten_axis
=-
2
)
x_axes
=
scale_rules
.
input_spec
+
(
f
"x
{
x_rank
-
1
}
"
,)
out
=
(
*
x_axes
[:
-
2
],
x_axes
[
-
1
])
scale_inv
=
scale_rules
.
rowwise_rule
colwise_scale_inv
=
scale_rules
.
colwise_rule
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
static_axis_boundary
=-
1
,
transpose_axis
=-
2
)
)
else
:
colwise_out
=
out
else
:
colwise_out
=
(
"j"
,)
colwise_scale_inv
=
(
"k"
,)
# amax is always a unit tensor.
amax
=
(
"l"
,)
return
SdyShardingRule
(
(
x_axes
,
"…1"
,
),
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
amax
),
**
scale_rules
.
factor_sizes
,
)
register_primitive
(
ActLuPrimitive
)
register_primitive
(
ActLuPrimitive
)
# TODO(Jeremy): replace is_2x with q_layout
class
DActLuDBiasQuantizePrimitive
(
BasePrimitive
):
class
DActLuDBiasQuantizePrimitive
(
BasePrimitive
):
"""
"""
DActLu DBias Cast Transpose Primitive
DActLu DBias Cast Transpose Primitive
...
@@ -445,42 +492,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -445,42 +492,41 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
te_dact_dbias_quantize_p abstract
te_dact_dbias_quantize_p abstract
"""
"""
del
act_enum
,
scale_shapes
del
act_enum
,
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
dz_dtype
=
dtypes
.
canonicalize_dtype
(
dz_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
dz_dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
x_aval
.
dtype
==
dtype
assert
x_aval
.
dtype
==
dz_dtype
assert
x_aval
.
shape
[
-
2
]
==
act_len
,
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f
"
{
x_aval
.
shape
}
and act_len
{
act_len
}
"
)
assert
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
.
dtype
==
jnp
.
float32
ir_hidden_size
=
dz_aval
.
shape
[
-
1
]
ir_hidden_size
=
dz_aval
.
shape
[
-
1
]
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
gi_hidden_size
=
act_len
*
x_aval
.
shape
[
-
1
]
assert
act_len
*
ir_hidden_size
==
gi_hidden_size
assert
act_len
*
ir_hidden_size
==
gi_hidden_size
out_shape
=
x_aval
.
shape
out_shape
=
x_aval
.
shape
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
out_aval
=
x_aval
.
update
(
shape
=
out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
,
flatten_axis
=-
2
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
if
is_2x
:
if
is_2x
:
# Don't transpose output for MXFP8
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
colwise_out_shape
=
multidim_transpose
(
out_shape
,
transpose_axis
=-
2
)
t_shape
=
out_shape
else
:
else
:
t_shape
=
multidim_transpose
(
out_shape
)
colwise_out_shape
=
out_shape
colwise_out_aval
=
x_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
else
:
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
colwise_out_shape
=
(
1
,)
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
colwise_scale_inv_shape
=
(
1
,)
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_out_shape
,
dtype
=
out_dtype
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
if
is_dbias
:
if
is_dbias
:
dbias_shape
=
gi_hidden_size
dbias_shape
=
(
act_len
,
ir_hidden_size
)
dbias_aval
=
x_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dact_dbias_quantize_workspace_sizes
(
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dact_dbias_quantize_workspace_sizes
(
x_aval
.
size
//
gi_hidden_size
,
x_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
gi_hidden_size
,
...
@@ -489,9 +535,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -489,9 +535,14 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scaling_mode
,
scaling_mode
,
is_2x
,
is_2x
,
)
)
wkspace_aval
=
x_aval
.
update
(
wkspace_shape
=
wkspace_info
[
0
]
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
wkspace_dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
else
:
dbias_shape
=
(
1
,)
wkspace_shape
=
(
1
,)
wkspace_dtype
=
jnp
.
float32
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
dbias_shape
,
dtype
=
dz_dtype
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
wkspace_shape
,
dtype
=
wkspace_dtype
)
return
(
return
(
out_aval
,
out_aval
,
...
@@ -543,7 +594,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -543,7 +594,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
dz
,
dz
,
x
,
x
,
scale
,
scale
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
.
value
,
is_2x
=
is_2x
,
is_2x
=
is_2x
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
act_enum
=
int
(
act_enum
),
act_enum
=
int
(
act_enum
),
...
@@ -587,23 +638,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -587,23 +638,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
,
flatten_axis
=-
2
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
# Slice out padding for MXFP8, noop for DelayedScaling
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
=
jax
.
lax
.
slice
(
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
)
if
is_2x
:
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
)
if
is_2x
:
return
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
colwise_scale_inv
=
jax
.
lax
.
slice
(
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
return
(
out
,
colwise_out
,
scale_inv
,
colwise_scale_inv
,
updated_amax
,
dbias
,
)
# Exclude wkspace
@
staticmethod
@
staticmethod
def
batcher
(
def
batcher
(
...
@@ -670,15 +714,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -670,15 +714,16 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
result_infos
,
result_infos
,
):
):
del
out_dtype
,
result_infos
,
act_enum
del
out_dtype
,
result_infos
,
act_enum
del
scale_dtype
,
scale_shapes
,
is_dbias
,
act_len
,
is_outer
del
scale_dtype
,
scale_shapes
,
act_len
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
scale_spec
=
get_padded_spec
(
arg_infos
[
2
])
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.out"
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.out"
)
)
if
is_2x
:
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_x_spec
=
multidim_transpose
(
x_spec
)
colwise_x_spec
=
multidim_transpose
(
x_spec
,
transpose_axis
=-
2
)
else
:
else
:
colwise_x_spec
=
x_spec
colwise_x_spec
=
x_spec
else
:
else
:
...
@@ -687,23 +732,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -687,23 +732,32 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
)
)
dbias_shaprding
=
NamedSharding
(
dbias_spec
=
x_spec
[
-
2
:]
if
is_dbias
else
(
None
,)
dbias_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
x
_spec
[
-
1
]
),
PartitionSpec
(
*
dbias
_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
)
)
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
x_spec
if
is_2x
:
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
)
)
amax_sharding
=
NamedSharding
(
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.amax"
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.amax"
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
colwise_scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
PartitionSpec
(
*
colwise_scale_inv_spec
),
)
desc
=
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
,
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
)
return
(
return
(
out_sharding
,
out_sharding
,
...
@@ -711,7 +765,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -711,7 +765,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
scale_inv_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
amax_sharding
,
dbias_sha
p
rding
,
dbias_sharding
,
)
)
@
staticmethod
@
staticmethod
...
@@ -731,10 +785,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -731,10 +785,15 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
):
):
del
result_infos
,
is_outer
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
x_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"out"
)
scale_spec
=
get_padded_spec
(
arg_infos
[
2
])
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.out"
)
if
is_2x
:
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_x_spec
=
multidim_transpose
(
x_spec
)
colwise_x_spec
=
multidim_transpose
(
x_spec
,
transpose_axis
=-
2
)
else
:
else
:
colwise_x_spec
=
x_spec
colwise_x_spec
=
x_spec
else
:
else
:
...
@@ -743,38 +802,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -743,38 +802,39 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
mesh
,
PartitionSpec
(
*
colwise_x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.colwise_out"
)
)
dbias_shaprding
=
NamedSharding
(
dbias_spec
=
x_spec
[
-
2
:]
if
is_dbias
else
(
None
,)
dbias_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
x
_spec
[
-
1
]
),
PartitionSpec
(
*
dbias
_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
desc
=
"DActLuDBiasQuantizePrimitive.dbias"
,
)
)
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
x_spec
if
is_2x
:
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"ActLuPrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"DActLuDBiasQuantizePrimitive.amax"
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"ActLuPrimitive.amax"
)
scale_inv_sharding
=
NamedSharding
(
colwise_scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DActLuDBiasQuantizePrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
colwise_scale_inv_spec
),
desc
=
"ActLuPrimitive.colwise_scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DActLuDBiasQuantizePrimitive.colwise_scale_inv"
)
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
(
arg_shardings
[
1
],
arg_shardings
[
1
],
*
arg_shardings
[
2
:],
)
# dz and x are the same
out_shardings
=
(
out_shardings
=
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
scale_inv_sharding
,
scale_inv_sharding
,
colwise_scale_inv_sharding
,
colwise_scale_inv_sharding
,
amax_sharding
,
amax_sharding
,
dbias_sha
p
rding
,
dbias_sharding
,
)
)
def
sharded_impl
(
dz
,
x
,
scale
):
def
sharded_impl
(
dz
,
x
,
scale
):
...
@@ -799,7 +859,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -799,7 +859,7 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
else
:
else
:
global_dbias
=
local_dbias
global_dbias
=
local_dbias
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
else
:
global_updated_amax
=
local_amax
global_updated_amax
=
local_amax
...
@@ -808,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
...
@@ -808,6 +868,46 @@ class DActLuDBiasQuantizePrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_dbias
,
act_enum
,
act_len
,
is_outer
,
mesh
,
value_types
,
result_types
,
):
del
out_dtype
,
scale_dtype
,
scale_shapes
,
act_enum
,
act_len
,
is_outer
,
mesh
,
result_types
x_rank
=
len
(
value_types
[
1
].
shape
)
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
x_rank
,
unique_var
=
"i"
,
flatten_axis
=-
2
)
x_axes
=
scale_rules
.
input_spec
out
=
x_axes
if
is_2x
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
transpose_axis
=-
2
))
else
:
colwise_out
=
tuple
(
x_axes
)
else
:
colwise_out
=
(
"j"
,)
dbias
=
x_axes
[
-
2
:]
if
is_dbias
else
(
"k"
,)
amax
=
(
"…4"
,)
return
SdyShardingRule
(
((
"…0"
,),
tuple
(
x_axes
),
(
"…2"
,)),
(
out
,
colwise_out
,
scale_rules
.
rowwise_rule
,
scale_rules
.
colwise_rule
,
amax
,
dbias
),
**
scale_rules
.
factor_sizes
,
)
register_primitive
(
DActLuDBiasQuantizePrimitive
)
register_primitive
(
DActLuDBiasQuantizePrimitive
)
...
@@ -816,14 +916,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
...
@@ -816,14 +916,21 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S
"""
"""
JAX native activation implementation
JAX native activation implementation
"""
"""
x
=
jnp
.
split
(
inputs
,
len
(
activation_type
),
axis
=-
1
)
act_len
=
len
(
activation_type
)
assert
inputs
.
shape
[
-
2
]
==
act_len
,
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f
"
{
inputs
.
shape
}
and act_len
{
act_len
}
"
)
x
=
jnp
.
split
(
inputs
,
act_len
,
axis
=-
2
)
acts
=
[]
acts
=
[]
for
idx
,
act_fn
in
enumerate
(
activation_type
):
for
idx
,
act_fn
in
enumerate
(
activation_type
):
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
x_i
=
_convert_to_activation_function
(
act_fn
)(
x
[
idx
])
acts
.
append
(
x_i
)
acts
.
append
(
x_i
)
x
=
reduce
(
operator
.
mul
,
acts
)
x
=
reduce
(
operator
.
mul
,
acts
)
x
=
jnp
.
squeeze
(
x
,
axis
=-
2
)
if
quantizer
:
if
quantizer
:
return
quantizer
.
quantize
(
x
)
return
quantizer
.
quantize
(
x
,
flatten_axis
=-
1
)
return
x
return
x
...
@@ -837,6 +944,12 @@ def _jax_quantize_dact_dbias(
...
@@ -837,6 +944,12 @@ def _jax_quantize_dact_dbias(
"""
"""
JAX implementation of dact_lu and dbias with optional quantization
JAX implementation of dact_lu and dbias with optional quantization
"""
"""
act_len
=
len
(
activation_type
)
assert
x
.
shape
[
-
2
]
==
act_len
,
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f
"
{
x
.
shape
}
and act_len
{
act_len
}
"
)
_
,
vjp_func
=
jax
.
vjp
(
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
x
.
astype
(
jnp
.
float32
)
partial
(
_jax_act_lu
,
activation_type
=
activation_type
),
x
.
astype
(
jnp
.
float32
)
)
)
...
@@ -844,10 +957,10 @@ def _jax_quantize_dact_dbias(
...
@@ -844,10 +957,10 @@ def _jax_quantize_dact_dbias(
dbias
=
None
dbias
=
None
if
is_dbias
:
if
is_dbias
:
dbias
=
_jax_dbias
(
dx
).
as
type
(
x
.
dtype
)
dbias
=
_jax_dbias
(
dx
,
d
type
=
x
.
dtype
,
flatten_axis
=-
2
)
if
quantizer
is
not
None
:
if
quantizer
is
not
None
:
dx
=
quantizer
.
quantize
(
dx
,
dq_dtype
=
x
.
dtype
)
dx
=
quantizer
.
quantize
(
dx
,
dq_dtype
=
x
.
dtype
,
flatten_axis
=-
2
)
else
:
else
:
dx
=
dx
.
astype
(
x
.
dtype
)
dx
=
dx
.
astype
(
x
.
dtype
)
...
@@ -863,6 +976,7 @@ def act_lu(
...
@@ -863,6 +976,7 @@ def act_lu(
Args:
Args:
x: Input tensor to be processed.
x: Input tensor to be processed.
Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function to apply.
activation_type: Type of activation function to apply.
quantizer: Optional quantizer for FP8 quantization of the output.
quantizer: Optional quantizer for FP8 quantization of the output.
...
@@ -873,12 +987,17 @@ def act_lu(
...
@@ -873,12 +987,17 @@ def act_lu(
A ScaledTensor containing the quantized activated input.
A ScaledTensor containing the quantized activated input.
"""
"""
act_type_id
=
ActivationEnum
[
activation_type
].
value
act_type_id
=
ActivationEnum
[
activation_type
].
value
act_len
=
len
(
activation_type
)
assert
x
.
shape
[
-
2
]
==
act_len
,
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f
"
{
x
.
shape
}
and act_len
{
act_len
}
"
)
if
not
ActLuPrimitive
.
enabled
():
if
not
ActLuPrimitive
.
enabled
():
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
# TE/common does not support colwise-only quantization yet
# TE/common does not support colwise-only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
COLWISE
:
if
quantizer
is
not
None
and
quantizer
.
q_
layout
==
Quantize
Layout
.
COLWISE
:
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
return
_jax_act_lu
(
x
,
activation_type
,
quantizer
)
# TE/common does not support 2x quantization for DelayedScaling yet
# TE/common does not support 2x quantization for DelayedScaling yet
...
@@ -889,17 +1008,16 @@ def act_lu(
...
@@ -889,17 +1008,16 @@ def act_lu(
return
war_output
return
war_output
scale
=
jnp
.
empty
((
1
,),
jnp
.
float32
)
scale
=
jnp
.
empty
((
1
,),
jnp
.
float32
)
output_shape
=
(
*
x
.
shape
[:
-
1
],
x
.
shape
[
-
1
]
//
len
(
activation_type
)
)
output_shape
=
(
*
x
.
shape
[:
-
2
],
x
.
shape
[
-
1
])
if
quantizer
is
None
:
if
quantizer
is
None
:
x
=
x
.
reshape
((
-
1
,
len
(
activation_type
),
x
.
shape
[
-
1
]
//
len
(
activation_type
)))
out
,
_
,
_
,
_
,
_
=
ActLuPrimitive
.
outer_primitive
.
bind
(
out
,
_
,
_
,
_
,
_
=
ActLuPrimitive
.
outer_primitive
.
bind
(
x
,
x
,
scale
,
scale
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
act_enum
=
act_type_id
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
)
,
act_len
=
act_len
,
scaling_mode
=
ScalingMode
.
N
VTE_DELAYED_TENSOR
_SCALING
.
value
,
scaling_mode
=
ScalingMode
.
N
O
_SCALING
.
value
,
is_2x
=
False
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((),
()),
scale_shapes
=
((),
()),
...
@@ -911,7 +1029,6 @@ def act_lu(
...
@@ -911,7 +1029,6 @@ def act_lu(
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
scale
=
quantizer
.
scale
scale
=
quantizer
.
scale
x
=
x
.
reshape
((
*
x
.
shape
[:
-
1
],
len
(
activation_type
),
x
.
shape
[
-
1
]
//
len
(
activation_type
)))
(
(
rowwise_casted_output
,
rowwise_casted_output
,
colwise_casted_output
,
colwise_casted_output
,
...
@@ -923,25 +1040,15 @@ def act_lu(
...
@@ -923,25 +1040,15 @@ def act_lu(
scale
,
scale
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
act_enum
=
act_type_id
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
)
,
act_len
=
act_len
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
quantizer
.
is_2x2x
(),
is_2x
=
quantizer
.
is_2x2x
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
output_shape
),
# output does not have act axis
scale_shapes
=
quantizer
.
get_scale_shapes
(
output_shape
,
flatten_axis
=-
1
),
is_outer
=
True
,
is_outer
=
True
,
)
)
rowwise_casted_output
=
rowwise_casted_output
.
reshape
(
output_shape
)
if
len
(
rowwise_scale_inv
.
shape
)
>
1
:
rowwise_scale_inv
=
jnp
.
squeeze
(
rowwise_scale_inv
,
axis
=-
2
)
# Remove act axis
if
quantizer
.
q_axis
in
(
QuantizeAxis
.
COLWISE
,
QuantizeAxis
.
ROWWISE_COLWISE
):
colwise_output_shape
=
output_shape
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
colwise_output_shape
=
multidim_transpose
(
output_shape
)
colwise_casted_output
=
colwise_casted_output
.
reshape
(
colwise_output_shape
)
if
len
(
colwise_scale_inv
.
shape
)
>
1
:
colwise_scale_inv
=
jnp
.
squeeze
(
colwise_scale_inv
,
axis
=-
2
)
# Remove act axis
quantizer
.
update
(
updated_amax
)
quantizer
.
update
(
updated_amax
)
return
ScaledTensorFactory
.
create
(
return
ScaledTensorFactory
.
create
(
...
@@ -951,8 +1058,8 @@ def act_lu(
...
@@ -951,8 +1058,8 @@ def act_lu(
colwise_scale_inv
=
colwise_scale_inv
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
dq_dtype
=
x
.
dtype
,
q_
axis
=
quantizer
.
q_
axis
,
q_
layout
=
quantizer
.
q_
layout
,
layout
=
quantizer
.
get_layout
(),
data_
layout
=
quantizer
.
get_
data_
layout
(),
)
)
...
@@ -968,7 +1075,7 @@ def quantize_dact_dbias(
...
@@ -968,7 +1075,7 @@ def quantize_dact_dbias(
Args:
Args:
dz: Gradient of the output with respect to the activation output.
dz: Gradient of the output with respect to the activation output.
x: Input tensor that was processed by the forward pass.
x: Input tensor that was processed by the forward pass.
Shape: (..., ACT_DIM
*
K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
Shape: (..., ACT_DIM
,
K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",).
is_dbias: If True, compute bias gradient. Defaults to True.
is_dbias: If True, compute bias gradient. Defaults to True.
quantizer: Optional quantizer for FP8 quantization of the output.
quantizer: Optional quantizer for FP8 quantization of the output.
...
@@ -979,21 +1086,25 @@ def quantize_dact_dbias(
...
@@ -979,21 +1086,25 @@ def quantize_dact_dbias(
- The gradient of the activation with respect to the bias.
- The gradient of the activation with respect to the bias.
"""
"""
act_len
=
len
(
activation_type
)
assert
x
.
shape
[
-
2
]
==
act_len
,
(
"activation input should be replicated by act_len in the -2 axis, got input shape"
f
"
{
x
.
shape
}
and act_len
{
act_len
}
"
)
if
not
DActLuDBiasQuantizePrimitive
.
enabled
():
if
not
DActLuDBiasQuantizePrimitive
.
enabled
():
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
# TE/common does not support colwise-only quantization yet
# TE/common does not support colwise-only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
COLWISE
:
if
quantizer
is
not
None
and
quantizer
.
q_
layout
==
Quantize
Layout
.
COLWISE
:
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
return
_jax_quantize_dact_dbias
(
dz
,
x
,
activation_type
,
is_dbias
,
quantizer
)
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
# TE/common does not support 1x dact_dbias_quantize on arch < 100 yet
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
out
,
_
=
quantize_dact_dbias
(
out
=
dact_lu
(
dz
,
x
,
activation_type
,
quantizer
=
None
)
dz
=
dz
,
x
=
x
,
activation_type
=
activation_type
,
is_dbias
=
False
,
quantizer
=
None
return
_quantize_dbias_impl
(
out
,
quantizer
,
is_dbias
=
True
,
flatten_axis
=-
2
)
)
return
quantize_dbias
(
out
,
is_dbias
=
True
,
quantizer
=
quantizer
)
is_gated
=
len
(
activation_type
)
==
2
is_gated
=
act_len
==
2
# TE/common does not support DelayedScaling2x for gated-act yet
# TE/common does not support DelayedScaling2x for gated-act yet
if
is_gated
:
if
is_gated
:
war_output
=
try_apply_delayed_scaling_2x_war
(
war_output
=
try_apply_delayed_scaling_2x_war
(
...
@@ -1003,6 +1114,7 @@ def quantize_dact_dbias(
...
@@ -1003,6 +1114,7 @@ def quantize_dact_dbias(
activation_type
=
activation_type
,
activation_type
=
activation_type
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
flatten_axis
=-
2
,
)
)
if
war_output
is
not
None
:
if
war_output
is
not
None
:
return
war_output
return
war_output
...
@@ -1019,18 +1131,18 @@ def quantize_dact_dbias(
...
@@ -1019,18 +1131,18 @@ def quantize_dact_dbias(
# outputs float32 for dbias accumulation
# outputs float32 for dbias accumulation
out_dtype
=
(
jnp
.
float32
if
is_dbias
else
x
.
dtype
),
out_dtype
=
(
jnp
.
float32
if
is_dbias
else
x
.
dtype
),
# default value for no scaling, TE/common ignore this value when scale is unset
# default value for no scaling, TE/common ignore this value when scale is unset
scaling_mode
=
ScalingMode
.
N
VTE_DELAYED_TENSOR
_SCALING
.
value
,
scaling_mode
=
ScalingMode
.
N
O
_SCALING
.
value
,
is_2x
=
False
,
# unused
is_2x
=
False
,
# unused
scale_dtype
=
jnp
.
float32
,
# unused
scale_dtype
=
jnp
.
float32
,
# unused
scale_shapes
=
((),
()),
# unused
scale_shapes
=
((),
()),
# unused
is_dbias
=
False
,
is_dbias
=
False
,
act_enum
=
act_type_id
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
)
,
act_len
=
act_len
,
is_outer
=
True
,
is_outer
=
True
,
)
)
dbias
=
None
dbias
=
None
if
is_dbias
:
if
is_dbias
:
dbias
=
_jax_dbias
(
output
).
as
type
(
x
.
dtype
)
dbias
=
_jax_dbias
(
output
,
d
type
=
x
.
dtype
,
flatten_axis
=-
2
)
return
output
.
astype
(
x
.
dtype
),
dbias
return
output
.
astype
(
x
.
dtype
),
dbias
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
...
@@ -1041,16 +1153,9 @@ def quantize_dact_dbias(
...
@@ -1041,16 +1153,9 @@ def quantize_dact_dbias(
dgated
=
dact_lu
(
dgated
=
dact_lu
(
dz
.
astype
(
jnp
.
float32
),
x
.
astype
(
jnp
.
float32
),
activation_type
=
activation_type
dz
.
astype
(
jnp
.
float32
),
x
.
astype
(
jnp
.
float32
),
activation_type
=
activation_type
)
)
# TODO(Jeremy): Debug - TE's quantize_dbias produced nans in this case for distributed layernorm_mlp tests
out
,
dbias
=
_quantize_dbias_impl
(
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
dgated
,
quantizer
,
is_dbias
=
True
,
dq_dtype
=
x
.
dtype
,
flatten_axis
=-
2
out
,
dbias
=
_jax_quantize_dbias
(
dgated
,
quantizer
=
quantizer
,
dq_dtype
=
x
.
dtype
)
)
else
:
out
,
dbias
=
quantize_dbias
(
dgated
,
quantizer
=
quantizer
,
is_dbias
=
True
,
dq_dtype
=
x
.
dtype
,
)
return
out
,
dbias
return
out
,
dbias
out_shape
=
x
.
shape
out_shape
=
x
.
shape
...
@@ -1070,15 +1175,16 @@ def quantize_dact_dbias(
...
@@ -1070,15 +1175,16 @@ def quantize_dact_dbias(
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
quantizer
.
is_2x2x
(),
is_2x
=
quantizer
.
is_2x2x
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
out_shape
),
# output has act axis
scale_shapes
=
quantizer
.
get_scale_shapes
(
out_shape
,
flatten_axis
=-
2
),
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
act_enum
=
act_type_id
,
act_enum
=
act_type_id
,
act_len
=
len
(
activation_type
)
,
act_len
=
act_len
,
is_outer
=
True
,
is_outer
=
True
,
)
)
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
# For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
if
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
colwise_scale_inv
=
rowwise_scale_inv
colwise_scale_inv
=
rowwise_scale_inv
quantizer
.
update
(
updated_amax
)
quantizer
.
update
(
updated_amax
)
...
@@ -1090,8 +1196,9 @@ def quantize_dact_dbias(
...
@@ -1090,8 +1196,9 @@ def quantize_dact_dbias(
colwise_scale_inv
=
colwise_scale_inv
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
dq_dtype
=
x
.
dtype
,
q_axis
=
quantizer
.
q_axis
,
q_layout
=
quantizer
.
q_layout
,
layout
=
quantizer
.
get_layout
(),
data_layout
=
quantizer
.
get_data_layout
(),
flatten_axis
=-
2
,
# as output has act axis
)
)
return
out
,
dbias
return
out
,
dbias
...
...
transformer_engine/jax/cpp_extensions/attention.py
View file @
ab3e5a92
...
@@ -14,6 +14,7 @@ import jax
...
@@ -14,6 +14,7 @@ import jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
,
lax
from
jax
import
dtypes
,
lax
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.sharding
import
PartitionSpec
,
NamedSharding
from
jax.experimental.custom_partitioning
import
SdyShardingRule
import
transformer_engine_jax
import
transformer_engine_jax
from
transformer_engine_jax
import
NVTE_Fused_Attn_Backend
from
transformer_engine_jax
import
NVTE_Fused_Attn_Backend
...
@@ -42,6 +43,7 @@ from ..sharding import (
...
@@ -42,6 +43,7 @@ from ..sharding import (
get_mesh_axis_rank
,
get_mesh_axis_rank
,
get_all_mesh_axes
,
get_all_mesh_axes
,
num_of_devices
,
num_of_devices
,
with_sharding_constraint
,
)
)
...
@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
...
@@ -618,6 +620,35 @@ class FusedAttnFwdPrimitive(BasePrimitive):
impl
=
partial
(
FusedAttnFwdPrimitive
.
impl
,
config
=
config
)
impl
=
partial
(
FusedAttnFwdPrimitive
.
impl
,
config
=
config
)
return
mesh
,
impl
,
out_shardings
,
arg_shardings
return
mesh
,
impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
del
mesh
,
result_types
# Keep in sync with `infer_sharding_from_operands`.
# We only need the first input. Fill up the rest with placeholders.
input_spec
=
[(
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
))]
# The RNG state sharding cannot be expressed as a Shardy rule. We use with_sharding_constraint
# instead. This has to happen outside of the primitive, see `fused_attn_fwd`.
rng_sharding
=
(
f
"…
{
len
(
value_types
)
}
"
,)
if
config
.
qkv_layout
.
is_qkvpacked
():
input_spec
[
0
]
=
(
"…0"
,
"seqlen"
,
"three"
,
"head"
,
"hidden"
)
elif
config
.
qkv_layout
.
is_kvpacked
()
or
config
.
qkv_layout
.
is_separate
():
input_spec
[
0
]
=
(
"…0"
,
"seqlen"
,
"head"
,
"hidden"
)
else
:
raise
ValueError
(
f
"Unsupported
{
config
.
qkv_layout
=
}
"
)
is_packed_softmax
=
get_cudnn_version
()
>=
(
9
,
6
,
0
)
and
config
.
qkv_layout
.
is_thd
()
out_sharding
=
(
"…0"
,
"seqlen"
,
"head"
,
"hidden"
)
if
is_packed_softmax
:
softmax_aux_sharding
=
(
"…0"
,
"seqlen"
,
"head"
,
"i"
)
else
:
softmax_aux_sharding
=
(
"…0"
,
"head"
,
"seqlen"
,
"i"
)
return
SdyShardingRule
(
tuple
(
input_spec
),
(
out_sharding
,
softmax_aux_sharding
,
rng_sharding
)
)
register_primitive
(
FusedAttnFwdPrimitive
)
register_primitive
(
FusedAttnFwdPrimitive
)
...
@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
...
@@ -998,6 +1029,15 @@ class FusedAttnBwdPrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
config
,
mesh
,
value_types
,
result_types
):
del
config
,
mesh
# We only care about the four first arguments.
# Keep in sync with `infer_sharding_from_operands`.
input_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
value_types
)))
output_spec
=
tuple
((
f
"…
{
x
}
"
,)
for
x
in
range
(
len
(
result_types
)))
return
SdyShardingRule
(
input_spec
,
output_spec
)
register_primitive
(
FusedAttnBwdPrimitive
)
register_primitive
(
FusedAttnBwdPrimitive
)
...
@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
...
@@ -2436,13 +2476,15 @@ def fused_attn_fwd(
primitive
=
FusedRingAttnFwdPrimitive
.
outer_primitive
primitive
=
FusedRingAttnFwdPrimitive
.
outer_primitive
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
seq_desc_flatten
,
_
=
jax
.
tree
.
flatten
(
sequence_descriptor
)
return
primitive
.
bind
(
output
,
softmax_aux
,
rng_state
=
primitive
.
bind
(
*
qkv_for_primitive
,
*
qkv_for_primitive
,
bias
,
bias
,
seed
,
seed
,
*
seq_desc_flatten
,
*
seq_desc_flatten
,
config
=
fused_config
,
config
=
fused_config
,
)
)
rng_state
=
with_sharding_constraint
(
rng_state
,
PartitionSpec
(
get_all_mesh_axes
(),
None
))
return
(
output
,
softmax_aux
,
rng_state
)
def
fused_attn_bwd
(
def
fused_attn_bwd
(
...
...
transformer_engine/jax/cpp_extensions/base.py
View file @
ab3e5a92
...
@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
...
@@ -98,6 +98,15 @@ class BasePrimitive(metaclass=ABCMeta):
"""
"""
return
NotImplemented
return
NotImplemented
@
staticmethod
@
abstractmethod
def
shardy_sharding_rule
(
*
args
):
"""
Returns the sharding rule for this primitive.
"""
del
args
return
"... -> ..."
def
register_primitive
(
cls
):
def
register_primitive
(
cls
):
"""
"""
...
@@ -123,7 +132,9 @@ def register_primitive(cls):
...
@@ -123,7 +132,9 @@ def register_primitive(cls):
batching
.
primitive_batchers
[
outer_p
]
=
cls
.
batcher
batching
.
primitive_batchers
[
outer_p
]
=
cls
.
batcher
outer_p_lower
=
custom_partitioning
(
cls
.
impl
,
static_argnums
=
cls
.
impl_static_args
)
outer_p_lower
=
custom_partitioning
(
cls
.
impl
,
static_argnums
=
cls
.
impl_static_args
)
outer_p_lower
.
def_partition
(
outer_p_lower
.
def_partition
(
infer_sharding_from_operands
=
cls
.
infer_sharding_from_operands
,
partition
=
cls
.
partition
infer_sharding_from_operands
=
cls
.
infer_sharding_from_operands
,
partition
=
cls
.
partition
,
sharding_rule
=
cls
.
shardy_sharding_rule
,
)
)
mlir
.
register_lowering
(
mlir
.
register_lowering
(
outer_p
,
mlir
.
lower_fun
(
outer_p_lower
,
multiple_results
=
cls
.
multiple_results
)
outer_p
,
mlir
.
lower_fun
(
outer_p_lower
,
multiple_results
=
cls
.
multiple_results
)
...
...
transformer_engine/jax/cpp_extensions/gemm.py
View file @
ab3e5a92
...
@@ -6,9 +6,9 @@
...
@@ -6,9 +6,9 @@
from
typing
import
Tuple
,
Sequence
,
Union
,
Dict
,
List
from
typing
import
Tuple
,
Sequence
,
Union
,
Dict
,
List
from
functools
import
partial
,
reduce
from
functools
import
partial
,
reduce
import
operator
import
operator
from
transformer_engine_jax
import
get_device_compute_capability
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
transformer_engine_jax
import
get_device_compute_capability
from
.base
import
BasePrimitive
,
register_primitive
from
.base
import
BasePrimitive
,
register_primitive
...
@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
...
@@ -41,32 +41,45 @@ class GroupedGemmPrimitive(BasePrimitive):
name
=
"te_grouped_gemm_ffi"
name
=
"te_grouped_gemm_ffi"
multiple_results
=
True
multiple_results
=
True
impl_static_args
=
(
6
,
7
,
8
,
9
)
impl_static_args
=
()
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
@
staticmethod
@
staticmethod
def
abstract
(
def
abstract
(
*
args
,
num_gemms
,
scaling_mode
,
out_dtype
,
has_bias
):
lhs_contig_aval
,
"""
lhs_scale_contig_aval
,
Args:
rhs_contig_aval
,
*args: Size num_gemms * 4 or num_gemms * 5 depending on has_bias:
rhs_scale_contig_aval
,
args[ 0 : num_gemms] are the lhs tensors,
bias_contig_aval
,
args[ num_gemms : 2*num_gemms] are the rhs tensors,
dim_list_aval
,
args[2*num_gemms : 3*num_gemms] are the lhs scale_inv tensors,
*
,
args[3*num_gemms : 4*num_gemms] are the rhs scale_inv tensors,
num_gemms
,
args[4*num_gemms : 5*num_gemms] are the bias tensors if has_bias is True.
scaling_mode
,
num_gemms: Number of GEMM operations to perform.
out_dtype
,
scaling_mode: Scaling mode for the GEMM operations.
out_flat_size
,
out_dtype: Data type of the output tensors.
):
has_bias: Boolean indicating if bias tensors are provided.
del
lhs_contig_aval
,
lhs_scale_contig_aval
del
rhs_contig_aval
,
rhs_scale_contig_aval
Returns:
del
bias_contig_aval
,
dim_list_aval
A tuple of ShapedArray objects of size num_gemms+1:
del
num_gemms
,
scaling_mode
ret[0 : num_gemms]: GEMM output tensors,
out_flat_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
out_flat_size
,),
dtype
=
out_dtype
)
ret[num_gemms]:workspace tensor.
wkspace_size
=
get_cublas_workspace_size_bytes
()
*
num_cublas_streams
"""
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
wkspace_size
,),
dtype
=
jnp
.
uint8
)
del
scaling_mode
return
(
out_flat_aval
,
wkspace_aval
)
expected_num_args
=
5
*
num_gemms
if
has_bias
else
4
*
num_gemms
assert
(
len
(
args
)
==
expected_num_args
),
f
"Expected
{
expected_num_args
}
input arguments, but got
{
len
(
args
)
}
"
A_list
=
args
[
0
:
num_gemms
]
B_list
=
args
[
num_gemms
:
2
*
num_gemms
]
# A and B have shapes [1, m, k] and [1, n, k]
out_list_aval
=
tuple
(
jax
.
core
.
ShapedArray
((
A
.
shape
[
1
],
B
.
shape
[
1
]),
dtype
=
out_dtype
)
for
A
,
B
in
zip
(
A_list
,
B_list
)
)
workspace_size
=
get_cublas_workspace_size_bytes
()
*
num_cublas_streams
workspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
workspace_size
,),
dtype
=
jnp
.
uint8
)
return
(
*
out_list_aval
,
workspace_aval
)
@
staticmethod
@
staticmethod
def
outer_abstract
(
*
args
,
**
kwargs
):
def
outer_abstract
(
*
args
,
**
kwargs
):
...
@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
...
@@ -74,60 +87,27 @@ class GroupedGemmPrimitive(BasePrimitive):
return
out_aval
return
out_aval
@
staticmethod
@
staticmethod
def
lowering
(
def
lowering
(
ctx
,
*
args
,
num_gemms
,
scaling_mode
,
out_dtype
,
has_bias
):
ctx
,
del
out_dtype
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
*
,
num_gemms
,
scaling_mode
,
out_dtype
,
out_flat_size
,
)
->
jnp
.
ndarray
:
del
out_dtype
,
out_flat_size
return
jax
.
ffi
.
ffi_lowering
(
GroupedGemmPrimitive
.
name
)(
return
jax
.
ffi
.
ffi_lowering
(
GroupedGemmPrimitive
.
name
)(
ctx
,
ctx
,
lhs_contig
,
*
args
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
num_gemms
=
num_gemms
,
scaling_mode
=
int
(
scaling_mode
),
scaling_mode
=
int
(
scaling_mode
),
has_bias
=
has_bias
,
)
)
@
staticmethod
@
staticmethod
def
impl
(
def
impl
(
*
args
,
num_gemms
,
scaling_mode
,
out_dtype
,
has_bias
):
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
,
scaling_mode
,
out_dtype
,
out_flat_size
,
)
->
jnp
.
ndarray
:
assert
GroupedGemmPrimitive
.
inner_primitive
is
not
None
assert
GroupedGemmPrimitive
.
inner_primitive
is
not
None
out
=
GroupedGemmPrimitive
.
inner_primitive
.
bind
(
out
=
GroupedGemmPrimitive
.
inner_primitive
.
bind
(
lhs_contig
,
*
args
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
num_gemms
=
num_gemms
,
scaling_mode
=
scaling_mode
.
value
,
scaling_mode
=
scaling_mode
.
value
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
out_flat_size
=
out_flat_size
,
has_bias
=
has_bias
,
)
)
return
out
[
0
]
# out is [out_
fla
t, wkspace], only return out_
fla
t
return
out
[
:
-
1
]
# out is [out_
lis
t, wkspace], only return out_
lis
t
register_primitive
(
GroupedGemmPrimitive
)
register_primitive
(
GroupedGemmPrimitive
)
...
@@ -183,10 +163,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
...
@@ -183,10 +163,9 @@ def __jitted_jax_gemm_delayed_scaling_fp8(lhs, rhs, lhs_dn, rhs_dn, precision):
# Reshape + Transpose
# Reshape + Transpose
# [..., M, K] -> [B, M, K]
# [..., M, K] -> [B, M, K]
# [..., K, M] -> [B, M, K]
# [..., K, M] -> [B, M, K]
lhs_3d
=
_shape_normalization
(
lhs_dq
,
lhs_dn
,
lhs
.
layout
==
"N"
)
lhs_3d
=
_shape_normalization
(
lhs_dq
,
lhs_dn
,
lhs
.
data_
layout
==
"N"
)
rhs_3d
=
_shape_normalization
(
rhs_dq
,
rhs_dn
,
rhs
.
layout
==
"T"
)
rhs_3d
=
_shape_normalization
(
rhs_dq
,
rhs_dn
,
rhs
.
data_
layout
==
"T"
)
# _shape_normalization ensures contracting_dims=2 and batch_dims=0
dim_nums
=
(((
2
,),
(
2
,)),
((
0
,),
(
0
,)))
dim_nums
=
(((
2
,),
(
2
,)),
((
0
,),
(
0
,)))
out_3d
=
jax
.
lax
.
dot_general
(
out_3d
=
jax
.
lax
.
dot_general
(
lhs_3d
,
rhs_3d
,
dim_nums
,
precision
=
precision
,
preferred_element_type
=
lhs
.
dq_dtype
lhs_3d
,
rhs_3d
,
dim_nums
,
precision
=
precision
,
preferred_element_type
=
lhs
.
dq_dtype
...
@@ -199,13 +178,13 @@ def _jax_gemm_delayed_scaling_fp8(
...
@@ -199,13 +178,13 @@ def _jax_gemm_delayed_scaling_fp8(
):
):
"""FP8 GEMM for XLA pattern match"""
"""FP8 GEMM for XLA pattern match"""
assert
(
assert
(
rhs
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
rhs
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
),
"rhs does not have delayed tensor scaling mode"
),
"rhs does not have delayed tensor scaling mode"
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
(
lhs_contract
,
rhs_contract
),
(
lhs_batch
,
rhs_batch
)
=
dim_nums
if
lhs
.
layout
==
"T"
:
if
lhs
.
data_
layout
==
"T"
:
lhs_contract
=
tuple
((
lhs
.
data
.
ndim
-
1
-
i
)
%
lhs
.
data
.
ndim
for
i
in
lhs_contract
)
lhs_contract
=
tuple
((
lhs
.
data
.
ndim
-
1
-
i
)
%
lhs
.
data
.
ndim
for
i
in
lhs_contract
)
if
rhs
.
layout
==
"T"
:
if
rhs
.
data_
layout
==
"T"
:
rhs_contract
=
tuple
((
rhs
.
data
.
ndim
-
1
-
i
)
%
rhs
.
data
.
ndim
for
i
in
rhs_contract
)
rhs_contract
=
tuple
((
rhs
.
data
.
ndim
-
1
-
i
)
%
rhs
.
data
.
ndim
for
i
in
rhs_contract
)
lhs_dn
=
(
lhs_contract
,
lhs_batch
)
lhs_dn
=
(
lhs_contract
,
lhs_batch
)
...
@@ -231,7 +210,7 @@ def _jax_gemm_mxfp8_1d(
...
@@ -231,7 +210,7 @@ def _jax_gemm_mxfp8_1d(
JAX GEMM for MXFP8 via scaled_matmul
JAX GEMM for MXFP8 via scaled_matmul
"""
"""
assert
(
assert
(
rhs
.
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
rhs
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
),
"rhs does not have MXFP8 1D scaling mode"
),
"rhs does not have MXFP8 1D scaling mode"
from
jax._src.cudnn.scaled_matmul_stablehlo
import
scaled_matmul_wrapper
from
jax._src.cudnn.scaled_matmul_stablehlo
import
scaled_matmul_wrapper
...
@@ -292,10 +271,10 @@ def _jax_gemm(
...
@@ -292,10 +271,10 @@ def _jax_gemm(
def
_jax_gemm_fp8_impl
(
lhs
,
rhs
):
def
_jax_gemm_fp8_impl
(
lhs
,
rhs
):
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
lhs
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
return
_jax_gemm_delayed_scaling_fp8
(
lhs
,
rhs
,
dim_nums
)
return
_jax_gemm_delayed_scaling_fp8
(
lhs
,
rhs
,
dim_nums
)
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
if
lhs
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
return
_jax_gemm_mxfp8_1d
(
lhs
,
rhs
,
dim_nums
)
return
_jax_gemm_mxfp8_1d
(
lhs
,
rhs
,
dim_nums
)
raise
NotImplementedError
(
"Unsupported ScalingMode: {lhs.scaling_mode}"
)
raise
NotImplementedError
(
"Unsupported ScalingMode: {lhs.scaling_mode}"
)
...
@@ -367,6 +346,7 @@ def swizzled_scale(scales):
...
@@ -367,6 +346,7 @@ def swizzled_scale(scales):
rows
,
cols
=
scales
.
shape
rows
,
cols
=
scales
.
shape
scales
=
scales
.
reshape
(
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
scales
=
scales
.
reshape
(
rows
//
128
,
4
,
32
,
cols
//
4
,
4
)
scales
=
jnp
.
transpose
(
scales
,
(
0
,
3
,
2
,
1
,
4
))
scales
=
jnp
.
transpose
(
scales
,
(
0
,
3
,
2
,
1
,
4
))
scales
=
scales
.
reshape
(
rows
,
cols
)
return
scales
return
scales
...
@@ -381,18 +361,12 @@ def grouped_gemm(
...
@@ -381,18 +361,12 @@ def grouped_gemm(
len
(
lhs_list
)
==
len
(
rhs_list
)
==
len
(
contracting_dims_list
)
len
(
lhs_list
)
==
len
(
rhs_list
)
==
len
(
contracting_dims_list
)
),
"lhs_list, rhs_list, contracting_dims_list must have the same length"
),
"lhs_list, rhs_list, contracting_dims_list must have the same length"
# Flatten inputs and save their shapes
num_gemms
=
len
(
lhs_list
)
out_flat_size
=
0
dims
=
[]
lhs_contig_
=
[]
rhs_contig_
=
[]
lhs_scale_inv_contig_
=
[]
rhs_scale_inv_contig_
=
[]
bias_contig_
=
[]
out_offsets
=
[]
remain_shape_list
=
[]
num_gemms
=
len
(
lhs_list
)
num_gemms
=
len
(
lhs_list
)
lhs_list_
=
[]
rhs_list_
=
[]
lhs_sinv_list_
=
[]
rhs_sinv_list_
=
[]
bias_list_
=
[]
for
i
in
range
(
num_gemms
):
for
i
in
range
(
num_gemms
):
lhs
=
lhs_list
[
i
]
lhs
=
lhs_list
[
i
]
rhs
=
rhs_list
[
i
]
rhs
=
rhs_list
[
i
]
...
@@ -403,20 +377,20 @@ def grouped_gemm(
...
@@ -403,20 +377,20 @@ def grouped_gemm(
lhs_shape
=
lhs
.
data
.
shape
lhs_shape
=
lhs
.
data
.
shape
rhs_shape
=
rhs
.
data
.
shape
rhs_shape
=
rhs
.
data
.
shape
out_dtype
=
lhs
.
dq_dtype
out_dtype
=
lhs
.
dq_dtype
# For ScaledTensors and
NVTE_
DELAYED_TENSOR_SCALING, need to handle internal layout
# For ScaledTensors and DELAYED_TENSOR_SCALING, need to handle internal
data_
layout
if
lhs
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
lhs
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
assert
not
(
assert
not
(
lhs
.
data
.
dtype
==
jnp
.
float8_e5m2
and
rhs
.
data
.
dtype
==
jnp
.
float8_e5m2
lhs
.
data
.
dtype
==
jnp
.
float8_e5m2
and
rhs
.
data
.
dtype
==
jnp
.
float8_e5m2
),
"FP8 GEMM does not support E5M2 * E5M2"
),
"FP8 GEMM does not support E5M2 * E5M2"
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
((
lhs_contract_dim
,),
(
rhs_contract_dim
,))
=
contracting_dims
if
lhs
.
layout
==
"T"
:
if
lhs
.
data_
layout
==
"T"
:
lhs_contract_dim
=
(
lhs_contract_dim
-
1
)
%
lhs
.
data
.
ndim
lhs_contract_dim
=
(
lhs_contract_dim
-
1
)
%
lhs
.
data
.
ndim
if
rhs
.
layout
==
"T"
:
if
rhs
.
data_
layout
==
"T"
:
rhs_contract_dim
=
(
rhs_contract_dim
-
1
)
%
rhs
.
data
.
ndim
rhs_contract_dim
=
(
rhs_contract_dim
-
1
)
%
rhs
.
data
.
ndim
dim_nums
=
((
lhs_contract_dim
,),
(
rhs_contract_dim
,)),
((),
())
dim_nums
=
((
lhs_contract_dim
,),
(
rhs_contract_dim
,)),
((),
())
else
:
else
:
# For jnp.ndarray, only consider contracting_dims, layout is always NN
# For jnp.ndarray, only consider contracting_dims,
data_
layout is always NN
scaling_mode
=
ScalingMode
.
NVTE_
NO_SCALING
scaling_mode
=
ScalingMode
.
NO_SCALING
lhs_shape
=
lhs
.
shape
lhs_shape
=
lhs
.
shape
rhs_shape
=
rhs
.
shape
rhs_shape
=
rhs
.
shape
out_dtype
=
lhs
.
dtype
out_dtype
=
lhs
.
dtype
...
@@ -428,24 +402,25 @@ def grouped_gemm(
...
@@ -428,24 +402,25 @@ def grouped_gemm(
lhs_remain_shape
=
_calculate_remaining_shape
(
lhs_shape
,
lhs_contract
)
lhs_remain_shape
=
_calculate_remaining_shape
(
lhs_shape
,
lhs_contract
)
rhs_remain_shape
=
_calculate_remaining_shape
(
rhs_shape
,
rhs_contract
)
rhs_remain_shape
=
_calculate_remaining_shape
(
rhs_shape
,
rhs_contract
)
if
scaling_mode
==
ScalingMode
.
NVTE_NO_SCALING
:
# Note: do not squeeze() for {lhs, rhs}_3d, it will trigger a D2D memcpy
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
,
lhs_dn
)
lhs_3d
=
_shape_normalization
(
lhs
,
lhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
,
rhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
,
rhs_dn
)
elif
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
elif
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
,
lhs
.
layout
==
"N"
)
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
,
lhs
.
data_
layout
==
"N"
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
,
rhs
.
layout
==
"T"
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
,
rhs
.
data_
layout
==
"T"
)
elif
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
)
lhs_3d
=
_shape_normalization
(
lhs
.
data
,
lhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
)
rhs_3d
=
_shape_normalization
(
rhs
.
data
,
rhs_dn
)
lhs_scale_inv
=
_shape_normalization
(
lhs
.
scale_inv
,
lhs_dn
)
lhs_scale_inv
=
_shape_normalization
(
lhs
.
scale_inv
,
lhs_dn
)
rhs_scale_inv
=
_shape_normalization
(
rhs
.
scale_inv
,
rhs_dn
)
rhs_scale_inv
=
_shape_normalization
(
rhs
.
scale_inv
,
rhs_dn
)
# swizzled_scale requires a matrix
lhs_scale_inv
=
swizzled_scale
(
lhs_scale_inv
.
squeeze
())
lhs_scale_inv
=
swizzled_scale
(
lhs_scale_inv
.
squeeze
())
rhs_scale_inv
=
swizzled_scale
(
rhs_scale_inv
.
squeeze
())
rhs_scale_inv
=
swizzled_scale
(
rhs_scale_inv
.
squeeze
())
else
:
else
:
raise
NotImplementedError
(
"Unsupported ScalingMode: {scaling_mode}"
)
raise
NotImplementedError
(
"Unsupported ScalingMode: {scaling_mode}"
)
# Note: if _shape_normalization() is updated to support non-TN, need to update here
# Note: already_transposed doesn't matter for the output shape
# already_transposed doesn't matter for the output shape
# x.shape = [B, D1, D2]
# x.shape = [B, D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (2, ) --> output.shape = [1, B * D1, D2]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
# contracting_dims = (0, 1, ) --> output.shape = [1, D2, B * D1]
...
@@ -456,61 +431,37 @@ def grouped_gemm(
...
@@ -456,61 +431,37 @@ def grouped_gemm(
bn
=
rhs_remain_shape
[
0
]
bn
=
rhs_remain_shape
[
0
]
kl
=
lhs_3d
.
shape
[
-
1
]
kl
=
lhs_3d
.
shape
[
-
1
]
kr
=
rhs_3d
.
shape
[
-
1
]
kr
=
rhs_3d
.
shape
[
-
1
]
remain_shape_list
.
append
(((
bm
,),
(
bn
,)))
assert
kl
==
kr
,
f
"After shape normalization, contracting dim size mismatch:
{
kl
}
!=
{
kr
}
"
assert
kl
==
kr
,
f
"lhs_3d.shape[-1] (
{
kl
}
) != rhs_3d.shape[-1] (
{
kr
}
)"
if
(
bm
%
16
!=
0
)
or
(
bn
%
16
!=
0
)
or
(
kl
%
16
!=
0
):
k
=
kl
print
(
"grouped_gemm input pair {i} has invalid problem shape for lowering: "
)
print
(
f
"m =
{
bm
}
, n =
{
bn
}
, k =
{
kl
}
; "
)
if
(
bm
%
16
!=
0
)
or
(
bn
%
16
!=
0
)
or
(
k
%
16
!=
0
):
print
(
"cuBLAS requires the problem shapes being multiples of 16"
)
print
(
f
"grouped_gemm input pair
{
i
}
has invalid problem shape for lowering: "
)
assert
(
bm
%
16
==
0
)
and
(
bn
%
16
==
0
)
and
(
kl
%
16
==
0
)
print
(
f
"m =
{
bm
}
, n =
{
bn
}
, k =
{
k
}
; cuBLAS requires the problem shapes being multiples"
lhs_list_
.
append
(
lhs_3d
)
" of 16"
rhs_list_
.
append
(
rhs_3d
)
)
if
scaling_mode
==
ScalingMode
.
NO_SCALING
:
assert
bm
%
16
==
0
and
bn
%
16
==
0
and
k
%
16
==
0
lhs_sinv_list_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
rhs_sinv_list_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
dims
.
append
((
bm
,
bn
,
k
))
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
lhs_contig_
.
append
(
lhs_3d
.
reshape
(
-
1
))
lhs_sinv_list_
.
append
(
lhs
.
scale_inv
)
rhs_contig_
.
append
(
rhs_3d
.
reshape
(
-
1
))
rhs_sinv_list_
.
append
(
rhs
.
scale_inv
)
if
scaling_mode
==
ScalingMode
.
NVTE_NO_SCALING
:
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
lhs_scale_inv_contig_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
lhs_sinv_list_
.
append
(
lhs_scale_inv
)
rhs_scale_inv_contig_
.
append
(
jnp
.
ones
(
1
,
dtype
=
jnp
.
float32
))
rhs_sinv_list_
.
append
(
rhs_scale_inv
)
if
scaling_mode
==
ScalingMode
.
NVTE_DELAYED_TENSOR_SCALING
:
lhs_scale_inv_contig_
.
append
(
lhs
.
scale_inv
.
reshape
(
-
1
))
rhs_scale_inv_contig_
.
append
(
rhs
.
scale_inv
.
reshape
(
-
1
))
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
lhs_scale_inv_contig_
.
append
(
lhs_scale_inv
.
reshape
(
-
1
))
rhs_scale_inv_contig_
.
append
(
rhs_scale_inv
.
reshape
(
-
1
))
if
bias_list
is
not
None
:
if
bias_list
is
not
None
:
bias_contig_
.
append
(
bias_list
[
i
].
reshape
(
-
1
))
bias_list_
.
append
(
bias_list
[
i
])
out_flat_size
+=
bm
*
bn
out_offsets
.
append
(
out_flat_size
)
out_list
=
GroupedGemmPrimitive
.
outer_primitive
.
bind
(
*
lhs_list_
,
lhs_contig
=
jnp
.
concatenate
(
lhs_contig_
)
*
rhs_list_
,
rhs_contig
=
jnp
.
concatenate
(
rhs_contig_
)
*
lhs_sinv_list_
,
lhs_scale_inv_contig
=
jnp
.
concatenate
(
lhs_scale_inv_contig_
)
*
rhs_sinv_list_
,
rhs_scale_inv_contig
=
jnp
.
concatenate
(
rhs_scale_inv_contig_
)
*
bias_list_
,
bias_contig
=
jnp
.
empty
(
0
)
if
bias_list
is
None
else
jnp
.
concatenate
(
bias_contig_
)
dim_list
=
jnp
.
array
(
dims
,
dtype
=
jnp
.
int32
)
# Perform batched GEMM on flattened inputs
out_contig
=
GroupedGemmPrimitive
.
outer_primitive
.
bind
(
lhs_contig
,
lhs_scale_inv_contig
,
rhs_contig
,
rhs_scale_inv_contig
,
bias_contig
,
dim_list
,
num_gemms
=
num_gemms
,
num_gemms
=
num_gemms
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
out_flat_size
=
out_flat_size
,
has_bias
=
1
if
bias_list
is
not
None
else
0
,
)
)
# Split the output back into tensors
return
out_list
out_offsets
=
jnp
.
array
(
out_offsets
)
out_flat_list
=
jnp
.
split
(
out_contig
,
out_offsets
[:
-
1
])
out_tensors
=
[]
for
out_flat
,
(
lhs_remain_shape
,
rhs_remain_shape
)
in
zip
(
out_flat_list
,
remain_shape_list
):
out_tensors
.
append
(
out_flat
.
reshape
(
*
lhs_remain_shape
,
*
rhs_remain_shape
))
return
out_tensors
transformer_engine/jax/cpp_extensions/misc.py
View file @
ab3e5a92
...
@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
...
@@ -19,7 +19,7 @@ from jax.interpreters.mlir import dtype_to_ir_type
import
transformer_engine_jax
import
transformer_engine_jax
from
..sharding
import
get_padded_spec
as
te_get_padded_spec
from
..sharding
import
get_padded_spec
as
te_get_padded_spec
from
..quantize
import
ScalingMode
,
ScaledTensorFactory
,
Quantize
Axis
from
..quantize
import
ScalingMode
,
ScaledTensorFactory
,
Quantize
Layout
TEDType
=
transformer_engine_jax
.
DType
TEDType
=
transformer_engine_jax
.
DType
...
@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
...
@@ -107,37 +107,37 @@ def normalize_axis_boundary(axis, ndim):
return
axis
if
axis
>=
0
else
ndim
+
axis
return
axis
if
axis
>=
0
else
ndim
+
axis
def
multidim_transpose
(
shape
,
static_axis_boundary
=-
1
,
transpose_axis
_boundary
=-
1
):
def
multidim_transpose
(
shape
,
static_axis_boundary
=-
1
,
transpose_axis
=-
1
):
"""
"""
te_cast_transpose_p multi-dims transpose
te_cast_transpose_p multi-dims transpose
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
static_axis_boundary: int, Indicate those axes <= static_axis_boundary would not be
involved into transpose, -1 means all axes involve into transpose.
involved into transpose, -1 means all axes involve into transpose.
transpose_axis
_boundary
: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose_axis: int, Indicate how to split multi-dimensions tensors to 2D matrix for
transpose. Note, transpose_axis
_boundary
should be greater than static_axis_boundary
transpose. Note, transpose_axis should be greater than static_axis_boundary
examples:
examples:
X in shape (dim0, dim1, dim2, dim3, dim4)
X in shape (dim0, dim1, dim2, dim3, dim4)
static_axis_boundary == -1, transpose_axis
_boundary
== 2
static_axis_boundary == -1, transpose_axis == 2
Xt = (dim2, dim3, dim4, dim0, dim1)
Xt = (dim2, dim3, dim4, dim0, dim1)
static_axis_boundary == 0, transpose_axis
_boundary
== 2
static_axis_boundary == 0, transpose_axis == 2
Xt = (dim0, dim2, dim3, dim4, dim1)
Xt = (dim0, dim2, dim3, dim4, dim1)
static_axis_boundary == 0, transpose_axis
_boundary
== 3
static_axis_boundary == 0, transpose_axis == 3
Xt = (dim0, dim3, dim4, dim1. dim2)
Xt = (dim0, dim3, dim4, dim1. dim2)
"""
"""
if
static_axis_boundary
<
0
:
if
static_axis_boundary
<
0
:
static_axis_boundary
=
-
1
# means no static axes
static_axis_boundary
=
-
1
# means no static axes
assert
static_axis_boundary
<
len
(
shape
)
-
2
# at least 2 remaining for transpose.
assert
static_axis_boundary
<
len
(
shape
)
-
2
# at least 2 remaining for transpose.
transpose_start_idx
=
static_axis_boundary
+
1
transpose_start_idx
=
static_axis_boundary
+
1
transpose_axis
_boundary
=
normalize_axis_boundary
(
transpose_axis
_boundary
,
len
(
shape
))
transpose_axis
=
normalize_axis_boundary
(
transpose_axis
,
len
(
shape
))
assert
transpose_start_idx
<
transpose_axis
_boundary
assert
transpose_start_idx
<
transpose_axis
return
(
return
(
*
shape
[:
transpose_start_idx
],
*
shape
[:
transpose_start_idx
],
*
shape
[
transpose_axis
_boundary
:],
*
shape
[
transpose_axis
:],
*
shape
[
transpose_start_idx
:
transpose_axis
_boundary
],
*
shape
[
transpose_start_idx
:
transpose_axis
],
)
)
...
@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
...
@@ -195,13 +195,13 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant
break
break
return
(
return
(
quantizer
is
not
None
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
ROWWISE
and
quantizer
.
q_
layout
==
Quantize
Layout
.
ROWWISE
and
arch_l_100
and
arch_l_100
and
is_dbias
and
is_dbias
)
)
def
try_apply_delayed_scaling_2x_war
(
f
,
*
args
,
quantizer
=
None
,
**
kwargs
):
def
try_apply_delayed_scaling_2x_war
(
f
,
*
args
,
quantizer
=
None
,
flatten_axis
=-
1
,
**
kwargs
):
"""
"""
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
Applies a workaround for delayed scaling 2x and can be used when the TE common kernels do not yet support 2x delayed scaling.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
It will call the given function 'f' with the given arguments and quantizer as 1x and calculate the colwise output by transposing result.
...
@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
...
@@ -216,7 +216,7 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
"""
"""
should_apply_war
=
(
should_apply_war
=
(
quantizer
is
not
None
quantizer
is
not
None
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
and
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
()
and
quantizer
.
is_2x2x
()
)
)
if
not
should_apply_war
:
if
not
should_apply_war
:
...
@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
...
@@ -224,14 +224,19 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
# 2x is not supported by TE kernels for delayed scaling
# 2x is not supported by TE kernels for delayed scaling
# so revert to 1x and transpose in JAX
# so revert to 1x and transpose in JAX
quantizer
.
q_
axis
=
Quantize
Axis
.
ROWWISE
quantizer
.
q_
layout
=
Quantize
Layout
.
ROWWISE
rowwise
=
f
(
*
args
,
**
kwargs
,
quantizer
=
quantizer
)
rowwise
=
f
(
*
args
,
**
kwargs
,
quantizer
=
quantizer
)
other_outputs
=
None
other_outputs
=
None
if
isinstance
(
rowwise
,
tuple
):
if
isinstance
(
rowwise
,
tuple
):
other_outputs
=
rowwise
[
1
:]
other_outputs
=
rowwise
[
1
:]
rowwise
=
rowwise
[
0
]
rowwise
=
rowwise
[
0
]
quantizer
.
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
quantizer
.
q_layout
=
QuantizeLayout
.
ROWWISE_COLWISE
colwise_data
=
jnp
.
transpose
(
rowwise
.
data
,
(
-
1
,
*
range
(
rowwise
.
data
.
ndim
-
1
)))
if
flatten_axis
<
0
:
flatten_axis
+=
rowwise
.
data
.
ndim
assert
0
<
flatten_axis
<
rowwise
.
data
.
ndim
,
"flatten_axis is out of bounds"
colwise_data
=
jnp
.
transpose
(
rowwise
.
data
,
(
*
range
(
flatten_axis
,
rowwise
.
data
.
ndim
),
*
range
(
flatten_axis
))
)
output_2x
=
ScaledTensorFactory
.
create
(
output_2x
=
ScaledTensorFactory
.
create
(
data
=
rowwise
.
data
,
data
=
rowwise
.
data
,
scale_inv
=
rowwise
.
scale_inv
,
scale_inv
=
rowwise
.
scale_inv
,
...
@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
...
@@ -239,8 +244,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, **kwargs):
colwise_scale_inv
=
rowwise
.
scale_inv
,
colwise_scale_inv
=
rowwise
.
scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
rowwise
.
dq_dtype
,
dq_dtype
=
rowwise
.
dq_dtype
,
q_axis
=
QuantizeAxis
.
ROWWISE_COLWISE
,
q_layout
=
QuantizeLayout
.
ROWWISE_COLWISE
,
layout
=
quantizer
.
get_layout
(),
data_layout
=
quantizer
.
get_data_layout
(),
flatten_axis
=
flatten_axis
,
)
)
if
other_outputs
is
not
None
:
if
other_outputs
is
not
None
:
return
(
output_2x
,)
+
other_outputs
return
(
output_2x
,)
+
other_outputs
...
...
transformer_engine/jax/cpp_extensions/normalization.py
View file @
ab3e5a92
...
@@ -12,6 +12,7 @@ from packaging import version
...
@@ -12,6 +12,7 @@ from packaging import version
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax
import
dtypes
from
jax.experimental.custom_partitioning
import
SdyShardingRule
from
jax.interpreters.mlir
import
ir
from
jax.interpreters.mlir
import
ir
from
jax.sharding
import
PartitionSpec
from
jax.sharding
import
PartitionSpec
...
@@ -30,7 +31,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
...
@@ -30,7 +31,7 @@ from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_a
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
(
from
..quantize
import
(
Quantizer
,
Quantizer
,
Quantize
Axis
,
Quantize
Layout
,
DelayedScaleQuantizer
,
DelayedScaleQuantizer
,
ScalingMode
,
ScalingMode
,
)
)
...
@@ -63,6 +64,27 @@ def get_backward_sm_margin():
...
@@ -63,6 +64,27 @@ def get_backward_sm_margin():
return
int
(
os
.
getenv
(
"NVTE_BWD_LAYERNORM_SM_MARGIN"
,
"0"
))
return
int
(
os
.
getenv
(
"NVTE_BWD_LAYERNORM_SM_MARGIN"
,
"0"
))
@
cache
def
is_norm_fwd_cudnn_enabled
(
scaling_mode
:
ScalingMode
)
->
bool
:
"""Retrieves whether CuDNN norm fwd is enabled."""
# MXFP8_1D_SCALING always uses CuDNN currently
return
(
int
(
os
.
getenv
(
"NVTE_NORM_FWD_USE_CUDNN"
,
"0"
))
==
1
or
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
)
@
cache
def
is_norm_zero_centered_gamma_in_weight_dtype
(
scaling_mode
:
ScalingMode
)
->
bool
:
"""Retrieves whether norm should compute `gamma += 1.0` for zero-centered gamma
in weight dtype as opposed to compute dtype."""
if
not
is_norm_fwd_cudnn_enabled
(
scaling_mode
):
# If CuDNN is not enabled, we use the TE backend which uses the compute dtype not weight dtype
# Remove this when TE supports gamma += 1.0 in weight dtype
return
False
return
int
(
os
.
getenv
(
"NVTE_ZERO_CENTERED_GAMMA_IN_WTYPE"
,
"0"
))
==
1
class
NormFwdPrimitive
(
BasePrimitive
):
class
NormFwdPrimitive
(
BasePrimitive
):
"""
"""
Layer Normalization Forward FP8 Primitive
Layer Normalization Forward FP8 Primitive
...
@@ -105,6 +127,26 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -105,6 +127,26 @@ class NormFwdPrimitive(BasePrimitive):
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
:
assert
gamma_aval
.
size
==
beta_aval
.
size
assert
gamma_aval
.
size
==
beta_aval
.
size
out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
mu_aval
=
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
mu_rsigama_dtype
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_aval
=
mu_aval
.
update
(
shape
=
(
1
,))
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
colwise_out_shape
=
x_aval
.
shape
if
is_2x
else
(
1
,)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_out_shape
,
dtype
=
out_dtype
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_shape
=
colwise_scale_inv_shape
if
is_2x
else
(
1
,)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_norm_fwd_workspace_sizes
(
(
wkspace_info
,)
=
transformer_engine_jax
.
get_norm_fwd_workspace_sizes
(
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
x_aval
.
size
//
gamma_aval
.
size
,
# batch size
gamma_aval
.
size
,
# hidden size
gamma_aval
.
size
,
# hidden size
...
@@ -112,33 +154,13 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -112,33 +154,13 @@ class NormFwdPrimitive(BasePrimitive):
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# wtype
jax_dtype_to_te_dtype
(
gamma_aval
.
dtype
),
# wtype
jax_dtype_to_te_dtype
(
out_dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
norm_type
,
norm_type
,
scaling_mode
.
value
,
scaling_mode
,
zero_centered_gamma
,
zero_centered_gamma
,
epsilon
,
epsilon
,
get_forward_sm_margin
(),
get_forward_sm_margin
(),
is_2x
,
is_2x
,
)
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
mu_aval
=
rsigma_aval
=
out_aval
.
update
(
shape
=
out_aval
.
shape
[:
-
1
],
dtype
=
mu_rsigama_dtype
)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_aval
=
mu_aval
.
update
(
shape
=
(
1
,))
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
scaling_mode
.
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
x_aval
.
shape
if
is_2x
else
(
1
,),
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
x_aval
.
update
(
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
)
...
@@ -274,17 +296,17 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -274,17 +296,17 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes
=
scale_shapes
,
scale_shapes
=
scale_shapes
,
is_outer
=
False
,
is_outer
=
False
,
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
scaling_mode
.
get_scale_shape_2x
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
x
.
shape
,
is_padded
=
False
scaling_mode
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
# slice out padding for mxfp8, noop for DelayedScaling
scale_inv
=
scale_inv
.
flatten
()[:
reduce
(
operator
.
mul
,
rowwise_scale_inv_shape
,
1
)].
reshape
(
rowwise_scale_inv_shape
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
if
is_2x
:
scale_inv
=
scale_inv
.
flatten
()[
colwise_scale_inv
=
colwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
rowwise_scale_inv_shape
)
:
reduce
(
operator
.
mul
,
colwise_scale_inv_shape
,
1
)
].
reshape
(
rowwise_scale_inv_shape
)
].
reshape
(
colwise_scale_inv_shape
)
if
is_2x
:
colwise_scale_inv
=
colwise_scale_inv
.
flatten
()[
:
reduce
(
operator
.
mul
,
colwise_scale_inv_shape
)
].
reshape
(
colwise_scale_inv_shape
)
return
(
return
(
out
,
out
,
colwise_out
,
colwise_out
,
...
@@ -364,6 +386,8 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -364,6 +386,8 @@ class NormFwdPrimitive(BasePrimitive):
del
zero_centered_gamma
,
epsilon
,
out_dtype
,
result_infos
del
zero_centered_gamma
,
epsilon
,
out_dtype
,
result_infos
del
scale_dtype
,
scale_shapes
,
is_outer
del
scale_dtype
,
scale_shapes
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_spec
=
(
*
x_spec
[:
-
1
],
None
)
if
x_spec
[
-
1
]
is
not
None
:
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
...
@@ -371,34 +395,27 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -371,34 +395,27 @@ class NormFwdPrimitive(BasePrimitive):
"and hurt performance."
"and hurt performance."
)
)
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"NormFwdPrimitive.out"
)
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormFwdPrimitive.out"
colwise_out_spec
=
out_spec
if
is_2x
else
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"NormFwdPrimitive.colwise_out"
)
)
if
is_2x
:
colwise_out_sharding
=
out_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.colwise_out"
)
else
:
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.colwise_out"
)
rsigma_sharding
=
NamedSharding
(
rsigma_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
)
)
mu_sharding
=
rsigma_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.mu"
)
mu_spec
=
x_spec
[:
-
1
]
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
else
(
None
,)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
mu_spec
),
desc
=
"NormFwdPrimitive.mu"
)
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.mu"
)
scale_inv_spec
=
amax_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
out_spec
scale_inv_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])
),
desc
=
"NormFwdPrimitive.scale_inv"
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"NormFwdPrimitive.amax"
)
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.amax"
)
output
=
(
output
=
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
...
@@ -427,8 +444,11 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -427,8 +444,11 @@ class NormFwdPrimitive(BasePrimitive):
):
):
del
result_infos
,
is_outer
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
g_spec
=
get_padded_spec
(
arg_infos
[
2
])
g_spec
=
get_padded_spec
(
arg_infos
[
2
])
b_spec
=
get_padded_spec
(
arg_infos
[
3
])
b_spec
=
get_padded_spec
(
arg_infos
[
3
])
out_spec
=
(
*
x_spec
[:
-
1
],
None
)
if
x_spec
[
-
1
]
is
not
None
:
if
x_spec
[
-
1
]
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
f
"Does not support to shard hidden dim in
{
NormFwdPrimitive
.
name
}
! "
...
@@ -445,43 +465,30 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -445,43 +465,30 @@ class NormFwdPrimitive(BasePrimitive):
f
"
{
NormFwdPrimitive
.
name
}
does not support sharding of parameter beta "
f
"
{
NormFwdPrimitive
.
name
}
does not support sharding of parameter beta "
"Enforcing no sharding of parameters hidden dim! "
"Enforcing no sharding of parameters hidden dim! "
)
)
x_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
None
),
desc
=
"NormFwdPrimitive.x"
)
g_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.gamma"
)
b_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.beta"
)
out_sharding
=
x_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.out"
)
if
is_2x
:
colwise_out_sharding
=
out_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.colwise_out"
)
else
:
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.colwise_out"
)
out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
out_spec
),
desc
=
"NormFwdPrimitive.out"
)
colwise_out_spec
=
out_spec
if
is_2x
else
(
None
,)
colwise_out_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"NormFwdPrimitive.colwise_out"
)
rsigma_sharding
=
NamedSharding
(
rsigma_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
0
])[:
-
1
]),
desc
=
"NormFwdPrimitive.rsigma"
,
)
)
mu_sharding
=
rsigma_sharding
.
duplicate_with_new_description
(
"NormFwdPrimitive.mu"
)
mu_spec
=
x_spec
[:
-
1
]
if
norm_type
==
NVTE_Norm_Type
.
LayerNorm
else
(
None
,)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
:
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
mu_spec
),
desc
=
"NormFwdPrimitive.mu"
)
mu_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.mu"
)
scale_sharding
=
NamedSharding
(
scale_inv_spec
=
amax_spec
=
(
None
,)
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])),
desc
=
"NormFwdPrimitive.scale"
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
)
scale_inv_spec
=
amax_spec
=
scale_spec
scale_inv_sharding
=
scale_sharding
.
duplicate_with_new_description
(
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
"NormFwdPrimitive.scale_inv"
scale_inv_spec
=
out_spec
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
None
),
desc
=
"NormFwdPrimitive.amax"
)
amax_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"NormFwdPrimitive.amax"
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
:
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"NormFwdPrimitive.scale_inv"
)
arg_shardings
=
(
x_sharding
,
scale_sharding
,
g_sharding
,
b_sharding
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
out_shardings
=
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
...
@@ -517,7 +524,7 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -517,7 +524,7 @@ class NormFwdPrimitive(BasePrimitive):
scale_shapes
=
scale_shapes
,
scale_shapes
=
scale_shapes
,
is_outer
=
True
,
is_outer
=
True
,
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
else
:
global_updated_amax
=
local_amax
global_updated_amax
=
local_amax
...
@@ -534,6 +541,57 @@ class NormFwdPrimitive(BasePrimitive):
...
@@ -534,6 +541,57 @@ class NormFwdPrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
norm_type
,
zero_centered_gamma
,
epsilon
,
out_dtype
,
scaling_mode
,
is_2x
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
value_types
,
result_types
,
):
del
(
zero_centered_gamma
,
epsilon
,
out_dtype
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
result_types
,
)
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
len
(
value_types
[
0
].
shape
),
unique_var
=
"i"
,
flatten_axis
=-
1
)
x_axes
=
scale_rules
.
input_spec
out
=
x_axes
[:
-
1
]
+
(
"k"
,)
colwise_out
=
out
if
is_2x
else
(
"…4"
,)
rsigma
=
x_axes
[:
-
1
]
mu
=
(
"…5"
,)
if
norm_type
==
NVTE_Norm_Type
.
RMSNorm
else
rsigma
amax
=
(
"…6"
,)
return
SdyShardingRule
(
(
x_axes
,
(
"…1"
,),
(
"…2"
,),
(
"…3"
,)),
(
out
,
colwise_out
,
scale_rules
.
rowwise_rule
,
scale_rules
.
colwise_rule
,
amax
,
mu
,
rsigma
,
),
**
scale_rules
.
factor_sizes
,
)
register_primitive
(
NormFwdPrimitive
)
register_primitive
(
NormFwdPrimitive
)
...
@@ -737,6 +795,11 @@ class NormBwdPrimitive(BasePrimitive):
...
@@ -737,6 +795,11 @@ class NormBwdPrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"...0, ...1 i, ...2, ...3, ...4 -> ...1 j, k, l"
register_primitive
(
NormBwdPrimitive
)
register_primitive
(
NormBwdPrimitive
)
...
@@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
...
@@ -746,6 +809,10 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None)
JAX native layernorm implementation
JAX native layernorm implementation
"""
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
if
not
is_norm_zero_centered_gamma_in_weight_dtype
(
quantizer
.
scaling_mode
if
quantizer
else
ScalingMode
.
NO_SCALING
):
gamma
=
gamma
.
astype
(
jnp
.
float32
)
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
mean
=
jnp
.
mean
(
x_
,
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
-
mean
),
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
...
@@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
...
@@ -767,6 +834,10 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None):
JAX native rmsnorm implementation
JAX native rmsnorm implementation
"""
"""
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
x_
=
jnp
.
asarray
(
x
,
jnp
.
float32
)
if
not
is_norm_zero_centered_gamma_in_weight_dtype
(
quantizer
.
scaling_mode
if
quantizer
else
ScalingMode
.
NO_SCALING
):
gamma
=
gamma
.
astype
(
jnp
.
float32
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
True
)
var
=
jnp
.
mean
(
jnp
.
square
(
x_
),
axis
=-
1
,
keepdims
=
True
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
rsigma
=
jax
.
lax
.
rsqrt
(
var
+
epsilon
)
normed_input
=
x_
*
rsigma
normed_input
=
x_
*
rsigma
...
@@ -816,7 +887,7 @@ def layernorm_fwd(
...
@@ -816,7 +887,7 @@ def layernorm_fwd(
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
# TE/common does not support normalization with colwise only quantization yet
# TE/common does not support normalization with colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
COLWISE
:
if
quantizer
is
not
None
and
quantizer
.
q_
layout
==
Quantize
Layout
.
COLWISE
:
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
return
_jax_layernorm
(
x
,
gamma
,
beta
,
zero_centered_gamma
,
epsilon
,
quantizer
)
scale
=
(
scale
=
(
...
@@ -824,7 +895,6 @@ def layernorm_fwd(
...
@@ -824,7 +895,6 @@ def layernorm_fwd(
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
)
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
)
else
jnp
.
ones
((
1
,),
dtype
=
jnp
.
float32
)
else
jnp
.
ones
((
1
,),
dtype
=
jnp
.
float32
)
)
)
if
quantizer
is
None
:
if
quantizer
is
None
:
output
,
_
,
_
,
_
,
_
,
mu
,
rsigma
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
output
,
_
,
_
,
_
,
_
,
mu
,
rsigma
=
NormFwdPrimitive
.
outer_primitive
.
bind
(
x
,
x
,
...
@@ -835,7 +905,7 @@ def layernorm_fwd(
...
@@ -835,7 +905,7 @@ def layernorm_fwd(
zero_centered_gamma
=
zero_centered_gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
epsilon
=
epsilon
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
scaling_mode
=
ScalingMode
.
N
VTE_DELAYED_TENSOR
_SCALING
,
scaling_mode
=
ScalingMode
.
N
O
_SCALING
.
value
,
is_2x
=
False
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((
1
,),
(
1
,)),
scale_shapes
=
((
1
,),
(
1
,)),
...
@@ -845,7 +915,7 @@ def layernorm_fwd(
...
@@ -845,7 +915,7 @@ def layernorm_fwd(
is_2x2x
=
quantizer
.
is_2x2x
()
is_2x2x
=
quantizer
.
is_2x2x
()
# TE/common normalization doesn't support 2x delayed scaling
# TE/common normalization doesn't support 2x delayed scaling
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
is_2x2x
=
False
is_2x2x
=
False
(
(
rowwise_casted_output
,
rowwise_casted_output
,
...
@@ -864,7 +934,7 @@ def layernorm_fwd(
...
@@ -864,7 +934,7 @@ def layernorm_fwd(
zero_centered_gamma
=
zero_centered_gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
epsilon
=
epsilon
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
is_2x2x
,
is_2x
=
is_2x2x
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
...
@@ -873,7 +943,7 @@ def layernorm_fwd(
...
@@ -873,7 +943,7 @@ def layernorm_fwd(
quantizer
.
update
(
updated_amax
)
quantizer
.
update
(
updated_amax
)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
colwise_casted_output
=
jnp
.
transpose
(
colwise_casted_output
=
jnp
.
transpose
(
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
)
)
...
@@ -882,7 +952,7 @@ def layernorm_fwd(
...
@@ -882,7 +952,7 @@ def layernorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
if
quantizer
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
x
.
shape
,
is_padded
=
False
x
.
shape
,
is_padded
=
False
)
)
...
@@ -900,8 +970,8 @@ def layernorm_fwd(
...
@@ -900,8 +970,8 @@ def layernorm_fwd(
colwise_scale_inv
=
colwise_scale_inv
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
dq_dtype
=
x
.
dtype
,
q_
axis
=
quantizer
.
q_
axis
,
q_
layout
=
quantizer
.
q_
layout
,
layout
=
quantizer
.
get_layout
(),
data_
layout
=
quantizer
.
get_
data_
layout
(),
)
)
return
scaled_tensor
,
mu
,
rsigma
return
scaled_tensor
,
mu
,
rsigma
...
@@ -997,7 +1067,7 @@ def rmsnorm_fwd(
...
@@ -997,7 +1067,7 @@ def rmsnorm_fwd(
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
# TE/common does not support normalization with colwise only quantization yet
# TE/common does not support normalization with colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
COLWISE
:
if
quantizer
is
not
None
and
quantizer
.
q_
layout
==
Quantize
Layout
.
COLWISE
:
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
return
_jax_rmsnorm
(
x
,
gamma
,
zero_centered_gamma
,
epsilon
,
quantizer
)
scale
=
(
scale
=
(
...
@@ -1017,7 +1087,7 @@ def rmsnorm_fwd(
...
@@ -1017,7 +1087,7 @@ def rmsnorm_fwd(
zero_centered_gamma
=
zero_centered_gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
epsilon
=
epsilon
,
out_dtype
=
x
.
dtype
,
out_dtype
=
x
.
dtype
,
scaling_mode
=
ScalingMode
.
N
VTE_DELAYED_TENSOR
_SCALING
,
scaling_mode
=
ScalingMode
.
N
O
_SCALING
.
value
,
is_2x
=
False
,
is_2x
=
False
,
scale_dtype
=
jnp
.
float32
,
scale_dtype
=
jnp
.
float32
,
scale_shapes
=
((),
()),
scale_shapes
=
((),
()),
...
@@ -1027,7 +1097,7 @@ def rmsnorm_fwd(
...
@@ -1027,7 +1097,7 @@ def rmsnorm_fwd(
is_2x2x
=
quantizer
.
is_2x2x
()
is_2x2x
=
quantizer
.
is_2x2x
()
# TE/common normalization doesn't support 2x delayed scaling
# TE/common normalization doesn't support 2x delayed scaling
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
is_2x2x
=
False
is_2x2x
=
False
(
(
rowwise_casted_output
,
rowwise_casted_output
,
...
@@ -1046,7 +1116,7 @@ def rmsnorm_fwd(
...
@@ -1046,7 +1116,7 @@ def rmsnorm_fwd(
zero_centered_gamma
=
zero_centered_gamma
,
zero_centered_gamma
=
zero_centered_gamma
,
epsilon
=
epsilon
,
epsilon
=
epsilon
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
is_2x
=
is_2x2x
,
is_2x
=
is_2x2x
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
...
@@ -1055,7 +1125,7 @@ def rmsnorm_fwd(
...
@@ -1055,7 +1125,7 @@ def rmsnorm_fwd(
quantizer
.
update
(
updated_amax
)
quantizer
.
update
(
updated_amax
)
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
# TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
:
if
quantizer
.
is_2x2x
()
and
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
colwise_casted_output
=
jnp
.
transpose
(
colwise_casted_output
=
jnp
.
transpose
(
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
rowwise_casted_output
,
(
-
1
,
*
range
(
rowwise_casted_output
.
ndim
-
1
))
)
)
...
@@ -1064,7 +1134,7 @@ def rmsnorm_fwd(
...
@@ -1064,7 +1134,7 @@ def rmsnorm_fwd(
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# cuDNN MXFP8 Norm does not support padding but we enforced padded scale inputs for nvte APIs.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# So here we need to slice out the zero tail and reshape it to the unpadded scale shape.
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
# The ScaledTensorFactory takes care of padding when creating the ScaledTensor
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
MXFP8_1D_SCALING
:
if
quantizer
.
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
rowwise_unpadded_shape
,
colwise_unpadded_shape
=
quantizer
.
get_scale_shapes
(
x
.
shape
,
is_padded
=
False
x
.
shape
,
is_padded
=
False
)
)
...
@@ -1082,8 +1152,8 @@ def rmsnorm_fwd(
...
@@ -1082,8 +1152,8 @@ def rmsnorm_fwd(
colwise_scale_inv
=
colwise_scale_inv
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
x
.
dtype
,
dq_dtype
=
x
.
dtype
,
q_
axis
=
quantizer
.
q_
axis
,
q_
layout
=
quantizer
.
q_
layout
,
layout
=
quantizer
.
get_layout
(),
data_
layout
=
quantizer
.
get_
data_
layout
(),
)
)
return
scaled_tensor
,
rsigma
return
scaled_tensor
,
rsigma
...
...
transformer_engine/jax/cpp_extensions/quantization.py
View file @
ab3e5a92
...
@@ -2,12 +2,15 @@
...
@@ -2,12 +2,15 @@
#
#
# See LICENSE for license information.
# See LICENSE for license information.
"""JAX/TE custom ops for quantization"""
"""JAX/TE custom ops for quantization"""
import
operator
from
functools
import
reduce
from
typing
import
Tuple
,
Optional
from
typing
import
Tuple
,
Optional
from
packaging
import
version
from
packaging
import
version
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
jax
import
dtypes
from
jax
import
dtypes
from
jax.experimental.custom_partitioning
import
SdyShardingRule
from
jax.sharding
import
PartitionSpec
from
jax.sharding
import
PartitionSpec
import
transformer_engine_jax
import
transformer_engine_jax
...
@@ -24,7 +27,7 @@ from .misc import (
...
@@ -24,7 +27,7 @@ from .misc import (
)
)
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..sharding
import
all_reduce_max_along_all_axes_except_PP
,
all_reduce_sum_along_dp_fsdp
from
..quantize
import
ScaledTensor2x
,
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
ScaledTensor2x
,
ScaledTensor
,
ScaledTensorFactory
from
..quantize
import
Quantizer
,
Quantize
Axis
,
DelayedScaleQuantizer
,
ScalingMode
from
..quantize
import
Quantizer
,
Quantize
Layout
,
DelayedScaleQuantizer
,
ScalingMode
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
if
version
.
parse
(
jax
.
__version__
)
>=
version
.
parse
(
"0.5.0"
):
from
jax
import
ffi
# pylint: disable=ungrouped-imports
from
jax
import
ffi
# pylint: disable=ungrouped-imports
...
@@ -50,7 +53,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -50,7 +53,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
6
,
6
,
7
,
7
,
8
,
8
,
)
# out_dtype, scaling_mode, q_axis, scale_dtype, scale_shapes, is_dbias, is_outer
9
,
)
# out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, scale_shapes, is_dbias, is_outer
inner_primitive
=
None
inner_primitive
=
None
outer_primitive
=
None
outer_primitive
=
None
...
@@ -61,7 +65,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -61,7 +65,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*
,
*
,
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -73,49 +78,56 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -73,49 +78,56 @@ class DBiasQuantizePrimitive(BasePrimitive):
del
scale_shapes
del
scale_shapes
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
dtype
=
dtypes
.
canonicalize_dtype
(
x_aval
.
dtype
)
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
assert
dtype
in
[
jnp
.
float32
,
jnp
.
float16
,
jnp
.
bfloat16
]
out_shape
=
x_aval
.
shape
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
assert
scale_aval
is
None
or
scale_aval
.
dtype
==
jnp
.
float32
rowwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
if
q_layout
in
(
QuantizeLayout
.
ROWWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
rowwise_out_shape
=
out_shape
if
q_axis
in
(
QuantizeAxis
.
ROWWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
else
:
rowwise_out_aval
=
x_aval
.
update
(
shape
=
x_aval
.
shape
,
dtype
=
out_dtype
)
rowwise_out_shape
=
(
1
,)
rowwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_out_shape
,
dtype
=
out_dtype
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
updated_amax_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
)
).
get_scale_shape_2x
(
x_aval
.
shape
,
is_padded
=
not
is_outer
,
flatten_axis
=
flatten_axis
)
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out_shape
=
multidim_transpose
(
out_shape
,
transpose_axis
=
flatten_axis
)
else
:
colwise_out_shape
=
out_shape
else
:
colwise_out_shape
=
(
1
,)
colwise_scale_inv_shape
=
(
1
,)
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_out_shape
,
dtype
=
out_dtype
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
rowwise_scale_inv_shape
,
dtype
=
scale_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
colwise_out_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
out_dtype
)
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
scale_dtype
)
)
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
(
1
,),
dtype
=
jnp
.
float32
)
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
t_shape
=
multidim_transpose
(
x_aval
.
shape
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
# Don't transpose output for MXFP8
t_shape
=
x_aval
.
shape
colwise_out_aval
=
x_aval
.
update
(
shape
=
t_shape
,
dtype
=
out_dtype
)
colwise_scale_inv_aval
=
jax
.
core
.
ShapedArray
(
shape
=
colwise_scale_inv_shape
,
dtype
=
scale_dtype
)
if
is_dbias
:
if
is_dbias
:
gi_hidden_size
=
x_aval
.
shape
[
-
1
]
dbias_shape
=
x_aval
.
shape
[
flatten_axis
:]
dbias_shape
=
(
gi_hidden_size
,)
gi_hidden_size
=
reduce
(
operator
.
mul
,
x_aval
.
shape
[
flatten_axis
:],
1
)
dbias_aval
=
x_aval
.
update
(
shape
=
dbias_shape
,
dtype
=
dtype
)
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dbias_quantize_workspace_sizes
(
(
wkspace_info
,)
=
transformer_engine_jax
.
get_dbias_quantize_workspace_sizes
(
x_aval
.
size
//
gi_hidden_size
,
x_aval
.
size
//
gi_hidden_size
,
gi_hidden_size
,
gi_hidden_size
,
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
x_aval
.
dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
jax_dtype_to_te_dtype
(
out_dtype
),
scaling_mode
,
QuantizeLayout
(
q_layout
),
# For now until we have auto-decoding for QuantizeLayout enum
)
)
wkspace_aval
=
x_aval
.
update
(
wkspace_shape
=
wkspace_info
[
0
]
shape
=
wkspace_info
[
0
],
dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
wkspace_dtype
=
te_dtype_to_jax_dtype
(
wkspace_info
[
1
])
)
else
:
dbias_shape
=
(
1
,)
wkspace_shape
=
(
1
,)
wkspace_dtype
=
jnp
.
float32
dbias_aval
=
jax
.
core
.
ShapedArray
(
shape
=
dbias_shape
,
dtype
=
dtype
)
wkspace_aval
=
jax
.
core
.
ShapedArray
(
shape
=
wkspace_shape
,
dtype
=
wkspace_dtype
)
return
(
return
(
rowwise_out_aval
,
rowwise_out_aval
,
...
@@ -151,7 +163,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -151,7 +163,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*
,
*
,
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -168,8 +181,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -168,8 +181,9 @@ class DBiasQuantizePrimitive(BasePrimitive):
ctx
,
ctx
,
x
,
x
,
scale
,
scale
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
.
value
,
q_axis
=
q_axis
,
q_layout
=
q_layout
,
flatten_axis
=
flatten_axis
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
)
)
...
@@ -179,7 +193,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -179,7 +193,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale
,
scale
,
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -203,7 +218,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -203,7 +218,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale
,
scale
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
q_layout
=
q_layout
,
flatten_axis
=
flatten_axis
,
scale_dtype
=
scale_dtype
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
...
@@ -211,16 +227,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -211,16 +227,14 @@ class DBiasQuantizePrimitive(BasePrimitive):
)
)
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
rowwise_scale_inv_shape
,
colwise_scale_inv_shape
=
ScalingMode
(
scaling_mode
scaling_mode
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
)
).
get_scale_shape_2x
(
x
.
shape
,
is_padded
=
False
,
flatten_axis
=
flatten_axis
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
scale_inv
=
jax
.
lax
.
slice
(
if
q_axis
in
(
QuantizeAxis
.
ROWWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
scale_inv
=
jax
.
lax
.
slice
(
)
scale_inv
,
[
0
]
*
len
(
rowwise_scale_inv_shape
),
rowwise_scale_inv_shape
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
)
colwise_scale_inv
=
jax
.
lax
.
slice
(
if
q_axis
in
(
QuantizeAxis
.
COLWISE
.
value
,
QuantizeAxis
.
ROWWISE_COLWISE
.
value
):
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
colwise_scale_inv
=
jax
.
lax
.
slice
(
)
colwise_scale_inv
,
[
0
]
*
len
(
colwise_scale_inv_shape
),
colwise_scale_inv_shape
)
return
(
return
(
out
,
out
,
colwise_out
,
colwise_out
,
...
@@ -237,7 +251,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -237,7 +251,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
*
,
*
,
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -260,7 +275,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -260,7 +275,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale
,
scale
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
q_layout
=
q_layout
,
flatten_axis
=
flatten_axis
,
scale_dtype
=
scale_dtype
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
...
@@ -272,7 +288,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -272,7 +288,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def
infer_sharding_from_operands
(
def
infer_sharding_from_operands
(
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -281,16 +298,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -281,16 +298,17 @@ class DBiasQuantizePrimitive(BasePrimitive):
arg_infos
,
arg_infos
,
result_infos
,
result_infos
,
):
):
del
(
out_dtype
,
result_infos
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
)
# Unused.
del
(
out_dtype
,
result_infos
,
scale_dtype
,
scale_shapes
,
is_outer
)
# Unused.
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
]
),
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
)
)
if
q_
axis
in
(
Quantize
Axis
.
COLWISE
.
value
,
Quantize
Axis
.
ROWWISE_COLWISE
.
value
):
if
q_
layout
in
(
Quantize
Layout
.
COLWISE
.
value
,
Quantize
Layout
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
x_spec
)
colwise_out_spec
=
multidim_transpose
(
x_spec
,
transpose_axis
=
flatten_axis
)
else
:
else
:
colwise_out_spec
=
x_spec
colwise_out_spec
=
x_spec
else
:
else
:
...
@@ -300,26 +318,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -300,26 +318,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec
(
*
colwise_out_spec
),
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
)
)
scale_inv_sharding
=
NamedSharding
(
dbias_spec
=
x_spec
[
flatten_axis
:]
if
is_dbias
else
(
None
,)
dbias_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])
),
PartitionSpec
(
*
dbias_spec
),
desc
=
"DBiasQuantizePrimitive.
scale_
in
v
"
,
desc
=
"DBiasQuantizePrimitive.
dbias_shard
in
g
"
,
)
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
desc
=
"DBiasQuantizePrimitive.amax_sharding"
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
x_spec
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
amax_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"DBiasQuantizePrimitive.amax"
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
)
dbias
_sharding
=
NamedSharding
(
colwise_scale_inv
_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
x
_spec
[
-
1
]
),
PartitionSpec
(
*
colwise_scale_inv
_spec
),
desc
=
"DBiasQuantizePrimitive.
dbias_shard
in
g
"
,
desc
=
"DBiasQuantizePrimitive.
colwise_scale_
in
v
"
,
)
)
return
(
return
(
out_sharding
,
out_sharding
,
colwise_out_sharding
,
colwise_out_sharding
,
...
@@ -333,7 +360,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -333,7 +360,8 @@ class DBiasQuantizePrimitive(BasePrimitive):
def
partition
(
def
partition
(
out_dtype
,
out_dtype
,
scaling_mode
,
scaling_mode
,
q_axis
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_dtype
,
scale_shapes
,
scale_shapes
,
is_dbias
,
is_dbias
,
...
@@ -344,14 +372,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -344,14 +372,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
):
):
del
result_infos
,
is_outer
del
result_infos
,
is_outer
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
x_spec
=
get_padded_spec
(
arg_infos
[
0
])
scale_spec
=
get_padded_spec
(
arg_infos
[
1
])
out_sharding
=
NamedSharding
(
out_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
x_spec
[:
-
1
],
x_spec
[
-
1
]
),
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
desc
=
"DBiasQuantizePrimitive.out_sharding"
,
)
)
if
q_
axis
in
(
Quantize
Axis
.
COLWISE
.
value
,
Quantize
Axis
.
ROWWISE_COLWISE
.
value
):
if
q_
layout
in
(
Quantize
Layout
.
COLWISE
.
value
,
Quantize
Layout
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out_spec
=
multidim_transpose
(
x_spec
)
colwise_out_spec
=
multidim_transpose
(
x_spec
,
transpose_axis
=
flatten_axis
)
else
:
else
:
colwise_out_spec
=
x_spec
colwise_out_spec
=
x_spec
else
:
else
:
...
@@ -361,26 +390,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -361,26 +390,35 @@ class DBiasQuantizePrimitive(BasePrimitive):
PartitionSpec
(
*
colwise_out_spec
),
PartitionSpec
(
*
colwise_out_spec
),
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
desc
=
"DBiasQuantizePrimitive.colwise_out_sharding"
,
)
)
scale_inv_sharding
=
NamedSharding
(
dbias_spec
=
x_spec
[
flatten_axis
:]
if
is_dbias
else
(
None
,)
dbias_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
*
get_padded_spec
(
arg_infos
[
1
])
),
PartitionSpec
(
*
dbias_spec
),
desc
=
"DBiasQuantizePrimitive.
scale_
in
v
"
,
desc
=
"DBiasQuantizePrimitive.
dbias_shard
in
g
"
,
)
)
amax_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
desc
=
"DBiasQuantizePrimitive.amax_sharding"
scale_inv_spec
=
amax_spec
=
colwise_scale_inv_spec
=
(
None
,)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
scale_inv_spec
=
amax_spec
=
scale_spec
elif
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
.
value
:
scale_inv_spec
=
x_spec
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
colwise_scale_inv_spec
=
scale_inv_spec
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
scale_inv_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_MXFP8_1D_SCALING
.
value
:
amax_sharding
=
NamedSharding
(
scale_inv_sharding
=
NamedSharding
(
mesh
,
PartitionSpec
(
*
amax_spec
),
desc
=
"DBiasQuantizePrimitive.amax"
mesh
,
PartitionSpec
(
*
x_spec
),
desc
=
"DBiasQuantizePrimitive.scale_inv"
)
colwise_scale_inv_sharding
=
scale_inv_sharding
.
duplicate_with_new_description
(
"DBiasQuantizePrimitive.colwise_scale_inv"
)
)
dbias
_sharding
=
NamedSharding
(
colwise_scale_inv
_sharding
=
NamedSharding
(
mesh
,
mesh
,
PartitionSpec
(
x
_spec
[
-
1
]
),
PartitionSpec
(
*
colwise_scale_inv
_spec
),
desc
=
"DBiasQuantizePrimitive.
dbias_shard
in
g
"
,
desc
=
"DBiasQuantizePrimitive.
colwise_scale_
in
v
"
,
)
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
arg_shardings
=
tuple
(
arg_i
.
sharding
for
arg_i
in
arg_infos
)
out_shardings
=
(
out_shardings
=
(
out_sharding
,
out_sharding
,
...
@@ -404,14 +442,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -404,14 +442,15 @@ class DBiasQuantizePrimitive(BasePrimitive):
scale
,
scale
,
out_dtype
=
out_dtype
,
out_dtype
=
out_dtype
,
scaling_mode
=
scaling_mode
,
scaling_mode
=
scaling_mode
,
q_axis
=
q_axis
,
q_layout
=
q_layout
,
flatten_axis
=
flatten_axis
,
scale_dtype
=
scale_dtype
,
scale_dtype
=
scale_dtype
,
scale_shapes
=
scale_shapes
,
scale_shapes
=
scale_shapes
,
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
is_outer
=
True
,
is_outer
=
True
,
)
)
if
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
.
value
:
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
global_updated_amax
=
all_reduce_max_along_all_axes_except_PP
(
local_amax
,
mesh
)
else
:
else
:
global_updated_amax
=
local_amax
global_updated_amax
=
local_amax
...
@@ -432,53 +471,91 @@ class DBiasQuantizePrimitive(BasePrimitive):
...
@@ -432,53 +471,91 @@ class DBiasQuantizePrimitive(BasePrimitive):
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
return
mesh
,
sharded_impl
,
out_shardings
,
arg_shardings
@
staticmethod
def
shardy_sharding_rule
(
out_dtype
,
scaling_mode
,
q_layout
,
flatten_axis
,
scale_dtype
,
scale_shapes
,
is_dbias
,
is_outer
,
mesh
,
value_types
,
result_types
,
):
del
out_dtype
,
scale_dtype
,
scale_shapes
,
is_outer
,
mesh
,
result_types
scale_rules
=
ScalingMode
(
scaling_mode
).
get_shardy_sharding_rules
(
len
(
value_types
[
0
].
shape
),
unique_var
=
"i"
,
flatten_axis
=
flatten_axis
)
x_axes
=
scale_rules
.
input_spec
colwise_scale_inv
=
scale_rules
.
colwise_rule
out
=
x_axes
if
q_layout
in
(
QuantizeLayout
.
COLWISE
.
value
,
QuantizeLayout
.
ROWWISE_COLWISE
.
value
):
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
.
value
:
colwise_out
=
tuple
(
multidim_transpose
(
x_axes
,
transpose_axis
=
flatten_axis
))
else
:
colwise_out
=
x_axes
else
:
colwise_out
=
(
"j"
,)
colwise_scale_inv
=
(
"k"
,)
dbias
=
x_axes
[
flatten_axis
:]
if
is_dbias
else
(
"l"
,)
amax
=
(
"m"
,)
return
SdyShardingRule
(
(
x_axes
,
(
"…1"
,)),
(
out
,
colwise_out
,
scale_rules
.
rowwise_rule
,
colwise_scale_inv
,
amax
,
dbias
),
**
scale_rules
.
factor_sizes
,
)
register_primitive
(
DBiasQuantizePrimitive
)
register_primitive
(
DBiasQuantizePrimitive
)
def
_jax_quantize
(
x
,
quantizer
:
Quantizer
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
):
def
_jax_quantize
(
x
,
quantizer
:
Quantizer
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
):
if
quantizer
is
None
:
if
quantizer
is
None
:
return
x
return
x
return
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
)
return
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
def
_jax_dbias
(
dx
:
jnp
.
ndarray
):
def
_jax_dbias
(
dx
:
jnp
.
ndarray
,
dtype
=
None
,
flatten_axis
:
int
=
-
1
):
assert
flatten_axis
<
0
dtype
=
dtype
or
dx
.
dtype
dbias
=
jnp
.
sum
(
dbias
=
jnp
.
sum
(
dx
,
dx
.
astype
(
jnp
.
float32
)
,
axis
=
tuple
(
range
(
dx
.
ndim
-
1
)),
axis
=
tuple
(
range
(
dx
.
ndim
+
flatten_axis
)),
keepdims
=
False
,
keepdims
=
False
,
)
)
dbias
=
dbias
.
ravel
()
# C++ function returns an 1D array for dbias
return
dbias
.
astype
(
dtype
)
return
dbias
def
_jax_quantize_dbias
(
def
_jax_quantize_dbias
(
x
,
x
,
quantizer
:
Quantizer
=
None
,
quantizer
:
Quantizer
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
,
):
):
if
quantizer
is
None
:
if
quantizer
is
None
:
return
x
,
None
return
x
,
None
return
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
),
_jax_dbias
(
x
)
return
(
quantizer
.
quantize
(
x
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
def
_jax_dbias
(
dx
:
jnp
.
ndarray
,
):
dbias
=
jnp
.
sum
(
dx
.
astype
(
jnp
.
float32
),
axis
=
tuple
(
range
(
dx
.
ndim
-
1
)),
keepdims
=
False
,
)
)
dbias
=
dbias
.
ravel
()
# C++ function returns an 1D array for dbias
return
dbias
.
astype
(
dx
.
dtype
)
def
_quantize_impl
(
def
_quantize_
dbias_
impl
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
quantizer
:
Quantizer
,
is_dbias
:
bool
=
False
,
is_dbias
:
bool
=
False
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""
"""
Cast wrapper
Cast wrapper
...
@@ -488,40 +565,51 @@ def _quantize_impl(
...
@@ -488,40 +565,51 @@ def _quantize_impl(
quantizer
is
not
None
quantizer
is
not
None
),
"quantizer must be provided if dq_dtype is provided"
),
"quantizer must be provided if dq_dtype is provided"
dq_dtype
=
dq_dtype
or
x
.
dtype
if
not
DBiasQuantizePrimitive
.
enabled
():
if
not
DBiasQuantizePrimitive
.
enabled
():
if
is_dbias
:
if
is_dbias
:
return
_jax_quantize_dbias
(
return
_jax_quantize_dbias
(
x
,
x
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
return
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
),
None
return
(
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
None
,
)
# TE/common doesn't support colwise only quantization yet
# TE/common doesn't support colwise only quantization yet
if
quantizer
is
not
None
and
quantizer
.
q_
axis
==
Quantize
Axis
.
COLWISE
:
if
quantizer
is
not
None
and
quantizer
.
q_
layout
==
Quantize
Layout
.
COLWISE
:
if
is_dbias
:
if
is_dbias
:
return
_jax_quantize_dbias
(
return
_jax_quantize_dbias
(
x
,
x
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
return
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
),
None
return
(
_jax_quantize
(
x
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
),
None
,
)
scale
=
jnp
.
empty
((),
jnp
.
float32
)
scale
=
jnp
.
empty
((),
jnp
.
float32
)
# TE/common dbias_quantize does not support 1x on arch < 100
# TE/common dbias_quantize does not support 1x on arch < 100
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
if
should_apply_1x_fused_dbias_war_for_arch_l_100
(
is_dbias
=
is_dbias
,
quantizer
=
quantizer
):
out
,
_
=
_quantize_impl
(
out
,
_
=
_quantize_
dbias_
impl
(
x
=
x
,
x
=
x
,
is_dbias
=
False
,
is_dbias
=
False
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
dbias
=
_jax_dbias
(
x
)
dbias
=
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
out
,
dbias
return
out
,
dbias
if
quantizer
is
None
:
if
quantizer
is
None
:
if
is_dbias
:
if
is_dbias
:
return
x
,
_jax_dbias
(
x
)
return
x
,
_jax_dbias
(
x
,
dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
)
return
x
,
None
return
x
,
None
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
if
isinstance
(
quantizer
,
DelayedScaleQuantizer
):
...
@@ -539,14 +627,15 @@ def _quantize_impl(
...
@@ -539,14 +627,15 @@ def _quantize_impl(
scale
,
scale
,
out_dtype
=
quantizer
.
q_dtype
,
out_dtype
=
quantizer
.
q_dtype
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
scaling_mode
=
quantizer
.
scaling_mode
.
value
,
q_axis
=
quantizer
.
q_axis
.
value
,
q_layout
=
quantizer
.
q_layout
.
value
,
flatten_axis
=
flatten_axis
,
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_dtype
=
quantizer
.
get_scale_dtype
(),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
),
scale_shapes
=
quantizer
.
get_scale_shapes
(
x
.
shape
,
flatten_axis
=
flatten_axis
),
is_dbias
=
is_dbias
,
is_dbias
=
is_dbias
,
is_outer
=
True
,
is_outer
=
True
,
)
)
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
# For DelayedScaling2x, the scale buffer is shared between rowwise and colwise
if
quantizer
.
scaling_mode
==
ScalingMode
.
NVTE_
DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
if
quantizer
.
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
and
quantizer
.
is_2x2x
():
colwise_scale_inv
=
rowwise_scale_inv
colwise_scale_inv
=
rowwise_scale_inv
quantizer
.
update
(
updated_amax
)
quantizer
.
update
(
updated_amax
)
...
@@ -557,18 +646,18 @@ def _quantize_impl(
...
@@ -557,18 +646,18 @@ def _quantize_impl(
colwise_data
=
colwise_casted_output
,
colwise_data
=
colwise_casted_output
,
colwise_scale_inv
=
colwise_scale_inv
,
colwise_scale_inv
=
colwise_scale_inv
,
scaling_mode
=
quantizer
.
scaling_mode
,
scaling_mode
=
quantizer
.
scaling_mode
,
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
,
dq_dtype
=
dq_dtype
,
q_axis
=
quantizer
.
q_axis
,
q_layout
=
quantizer
.
q_layout
,
layout
=
quantizer
.
get_layout
(),
data_layout
=
quantizer
.
get_data_layout
(),
flatten_axis
=
flatten_axis
,
)
)
return
out
,
dbias
return
out
,
dbias
.
astype
(
dq_dtype
)
# TODO(Phuong): do not expose dq_dtype to users
def
quantize
(
def
quantize
(
x
:
jnp
.
ndarray
,
x
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
quantizer
:
Quantizer
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
ScaledTensor
]:
)
->
Tuple
[
ScaledTensor
]:
"""Quantize input tensor according to the quantizer.
"""Quantize input tensor according to the quantizer.
...
@@ -576,26 +665,25 @@ def quantize(
...
@@ -576,26 +665,25 @@ def quantize(
x: Input tensor to be quantized.
x: Input tensor to be quantized.
Shape: (..., K) where K is the hidden size.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
quantizer: Quantizer for FP8 quantization of the output.
dq_dtype: Optional dtype
for
de
quantization.
flatten_axis: The quantization axis in which input data can be flattened to 2D
for quantization.
If None, uses the same dtype as the input tensor
.
Defaults to -1
.
Returns:
Returns:
A ScaledTensor containing the quantized input tensor.
A ScaledTensor containing the quantized input tensor.
"""
"""
out
,
_
=
_quantize_impl
(
out
,
_
=
_quantize_
dbias_
impl
(
x
,
x
,
quantizer
=
quantizer
,
quantizer
=
quantizer
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
)
return
out
return
out
# TODO(Phuong): do not expose dq_dtype to users
def
quantize_dbias
(
def
quantize_dbias
(
dz
:
jnp
.
ndarray
,
dz
:
jnp
.
ndarray
,
quantizer
:
Quantizer
,
quantizer
:
Quantizer
,
is_dbias
:
bool
=
True
,
is_dbias
:
bool
=
True
,
dq_dtype
:
Optional
[
jnp
.
dtype
]
=
None
,
flatten_axis
:
int
=
-
1
,
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
)
->
Tuple
[
ScaledTensor2x
,
jnp
.
ndarray
]:
"""Quantize input tensor and compute bias gradient.
"""Quantize input tensor and compute bias gradient.
...
@@ -604,8 +692,8 @@ def quantize_dbias(
...
@@ -604,8 +692,8 @@ def quantize_dbias(
Shape: (..., K) where K is the hidden size.
Shape: (..., K) where K is the hidden size.
quantizer: Quantizer for FP8 quantization of the output.
quantizer: Quantizer for FP8 quantization of the output.
is_dbias: If True, compute bias gradient. Defaults to True.
is_dbias: If True, compute bias gradient. Defaults to True.
dq_dtype: Optional dtype
for
de
quantization.
flatten_axis: The quantization axis in which input data can be flattened to 2D
for quantization.
If None, uses the same dtype as the input tensor
.
Defaults to -1
.
Returns:
Returns:
A tuple containing:
A tuple containing:
...
@@ -614,9 +702,6 @@ def quantize_dbias(
...
@@ -614,9 +702,6 @@ def quantize_dbias(
- The bias gradient tensor.
- The bias gradient tensor.
Shape: (K,) or empty if is_dbias is False.
Shape: (K,) or empty if is_dbias is False.
"""
"""
return
_quantize_impl
(
return
_quantize_dbias_impl
(
dz
,
dz
,
quantizer
=
quantizer
,
is_dbias
=
is_dbias
,
flatten_axis
=
flatten_axis
quantizer
=
quantizer
,
is_dbias
=
is_dbias
,
dq_dtype
=
dq_dtype
,
)
)
transformer_engine/jax/cpp_extensions/softmax.py
View file @
ab3e5a92
...
@@ -31,6 +31,9 @@ __all__ = [
...
@@ -31,6 +31,9 @@ __all__ = [
"scaled_upper_triang_masked_softmax_fwd"
,
"scaled_upper_triang_masked_softmax_fwd"
,
"scaled_upper_triang_masked_softmax_bwd"
,
"scaled_upper_triang_masked_softmax_bwd"
,
"is_softmax_kernel_available"
,
"is_softmax_kernel_available"
,
"jax_scaled_softmax"
,
"jax_scaled_masked_softmax"
,
"jax_scaled_upper_triang_masked_softmax"
,
]
]
...
@@ -330,6 +333,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
...
@@ -330,6 +333,11 @@ class ScaledSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxFwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
ScaledSoftmaxFwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"... -> ..."
register_primitive
(
ScaledSoftmaxFwdPrimitive
)
register_primitive
(
ScaledSoftmaxFwdPrimitive
)
...
@@ -400,6 +408,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
...
@@ -400,6 +408,11 @@ class ScaledSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledSoftmaxBwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
ScaledSoftmaxBwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"..., ... -> ..."
register_primitive
(
ScaledSoftmaxBwdPrimitive
)
register_primitive
(
ScaledSoftmaxBwdPrimitive
)
...
@@ -412,7 +425,7 @@ def scaled_softmax_bwd(
...
@@ -412,7 +425,7 @@ def scaled_softmax_bwd(
Return FP16/BF16 tensor
Return FP16/BF16 tensor
"""
"""
if
not
ScaledSoftmaxBwdPrimitive
.
enabled
():
if
not
ScaledSoftmaxBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_
jax_scaled_softmax
,
scale_factor
=
scale_factor
),
logits
)
_
,
vjp_func
=
jax
.
vjp
(
partial
(
jax_scaled_softmax
,
scale_factor
=
scale_factor
),
logits
)
return
vjp_func
(
dz
)[
0
]
return
vjp_func
(
dz
)[
0
]
return
ScaledSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
return
ScaledSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
...
@@ -525,6 +538,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
...
@@ -525,6 +538,11 @@ class ScaledMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxFwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
ScaledMaskedSoftmaxFwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"...1, ...2 -> ...1"
register_primitive
(
ScaledMaskedSoftmaxFwdPrimitive
)
register_primitive
(
ScaledMaskedSoftmaxFwdPrimitive
)
...
@@ -596,6 +614,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
...
@@ -596,6 +614,11 @@ class ScaledMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
ScaledMaskedSoftmaxBwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
ScaledMaskedSoftmaxBwdPrimitive
.
impl
,
scale_factor
,
mesh
,
arg_infos
,
result_infos
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"..., ... -> ..."
register_primitive
(
ScaledMaskedSoftmaxBwdPrimitive
)
register_primitive
(
ScaledMaskedSoftmaxBwdPrimitive
)
...
@@ -682,6 +705,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
...
@@ -682,6 +705,11 @@ class ScaledUpperTriangMaskedSoftmaxFwdPrimitive(SoftmaxPrimitive):
result_infos
,
result_infos
,
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"... -> ..."
register_primitive
(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
)
register_primitive
(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
)
...
@@ -761,15 +789,26 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
...
@@ -761,15 +789,26 @@ class ScaledUpperTriangMaskedSoftmaxBwdPrimitive(SoftmaxPrimitive):
result_infos
,
result_infos
,
)
)
@
staticmethod
def
shardy_sharding_rule
(
*
args
):
del
args
return
"..., ... -> ..."
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
register_primitive
(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
)
def
_jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
"""
JAX based implementation of scaled softmax
"""
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
return
jax
.
nn
.
softmax
(
scale_factor
*
logits
)
def
_jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_masked_softmax
(
logits
:
jnp
.
ndarray
,
mask
:
jnp
.
ndarray
,
scale_factor
:
float
):
"""
JAX based implementation of scaled and masked softmax
"""
if
mask
is
not
None
:
if
mask
is
not
None
:
logits
+=
jax
.
lax
.
select
(
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
mask
>
0
,
...
@@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
...
@@ -779,7 +818,10 @@ def _jax_scaled_masked_softmax(logits: jnp.ndarray, mask: jnp.ndarray, scale_fac
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
return
jax
.
nn
.
softmax
(
logits
*
scale_factor
)
def
_jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
def
jax_scaled_upper_triang_masked_softmax
(
logits
:
jnp
.
ndarray
,
scale_factor
:
float
):
"""
JAX based implementation of scaled and upper triangle masked softmax
"""
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
mask
=
1
-
jnp
.
tril
(
jnp
.
ones_like
(
logits
))
logits
+=
jax
.
lax
.
select
(
logits
+=
jax
.
lax
.
select
(
mask
>
0
,
mask
>
0
,
...
@@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
...
@@ -795,7 +837,7 @@ def scaled_softmax_fwd(logits: jnp.ndarray, scale_factor: float) -> jnp.ndarray:
Return FP16/BF16 tensor
Return FP16/BF16 tensor
"""
"""
if
not
ScaledSoftmaxFwdPrimitive
.
enabled
():
if
not
ScaledSoftmaxFwdPrimitive
.
enabled
():
return
_
jax_scaled_softmax
(
logits
,
scale_factor
)
return
jax_scaled_softmax
(
logits
,
scale_factor
)
return
ScaledSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
return
ScaledSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
)
...
@@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd(
...
@@ -807,7 +849,7 @@ def scaled_masked_softmax_fwd(
Return FP16/BF16 tensor
Return FP16/BF16 tensor
"""
"""
if
not
ScaledMaskedSoftmaxFwdPrimitive
.
enabled
():
if
not
ScaledMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
jax_scaled_masked_softmax
(
logits
,
mask
,
scale_factor
)
return
ScaledMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
return
ScaledMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
mask
,
scale_factor
=
scale_factor
logits
,
mask
,
scale_factor
=
scale_factor
)
)
...
@@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd(
...
@@ -826,7 +868,7 @@ def scaled_masked_softmax_bwd(
"""
"""
if
not
ScaledMaskedSoftmaxBwdPrimitive
.
enabled
():
if
not
ScaledMaskedSoftmaxBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_
jax_scaled_masked_softmax
,
scale_factor
=
scale_factor
),
logits
,
mask
partial
(
jax_scaled_masked_softmax
,
scale_factor
=
scale_factor
),
logits
,
mask
)
)
return
vjp_func
(
dz
)[
0
]
return
vjp_func
(
dz
)[
0
]
return
ScaledMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
return
ScaledMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
...
@@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
...
@@ -840,7 +882,7 @@ def scaled_upper_triang_masked_softmax_fwd(logits: jnp.ndarray, scale_factor: fl
Return FP16/BF16 tensor
Return FP16/BF16 tensor
"""
"""
if
not
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
enabled
():
if
not
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
enabled
():
return
_
jax_scaled_upper_triang_masked_softmax
(
logits
,
scale_factor
)
return
jax_scaled_upper_triang_masked_softmax
(
logits
,
scale_factor
)
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
return
ScaledUpperTriangMaskedSoftmaxFwdPrimitive
.
outer_primitive
.
bind
(
logits
,
scale_factor
=
scale_factor
logits
,
scale_factor
=
scale_factor
)
)
...
@@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
...
@@ -855,7 +897,7 @@ def scaled_upper_triang_masked_softmax_bwd(
"""
"""
if
not
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
.
enabled
():
if
not
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
.
enabled
():
_
,
vjp_func
=
jax
.
vjp
(
_
,
vjp_func
=
jax
.
vjp
(
partial
(
_
jax_scaled_upper_triang_masked_softmax
,
scale_factor
=
scale_factor
),
logits
partial
(
jax_scaled_upper_triang_masked_softmax
,
scale_factor
=
scale_factor
),
logits
)
)
return
vjp_func
(
dz
)[
0
]
return
vjp_func
(
dz
)[
0
]
return
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
return
ScaledUpperTriangMaskedSoftmaxBwdPrimitive
.
outer_primitive
.
bind
(
...
...
Prev
1
2
3
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment