Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
377 additions
and
337 deletions
+377
-337
transformer_engine/jax/csrc/extensions/attention.cpp
transformer_engine/jax/csrc/extensions/attention.cpp
+26
-26
transformer_engine/jax/csrc/extensions/gemm.cpp
transformer_engine/jax/csrc/extensions/gemm.cpp
+2
-1
transformer_engine/jax/csrc/extensions/misc.h
transformer_engine/jax/csrc/extensions/misc.h
+4
-0
transformer_engine/jax/csrc/extensions/normalization.cpp
transformer_engine/jax/csrc/extensions/normalization.cpp
+10
-3
transformer_engine/jax/csrc/extensions/pybind.cpp
transformer_engine/jax/csrc/extensions/pybind.cpp
+1
-0
transformer_engine/jax/csrc/extensions/quantization.cpp
transformer_engine/jax/csrc/extensions/quantization.cpp
+12
-6
transformer_engine/jax/csrc/extensions/utils.cpp
transformer_engine/jax/csrc/extensions/utils.cpp
+28
-0
transformer_engine/jax/csrc/extensions/utils.h
transformer_engine/jax/csrc/extensions/utils.h
+0
-32
transformer_engine/jax/csrc/utils.cu
transformer_engine/jax/csrc/utils.cu
+0
-75
transformer_engine/jax/dense.py
transformer_engine/jax/dense.py
+4
-3
transformer_engine/jax/flax/module.py
transformer_engine/jax/flax/module.py
+77
-70
transformer_engine/jax/flax/transformer.py
transformer_engine/jax/flax/transformer.py
+2
-4
transformer_engine/jax/layernorm_mlp.py
transformer_engine/jax/layernorm_mlp.py
+6
-11
transformer_engine/jax/quantize/dequantizer.py
transformer_engine/jax/quantize/dequantizer.py
+1
-0
transformer_engine/jax/quantize/helper.py
transformer_engine/jax/quantize/helper.py
+30
-2
transformer_engine/jax/quantize/quantizer.py
transformer_engine/jax/quantize/quantizer.py
+105
-34
transformer_engine/jax/quantize/scaling_modes.py
transformer_engine/jax/quantize/scaling_modes.py
+32
-3
transformer_engine/jax/setup.py
transformer_engine/jax/setup.py
+12
-0
transformer_engine/jax/sharding.py
transformer_engine/jax/sharding.py
+20
-8
transformer_engine/pytorch/__init__.py
transformer_engine/pytorch/__init__.py
+5
-59
No files found.
transformer_engine/jax/csrc/extensions/attention.cpp
View file @
f8c2af4c
...
...
@@ -196,10 +196,10 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes(
if (cudnn_runtime_version >= 90300) { \
num_segments = input_batch * max_segments_per_seq; \
} else { \
size_t runtime_num_segments_q =
\
GetRuntimeNumSegments(
q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream); \
size_t runtime_num_segments_kv =
\
GetRuntimeNumSegments(
kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream); \
size_t runtime_num_segments_q =
nvte_get_runtime_num_segments(
\
q_cu_seqlens, workspace, input_batch * q_max_seqlen, stream);
\
size_t runtime_num_segments_kv =
nvte_get_runtime_num_segments(
\
kv_cu_seqlens, workspace, input_batch * kv_max_seqlen, stream);
\
NVTE_CHECK(runtime_num_segments_q == runtime_num_segments_kv); \
NVTE_CHECK(runtime_num_segments_q <= input_batch * max_segments_per_seq); \
num_segments = runtime_num_segments_q; \
...
...
@@ -248,7 +248,7 @@ static void FusedAttnForwardImpl(
static_cast
<
NVTEDType
>
(
dtype
),
static_cast
<
NVTEDType
>
(
dtype
),
qkv_layout
,
bias_type
,
mask_type
,
dropout_probability
,
attn_heads
,
num_gqa_groups
,
q_max_seqlen
,
kv_max_seqlen
,
head_dim
,
head_dim
,
window_size_left
,
window_size_right
);
P
opulate
RngS
tate
A
sync
(
rng_state
,
seed
,
q_max_seqlen
,
kv_max_seqlen
,
backend
,
stream
);
nvte_p
opulate
_rng_s
tate
_a
sync
(
rng_state
,
seed
,
q_max_seqlen
,
kv_max_seqlen
,
backend
,
stream
);
/* Auxiliary tensors (to be propagated to the backward pass later) */
NVTETensorPack
aux_output_tensors
;
...
...
transformer_engine/jax/csrc/extensions/gemm.cpp
View file @
f8c2af4c
...
...
@@ -108,7 +108,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Variadic_Buffer_Type input_list,
auto
rhs_sinv_shape
=
std
::
vector
<
size_t
>
{
1
,
1
};
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
NO_SCALING
||
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
||
scaling_mode
==
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
)
{
float
*
amax_dptr
=
nullptr
;
float
*
scale_dptr
=
nullptr
;
auto
lhs_i_
=
TensorWrapper
(
lhs_ptr
,
lhs_shape
,
lhs_dtype
,
amax_dptr
,
scale_dptr
,
...
...
transformer_engine/jax/csrc/extensions/misc.h
View file @
f8c2af4c
...
...
@@ -44,6 +44,7 @@ enum class JAXX_Scaling_Mode : int64_t {
NO_SCALING
=
0
,
DELAYED_TENSOR_SCALING
=
1
,
MXFP8_1D_SCALING
=
2
,
CURRENT_TENSOR_SCALING
=
3
,
};
static
NVTEScalingMode
get_nvte_scaling_mode
(
const
JAXX_Scaling_Mode
&
mode
)
{
...
...
@@ -57,6 +58,9 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) {
case
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
:
return
NVTEScalingMode
::
NVTE_MXFP8_1D_SCALING
;
break
;
case
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
:
return
NVTEScalingMode
::
NVTE_DELAYED_TENSOR_SCALING
;
break
;
default:
NVTE_ERROR
(
"Invalid Scaling Mode "
,
static_cast
<
int
>
(
mode
));
break
;
...
...
transformer_engine/jax/csrc/extensions/normalization.cpp
View file @
f8c2af4c
...
...
@@ -24,7 +24,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si
// empty tensor wrappers are okay just to get workspace size
auto
input_tensor
=
TensorWrapper
(
nullptr
,
input_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
in
_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
nullptr
,
weight_shape
,
w
_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
nullptr
,
intermediates_shape
,
DType
::
kFloat32
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
...
...
@@ -98,7 +98,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
workspace_shape
=
std
::
vector
<
size_t
>
{
workspace_size
};
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
gamma
,
gamma_shape
,
in
_dtype
);
auto
gamma_tensor
=
TensorWrapper
(
gamma
,
gamma_shape
,
w
_dtype
);
auto
rsigma_tensor
=
TensorWrapper
(
rsigma
,
intermediates_shape
,
DType
::
kFloat32
);
auto
num_sm
=
cudaDevicePropertiesManager
::
Instance
().
GetMultiProcessorCount
()
-
_sm_margin
;
...
...
@@ -107,6 +107,11 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
output_tensor
.
set_rowwise_data
(
output
,
static_cast
<
DType
>
(
out_dtype
),
input_shape
);
NVTE_CHECK
(
scaling_mode
!=
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
,
"Current tensor scaling does not support fused operations yet. Please call this primitive "
"in higher-precision then quantize with current scaling."
);
if
(
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
...
...
@@ -118,7 +123,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
&&
is_fp8_dtype
(
out_dtype
))
{
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
nvte_memset
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
...
...
@@ -134,6 +139,8 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc
}
if
(
_norm_type
==
NVTE_Norm_Type
::
LayerNorm
)
{
NVTE_CHECK
(
w_dtype
==
convert_ffi_datatype_to_te_dtype
(
beta_buf
.
element_type
()),
"gamma and beta must have the same data type."
);
auto
beta_tensor
=
TensorWrapper
(
beta
,
gamma_shape
,
w_dtype
);
auto
mu_tensor
=
TensorWrapper
(
mu
,
intermediates_shape
,
DType
::
kFloat32
);
...
...
transformer_engine/jax/csrc/extensions/pybind.cpp
View file @
f8c2af4c
...
...
@@ -142,6 +142,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) {
.
value
(
"NO_SCALING"
,
JAXX_Scaling_Mode
::
NO_SCALING
)
.
value
(
"DELAYED_TENSOR_SCALING"
,
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
.
value
(
"MXFP8_1D_SCALING"
,
JAXX_Scaling_Mode
::
MXFP8_1D_SCALING
)
.
value
(
"CURRENT_TENSOR_SCALING"
,
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
)
.
export_values
();
pybind11
::
enum_
<
transformer_engine
::
jax
::
QuantizeLayout
>
(
m
,
"QuantizeLayout"
,
...
...
transformer_engine/jax/csrc/extensions/quantization.cpp
View file @
f8c2af4c
...
...
@@ -7,6 +7,7 @@
#include "extensions.h"
#include "transformer_engine/cast.h"
#include "transformer_engine/recipe.h"
#include "xla/ffi/api/c_api.h"
namespace
transformer_engine
{
...
...
@@ -107,18 +108,21 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
auto
input_tensor
=
TensorWrapper
(
input
,
input_shape
,
in_dtype
);
auto
output_tensor
=
TensorWrapper
(
get_nvte_scaling_mode
(
scaling_mode
));
bool
const
is_tensor_scaling
=
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
||
scaling_mode
==
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
;
if
(
quantize_layout
==
QuantizeLayout
::
ROWWISE
||
quantize_layout
==
QuantizeLayout
::
ROWWISE_COLWISE
)
{
output_tensor
.
set_rowwise_data
(
output
,
out_dtype
,
output_shape
);
if
(
is_fp8_dtype
(
out_dtype
))
{
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
if
(
is_tensor_scaling
)
{
float
*
scale
=
reinterpret_cast
<
float
*>
(
scale_buf
.
untyped_data
());
float
*
amax
=
reinterpret_cast
<
float
*>
(
amax_buf
->
untyped_data
());
NVTE_CHECK
(
scale
!=
nullptr
,
"scale must be provided for delayed tensor scaling"
);
NVTE_CHECK
(
amax
!=
nullptr
,
"amax must be provided for delayed tensor scaling"
);
output_tensor
.
set_scale
(
scale
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
cudaMemsetAsync
(
amax
,
0
,
sizeof
(
float
),
stream
);
nvte_memset
(
amax
,
0
,
sizeof
(
float
),
stream
);
output_tensor
.
set_amax
(
amax
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
output_tensor
.
set_rowwise_scale_inv
(
scale_inv_buf
->
untyped_data
(),
...
...
@@ -142,11 +146,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
:
output_shape
;
output_tensor
.
set_columnwise_data
(
output_trans
,
out_dtype
,
tmp_shape
);
// For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling
auto
&
tmp_buf
=
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
?
scale_inv_buf
:
colwise_scale_inv_buf
;
auto
&
tmp_buf
=
is_tensor_scaling
?
scale_inv_buf
:
colwise_scale_inv_buf
;
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
DELAYED_TENSOR_SCALING
)
{
if
(
is_tensor_scaling
)
{
output_tensor
.
set_columnwise_scale_inv
(
tmp_buf
->
untyped_data
(),
convert_ffi_datatype_to_te_dtype
(
tmp_buf
->
element_type
()),
std
::
vector
<
size_t
>
{
1
});
...
...
@@ -159,6 +161,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T
}
}
if
(
scaling_mode
==
JAXX_Scaling_Mode
::
CURRENT_TENSOR_SCALING
)
{
output_tensor
.
set_amax
(
nullptr
,
DType
::
kFloat32
,
std
::
vector
<
size_t
>
{
1
});
}
auto
dbias_tensor
=
TensorWrapper
(
dbias
,
dbias_shape
,
in_dtype
);
auto
workspace_tensor
=
TensorWrapper
(
workspace
,
workspace_shape
,
workspace_dtype
);
...
...
transformer_engine/
pytorch
/csrc/extensions/util.cpp
→
transformer_engine/
jax
/csrc/extensions/util
s
.cpp
View file @
f8c2af4c
...
...
@@ -3,12 +3,26 @@
*
* See LICENSE for license information.
************************************************************************/
#include "utils.h"
#include
"util
.h
"
#include
<cuda_runtime_api
.h
>
#include
"ATen/cuda/CUDAContextLight.h"
#include
<cassert>
bool
non_tn_fp8_gemm_supported
()
{
int
major
=
at
::
cuda
::
getCurrentDeviceProperties
()
->
major
;
return
major
>=
10
;
#include "common/util/cuda_runtime.h"
namespace
transformer_engine
{
namespace
jax
{
int
GetCudaRuntimeVersion
()
{
int
ver
=
0
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
ver
));
return
ver
;
}
size_t
GetCudnnRuntimeVersion
()
{
return
cudnnGetVersion
();
}
int
GetDeviceComputeCapability
(
int
gpu_id
)
{
return
transformer_engine
::
cuda
::
sm_arch
(
gpu_id
);
}
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/csrc/utils.h
→
transformer_engine/jax/csrc/
extensions/
utils.h
View file @
f8c2af4c
...
...
@@ -4,9 +4,6 @@
* See LICENSE for license information.
************************************************************************/
#ifndef TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#define TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
#include <pybind11/pybind11.h>
#include <transformer_engine/fused_attn.h>
...
...
@@ -25,12 +22,6 @@ int GetCudaRuntimeVersion();
size_t
GetCudnnRuntimeVersion
();
int
GetDeviceComputeCapability
(
int
gpu_id
);
void
PopulateRngStateAsync
(
void
*
rng_state_dst
,
const
void
*
const
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
);
uint32_t
GetRuntimeNumSegments
(
void
*
cu_seqlen
,
void
*
workspace
,
size_t
len
,
cudaStream_t
stream
);
class
cudaDevicePropertiesManager
{
public:
static
cudaDevicePropertiesManager
&
Instance
()
{
...
...
@@ -63,28 +54,5 @@ class cudaDevicePropertiesManager {
cudaDeviceProp
prop_
;
};
class
FusedAttnOffsetManager
{
public:
static
FusedAttnOffsetManager
&
Instance
()
{
static
thread_local
FusedAttnOffsetManager
instance
;
return
instance
;
}
size_t
GetAndUpdateOffset
(
size_t
increment
)
{
size_t
ret
=
offset_
;
offset_
+=
increment
;
return
ret
;
}
FusedAttnOffsetManager
(
FusedAttnOffsetManager
const
&
)
=
delete
;
void
operator
=
(
FusedAttnOffsetManager
const
&
)
=
delete
;
private:
FusedAttnOffsetManager
()
{}
size_t
offset_
=
0
;
};
}
// namespace jax
}
// namespace transformer_engine
#endif // TRANSFORMER_ENGINE_JAX_CSRC_UTILS_H_
transformer_engine/jax/csrc/utils.cu
deleted
100644 → 0
View file @
e92773a3
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime_api.h>
#include <cassert>
#include "common/util/cuda_runtime.h"
#include "utils.h"
namespace
transformer_engine
{
namespace
jax
{
int
GetCudaRuntimeVersion
()
{
int
ver
=
0
;
NVTE_CHECK_CUDA
(
cudaRuntimeGetVersion
(
&
ver
));
return
ver
;
}
size_t
GetCudnnRuntimeVersion
()
{
return
cudnnGetVersion
();
}
int
GetDeviceComputeCapability
(
int
gpu_id
)
{
return
transformer_engine
::
cuda
::
sm_arch
(
gpu_id
);
}
__global__
void
populate_rng_state_kernel
(
int64_t
*
rng_state_dst
,
const
int64_t
*
const
seed
,
int64_t
offset
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
>
0
)
return
;
rng_state_dst
[
0
]
=
seed
[
0
];
rng_state_dst
[
1
]
=
offset
;
}
void
PopulateRngStateAsync
(
void
*
rng_state_dst
,
const
void
*
const
seed
,
size_t
q_max_seqlen
,
size_t
kv_max_seqlen
,
NVTE_Fused_Attn_Backend
backend
,
cudaStream_t
stream
)
{
size_t
increment
=
0
;
if
(
backend
==
NVTE_Fused_Attn_Backend
::
NVTE_F16_arbitrary_seqlen
)
{
increment
=
16
;
}
else
{
constexpr
int
threads_per_cta
=
128
;
increment
=
(
q_max_seqlen
*
kv_max_seqlen
+
threads_per_cta
-
1
)
/
threads_per_cta
;
}
auto
offset
=
FusedAttnOffsetManager
::
Instance
().
GetAndUpdateOffset
(
increment
);
populate_rng_state_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
reinterpret_cast
<
int64_t
*>
(
rng_state_dst
),
reinterpret_cast
<
const
int64_t
*>
(
seed
),
offset
);
NVTE_CHECK_CUDA
(
cudaGetLastError
());
}
__global__
void
get_runtime_num_segments_kernel
(
int32_t
*
cu_seqlen
,
size_t
len
,
uint32_t
*
out
)
{
int
tid
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
if
(
tid
>=
len
)
return
;
if
(
cu_seqlen
[
tid
]
>
0
)
{
// atomicAdd only support 32 bits dtype
atomicAdd
(
out
,
1
);
}
}
uint32_t
GetRuntimeNumSegments
(
void
*
cu_seqlen
,
void
*
workspace
,
size_t
len
,
cudaStream_t
stream
)
{
// workspace size requires 4 bytes
uint32_t
*
dout
=
static_cast
<
uint32_t
*>
(
workspace
);
uint32_t
hout
{};
cudaMemsetAsync
(
dout
,
0
,
sizeof
(
uint32_t
),
stream
);
constexpr
int
threads
=
128
;
const
int
blocks
=
(
len
-
1
)
/
threads
+
1
;
get_runtime_num_segments_kernel
<<<
blocks
,
threads
,
0
,
stream
>>>
(
static_cast
<
int32_t
*>
(
cu_seqlen
),
len
,
dout
);
cudaMemcpyAsync
(
&
hout
,
dout
,
sizeof
(
uint32_t
),
cudaMemcpyDeviceToHost
,
stream
);
cudaStreamSynchronize
(
stream
);
return
hout
;
}
}
// namespace jax
}
// namespace transformer_engine
transformer_engine/jax/dense.py
View file @
f8c2af4c
...
...
@@ -49,6 +49,7 @@ def dense(
"""
# Remove when tex.quantize() can handle quantizer=None
if
quantizer_set
==
noop_quantizer_set
:
x
=
with_sharding_constraint_by_logical_axes
(
x
,
input_axes
)
output
=
tex
.
gemm
(
x
,
kernel
,
contracting_dims
)
if
bias
is
not
None
:
bias_new_shape
=
(
1
,)
*
(
output
.
ndim
-
bias
.
ndim
)
+
bias
.
shape
...
...
@@ -183,6 +184,7 @@ def _dense_bwd_rule(
_dense
.
defvjp
(
_dense_fwd_rule
,
_dense_bwd_rule
)
"""
def grouped_dense(
x_list,
kernel_list,
...
...
@@ -190,10 +192,8 @@ def grouped_dense(
contracting_dims_list,
quantizer_set_list=None,
):
"""
Perform grouped_dense layer transformation with optional quantization.
# Perform grouped_dense layer transformation with optional quantization.
"""
output_list = _grouped_dense(
x_list, kernel_list, bias_list, contracting_dims_list, quantizer_set_list
)
...
...
@@ -315,3 +315,4 @@ def _grouped_dense_bwd_rule(contracting_dims_list, ctx, grad_list):
_grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule)
"""
transformer_engine/jax/flax/module.py
View file @
f8c2af4c
...
...
@@ -11,7 +11,6 @@ from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union
import
numpy
as
np
import
jax.numpy
as
jnp
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
jax
import
lax
from
jax
import
random
as
jax_random
from
jax.ad_checkpoint
import
checkpoint_name
...
...
@@ -65,6 +64,7 @@ def _obtain_default_layernorm_scale_init_if_need(original_init, zero_centered_ga
def
_create_layernorm_parameters
(
module
,
norm_type
,
shape
,
scale_init
,
...
...
@@ -74,13 +74,21 @@ def _create_layernorm_parameters(
input_dtype
,
dtype
,
):
scale
=
nn_partitioning
.
param_with_axes
(
"scale"
,
scale_init
,
shape
,
dtype
,
axes
=
scale_axes
)
scale
=
scale
.
astype
(
input_dtype
)
scale
=
module
.
param
(
"scale"
,
nn
.
with_logical_partitioning
(
scale_init
,
scale_axes
),
shape
,
dtype
,
).
astype
(
input_dtype
)
norm_type
=
canonicalize_norm_type
(
norm_type
)
if
norm_type
==
"layernorm"
:
bias
=
nn_partitioning
.
param_with_axes
(
"ln_bias"
,
bias_init
,
shape
,
dtype
,
axes
=
bias_axes
)
bias
=
jnp
.
asarray
(
bias
,
input_dtype
)
bias
=
module
.
param
(
"ln_bias"
,
nn
.
with_logical_partitioning
(
bias_init
,
bias_axes
),
shape
,
dtype
,
).
astype
(
input_dtype
)
else
:
assert
norm_type
==
"rmsnorm"
bias
=
None
...
...
@@ -308,6 +316,7 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods
features
=
x
.
shape
[
-
1
]
scale
,
ln_bias
=
_create_layernorm_parameters
(
self
,
self
.
layernorm_type
,
(
features
,),
self
.
scale_init
,
...
...
@@ -467,16 +476,22 @@ class DenseGeneral(TransformerEngineBase):
"Expected len(kernel_shape) to match len(kernel_axes),"
f
"got kernel_shape
{
kernel_shape
}
and kernel_axes
{
self
.
kernel_axes
}
"
)
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
kernel
=
self
.
param
(
"kernel"
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
self
.
kernel_axes
),
kernel_shape
,
self
.
dtype
,
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
bias
=
self
.
param
(
"bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes
),
features
,
self
.
dtype
,
).
astype
(
input_dtype
)
else
:
bias
=
None
...
...
@@ -499,25 +514,21 @@ class DenseGeneral(TransformerEngineBase):
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
lora_a_kernel
=
self
.
param
(
"lora_a_kernel"
,
self
.
kernel_init
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
lora_a_kernel_axes
)
,
lora_a_kernel_shape
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
lora_b_kernel
=
self
.
param
(
"lora_b_kernel"
,
nn
.
initializers
.
zeros
,
nn
.
with_logical_partitioning
(
nn
.
initializers
.
zeros
,
lora_b_kernel_axes
)
,
lora_b_kernel_shape
,
self
.
dtype
,
axes
=
lora_b_kernel_axes
,
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
y
+=
_apply_low_rank_adaptation
(
inputs
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
...
...
@@ -695,6 +706,7 @@ class LayerNormDenseGeneral(TransformerEngineBase):
inputs
=
with_sharding_constraint_by_logical_axes
(
inputs
,
self
.
layernorm_input_axes
)
features
=
inputs
.
shape
[
-
1
]
scale
,
ln_bias
=
_create_layernorm_parameters
(
self
,
self
.
layernorm_type
,
(
features
,),
self
.
scale_init
,
...
...
@@ -730,8 +742,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),)
+
features
kernel
=
nn_partitioning
.
param_with_axes
(
"kernel"
,
self
.
kernel_init
,
kernel_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes
kernel
=
self
.
param
(
"kernel"
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
self
.
kernel_axes
),
kernel_shape
,
self
.
dtype
,
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel
=
kernel
.
astype
(
input_dtype
)
...
...
@@ -770,25 +785,21 @@ class LayerNormDenseGeneral(TransformerEngineBase):
self
.
low_rank_adaptation_dim
,
)
lora_a_kernel_axes
=
(
None
,)
*
len
(
lora_a_kernel_shape
)
lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
lora_a_kernel
=
self
.
param
(
"lora_a_kernel"
,
self
.
kernel_init
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
lora_a_kernel_axes
)
,
lora_a_kernel_shape
,
self
.
dtype
,
axes
=
lora_a_kernel_axes
,
)
lora_a_kernel
=
lora_a_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
lora_b_kernel_shape
=
(
*
features
[:
-
1
],
self
.
low_rank_adaptation_dim
,
features
[
-
1
])
lora_b_kernel_axes
=
(
None
,)
*
len
(
lora_b_kernel_shape
)
lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
lora_b_kernel
=
self
.
param
(
"lora_b_kernel"
,
nn
.
initializers
.
zeros
,
nn
.
with_logical_partitioning
(
nn
.
initializers
.
zeros
,
lora_b_kernel_axes
)
,
lora_b_kernel_shape
,
self
.
dtype
,
axes
=
lora_b_kernel_axes
,
)
lora_b_kernel
=
lora_b_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
z
+=
_apply_low_rank_adaptation
(
y
,
axis
,
features
,
lora_a_kernel
,
lora_b_kernel
,
self
.
low_rank_adaptation_alpha
...
...
@@ -796,8 +807,11 @@ class LayerNormDenseGeneral(TransformerEngineBase):
bias
=
None
if
self
.
use_bias
:
bias
=
nn_partitioning
.
param_with_axes
(
"bias"
,
self
.
bias_init
,
features
,
self
.
dtype
,
axes
=
self
.
bias_axes
bias
=
self
.
param
(
"bias"
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes
),
features
,
self
.
dtype
,
).
astype
(
input_dtype
)
if
bias
is
not
None
:
...
...
@@ -1028,6 +1042,7 @@ class LayerNormMLP(TransformerEngineBase):
features
=
inputs
.
shape
[
-
1
]
scale
,
ln_bias
=
_create_layernorm_parameters
(
self
,
self
.
layernorm_type
,
(
features
,),
self
.
scale_init
,
...
...
@@ -1067,14 +1082,13 @@ class LayerNormMLP(TransformerEngineBase):
axis
=
_canonicalize_tuple
(
self
.
axis
)
axis
=
_normalize_axes
(
axis
,
y
.
ndim
)
kernel_1_each_shape
=
(
np
.
prod
([
inputs
.
shape
[
ax
]
for
ax
in
axis
]),
self
.
intermediate_dim
)
kernel_1
=
nn_partitioning
.
param_with_axes
(
kernel_1
=
self
.
param
(
"wi_kernel"
,
kernel_1_init
,
nn
.
with_logical_partitioning
(
kernel_1_init
,
self
.
kernel_axes_1
)
,
num_activations
,
-
2
,
kernel_1_each_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes_1
,
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
...
...
@@ -1083,12 +1097,11 @@ class LayerNormMLP(TransformerEngineBase):
hidden_size
=
inputs
.
shape
[
-
1
]
hidden_size_tuple
=
_canonicalize_tuple
(
hidden_size
)
kernel_2_shape
=
(
self
.
intermediate_dim
,)
+
hidden_size_tuple
kernel_2
=
nn_partitioning
.
param_with_axes
(
kernel_2
=
self
.
param
(
"wo_kernel"
,
self
.
kernel_init
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
self
.
kernel_axes_2
)
,
kernel_2_shape
,
self
.
dtype
,
axes
=
self
.
kernel_axes_2
,
)
if
not
QuantizeConfig
.
is_fp8_enabled
():
kernel_2
=
kernel_2
.
astype
(
input_dtype
)
...
...
@@ -1097,21 +1110,19 @@ class LayerNormMLP(TransformerEngineBase):
if
self
.
use_bias
:
bias_1_shape
=
(
num_activations
,
self
.
intermediate_dim
)
bias_1
=
nn_partitioning
.
param_with_axes
(
bias_1
=
self
.
param
(
"wi_bias"
,
self
.
bias_init
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes_1
)
,
bias_1_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_1
,
).
astype
(
input_dtype
)
bias_2_shape
=
(
hidden_size
,)
bias_2
=
nn_partitioning
.
param_with_axes
(
bias_2
=
self
.
param
(
"wo_bias"
,
self
.
bias_init
,
nn
.
with_logical_partitioning
(
self
.
bias_init
,
self
.
bias_axes_2
)
,
bias_2_shape
,
self
.
dtype
,
axes
=
self
.
bias_axes_2
,
).
astype
(
input_dtype
)
else
:
bias_1
=
None
...
...
@@ -1168,9 +1179,13 @@ class LayerNormMLP(TransformerEngineBase):
kernel_axes
=
self
.
kernel_axes_1
,
quantizer_set
=
ffn1_quantizer_set
,
)
if
self
.
dot_1_input_axes
is
not
None
and
self
.
kernel_axes_1
is
not
None
:
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
y
.
ndim
,
self
.
dot_1_input_axes
,
axis
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
self
.
kernel_axes_1
,
contract_ind
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
self
.
kernel_axes_1
,
contract_ind
),
)
x
=
with_sharding_constraint_by_logical_axes
(
x
,
dot_1_output_axes
)
...
...
@@ -1180,16 +1195,14 @@ class LayerNormMLP(TransformerEngineBase):
self
.
low_rank_adaptation_dim
,
)
wi_lora_a_kernel_axes
=
(
None
,)
*
len
(
wi_lora_a_kernel_each_shape
+
1
)
wi_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
wi_lora_a_kernel
=
self
.
param
(
"wi_lora_a_kernel"
,
kernel_1_init
,
nn
.
with_logical_partitioning
(
kernel_1_init
,
wi_lora_a_kernel_axes
)
,
num_activations
,
-
2
,
wi_lora_a_kernel_each_shape
,
self
.
dtype
,
axes
=
wi_lora_a_kernel_axes
,
)
wi_lora_a_kernel
=
wi_lora_a_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
wi_lora_b_kernel_shape
=
(
num_activations
,
...
...
@@ -1197,14 +1210,12 @@ class LayerNormMLP(TransformerEngineBase):
self
.
intermediate_dim
,
)
wi_lora_b_kernel_axes
=
(
None
,)
*
len
(
wi_lora_b_kernel_shape
)
wi_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
wi_lora_b_kernel
=
self
.
param
(
"wi_lora_b_kernel"
,
nn
.
initializers
.
zeros
,
nn
.
with_logical_partitioning
(
nn
.
initializers
.
zeros
,
wi_lora_b_kernel_axes
)
,
wi_lora_b_kernel_shape
,
self
.
dtype
,
axes
=
wi_lora_b_kernel_axes
,
)
wi_lora_b_kernel
=
wi_lora_b_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
x
+=
_apply_low_rank_adaptation
(
y
,
...
...
@@ -1253,25 +1264,21 @@ class LayerNormMLP(TransformerEngineBase):
if
self
.
enable_low_rank_adaptation
:
wo_lora_a_kernel_shape
=
(
self
.
intermediate_dim
,
self
.
low_rank_adaptation_dim
)
wo_lora_a_kernel_axes
=
(
None
,)
*
len
(
wo_lora_a_kernel_shape
)
wo_lora_a_kernel
=
nn_partitioning
.
param_with_axes
(
wo_lora_a_kernel
=
self
.
param
(
"wo_lora_a_kernel"
,
self
.
kernel_init
,
nn
.
with_logical_partitioning
(
self
.
kernel_init
,
wo_lora_a_kernel_axes
)
,
wo_lora_a_kernel_shape
,
self
.
dtype
,
axes
=
wo_lora_a_kernel_axes
,
)
wo_lora_a_kernel
=
wo_lora_a_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
wo_lora_b_kernel_shape
=
(
self
.
low_rank_adaptation_dim
,
hidden_size
)
wo_lora_b_kernel_axes
=
(
None
,)
*
len
(
wo_lora_b_kernel_shape
)
wo_lora_b_kernel
=
nn_partitioning
.
param_with_axes
(
wo_lora_b_kernel
=
self
.
param
(
"wo_lora_b_kernel"
,
nn
.
initializers
.
zeros
,
nn
.
with_logical_partitioning
(
nn
.
initializers
.
zeros
,
wo_lora_b_kernel_axes
)
,
wo_lora_b_kernel_shape
,
self
.
dtype
,
axes
=
wo_lora_b_kernel_axes
,
)
wo_lora_b_kernel
=
wo_lora_b_kernel
.
astype
(
input_dtype
)
).
astype
(
input_dtype
)
out
+=
_apply_low_rank_adaptation
(
z
,
...
...
transformer_engine/jax/flax/transformer.py
View file @
f8c2af4c
...
...
@@ -15,7 +15,6 @@ import jax
import
jax.numpy
as
jnp
import
numpy
as
np
from
flax
import
linen
as
nn
from
flax.linen
import
partitioning
as
nn_partitioning
from
flax.linen.attention
import
combine_masks
from
jax
import
nn
as
jax_nn
from
jax
import
random
as
jax_random
...
...
@@ -1503,12 +1502,11 @@ class RelativePositionBiases(nn.Module): # pylint: disable=too-few-public-metho
rp_bucket
+=
np
.
where
(
rpb_is_small
,
negative_rp
,
rpb_val_if_large
)
# Compute relative attention bias
relative_attention_bias
=
nn_partitioning
.
param_with_axes
(
relative_attention_bias
=
self
.
param
(
"rel_embedding"
,
self
.
embedding_init
,
nn
.
with_logical_partitioning
(
self
.
embedding_init
,
self
.
embedding_axes
)
,
(
self
.
num_attention_heads
,
self
.
num_buckets
),
self
.
dtype
,
axes
=
self
.
embedding_axes
,
)
relative_attention_bias
=
jnp
.
asarray
(
relative_attention_bias
,
self
.
dtype
)
...
...
transformer_engine/jax/layernorm_mlp.py
View file @
f8c2af4c
...
...
@@ -275,6 +275,7 @@ def _layernorm_mlp_fwd_rule(
(
x_contracting_dims
,
k_contracting_dims
),
)
if
dot_1_input_axes
is
not
None
and
kernel_1_axes
is
not
None
:
dot_1_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_1_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_1
.
ndim
,
kernel_1_axes
,
k_contracting_dims
),
...
...
@@ -303,12 +304,6 @@ def _layernorm_mlp_fwd_rule(
(
x_contracting_dims
,
k_contracting_dims
),
)
dot_2_output_axes
=
(
*
get_non_contracting_logical_axes
(
x
.
ndim
,
dot_2_input_axes
,
x_contracting_dims
),
*
get_non_contracting_logical_axes
(
kernel_2
.
ndim
,
None
,
k_contracting_dims
),
)
dot_2_output
=
with_sharding_constraint_by_logical_axes
(
dot_2_output
,
dot_2_output_axes
)
if
use_bias_2
:
bias_2_shape
=
bias_2
.
shape
bias_2_new_shape
=
(
1
,)
*
(
dot_2_output
.
ndim
-
bias_2
.
ndim
)
+
bias_2_shape
...
...
transformer_engine/jax/quantize/dequantizer.py
View file @
f8c2af4c
...
...
@@ -85,6 +85,7 @@ class Dequantizer:
funcs
=
{
ScalingMode
.
DELAYED_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
CURRENT_TENSOR_SCALING
:
_dq_func_tensor_scaling
,
ScalingMode
.
MXFP8_1D_SCALING
:
_dq_func_block_scaling
,
}
...
...
transformer_engine/jax/quantize/helper.py
View file @
f8c2af4c
...
...
@@ -94,7 +94,7 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]:
A tuple of (bool, str) indicating support and any error message
"""
gpu_arch
=
get_device_compute_capability
(
gpu_id
)
if
scaling_mode
==
ScalingMode
.
DELAYED_TENSOR_SCALING
:
if
scaling_mode
.
is_tensor_scaling
()
:
return
_check_delayed_scaling_fp8_support
(
gpu_arch
)
if
scaling_mode
==
ScalingMode
.
MXFP8_1D_SCALING
:
return
_check_block_scaling_fp8_support
(
gpu_arch
)
...
...
@@ -182,6 +182,8 @@ def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode:
return
ScalingMode
.
DELAYED_TENSOR_SCALING
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
return
ScalingMode
.
MXFP8_1D_SCALING
if
isinstance
(
fp8_recipe
,
recipe
.
Float8CurrentScaling
):
return
ScalingMode
.
CURRENT_TENSOR_SCALING
raise
ValueError
(
"Invalid fp8_recipe!"
)
...
...
@@ -240,7 +242,7 @@ class QuantizeConfig:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls
.
INITIALIZED
=
True
cls
.
MARGIN
=
fp8_recipe
.
margin
cls
.
MARGIN
=
fp8_recipe
.
margin
if
"margin"
in
dir
(
fp8_recipe
)
else
0.0
cls
.
FP8_FORMAT
=
fp8_recipe
.
fp8_format
cls
.
FWD_DTYPE
,
cls
.
BWD_DTYPE
=
_format2dtypes
(
cls
.
FP8_FORMAT
)
cls
.
SCALING_MODE
=
_get_scaling_mode
(
fp8_recipe
)
...
...
@@ -309,6 +311,30 @@ class DelayedScalingQuantizeConfig:
QuantizeConfig
.
finalize
()
class
CurrentScalingQuantizeConfig
:
"""Configuration class for current scaling FP8 recipe.
This class provides specific initialization and finalization for current scaling
FP8 quantization mode.
"""
@
staticmethod
def
initialize
(
fp8_recipe
:
recipe
.
Recipe
)
->
None
:
"""Initialize current scaling FP8 configuration.
Args:
fp8_recipe: The FP8 recipe to use for initialization
"""
cls
=
QuantizeConfig
cls
.
initialize
(
fp8_recipe
)
cls
.
AMAX_HISTORY_LEN
=
0
@
staticmethod
def
finalize
()
->
None
:
"""Reset the current scaling configuration."""
QuantizeConfig
.
finalize
()
class
BlockScalingQuantizeConfig
:
"""Configuration class for block scaling FP8 recipe.
...
...
@@ -385,6 +411,8 @@ def fp8_autocast(
Config
=
DelayedScalingQuantizeConfig
if
isinstance
(
fp8_recipe
,
recipe
.
MXFP8BlockScaling
):
Config
=
BlockScalingQuantizeConfig
if
isinstance
(
fp8_recipe
,
recipe
.
Float8CurrentScaling
):
Config
=
CurrentScalingQuantizeConfig
try
:
with
global_shard_guard
(
mesh_resource
):
...
...
transformer_engine/jax/quantize/quantizer.py
View file @
f8c2af4c
...
...
@@ -27,13 +27,36 @@ __all__ = [
"QuantizeLayout"
,
"Quantizer"
,
"QuantizerSet"
,
"CurrentScaleQuantizer"
,
"DelayedScaleQuantizer"
,
"BlockScaleQuantizer"
,
"QuantizerFactory"
,
"noop_quantizer_set"
,
"compute_scale_from_amax"
,
]
def
compute_scale_from_amax
(
amax
:
jnp
.
ndarray
,
q_dtype
:
jnp
.
dtype
,
scale
:
Optional
[
jnp
.
ndarray
]
=
None
)
->
jnp
.
ndarray
:
"""Compute scale from amax value.
Args:
amax: Maximum absolute value of the tensor
q_dtype: Quantization data type
Returns:
Scale value
"""
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
q_dtype
).
max
,
jnp
.
float32
)
if
scale
is
None
:
scale
=
jnp
.
ones
((
1
,))
sf
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
sf
=
jnp
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
jnp
.
where
(
jnp
.
isfinite
(
amax
),
sf
,
scale
)
return
sf
@
register_pytree_node_class
@
dataclass
class
Quantizer
(
ABC
):
...
...
@@ -159,37 +182,19 @@ class Quantizer(ABC):
@
register_pytree_node_class
@
dataclass
class
Delayed
ScaleQuantizer
(
Quantizer
):
"""Quantizer implementation using
delayed
scaling.
class
Current
ScaleQuantizer
(
Quantizer
):
"""Quantizer implementation using
current
scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
This quantizer uses current scaling mode with float32 scales
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
DELAYED
_TENSOR_SCALING
scaling_mode
:
ScalingMode
=
ScalingMode
.
CURRENT
_TENSOR_SCALING
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
amax_history
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
zeros
((
QuantizeConfig
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
)
)
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
)
return
(
children
,
aux_data
)
def
get_data_layout
(
self
)
->
str
:
"""Get the data data_layout string.
...
...
@@ -217,15 +222,18 @@ class DelayedScaleQuantizer(Quantizer):
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
compute_dtype
=
self
.
scale
.
dtype
compute_dtype
=
jnp
.
float32
dtype_max
=
(
jnp
.
finfo
(
self
.
q_dtype
).
max
).
astype
(
compute_dtype
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
self
.
scale
amax
=
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,)).
astype
(
compute_dtype
)
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
self
.
q_dtype
).
max
,
jnp
.
float32
)
scale
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
...
...
@@ -233,8 +241,7 @@ class DelayedScaleQuantizer(Quantizer):
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
).
astype
(
self
.
q_dtype
)
scale_inv
=
1.0
/
self
.
scale
self
.
update
(
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,)))
scale_inv
=
1.0
/
scale
return
ScaledTensorFactory
.
create_1x
(
data
=
clipped_scaled_x
,
scale_inv
=
scale_inv
,
...
...
@@ -294,6 +301,75 @@ class DelayedScaleQuantizer(Quantizer):
return
colwise_tensor
return
rowwise_tensor
@
register_pytree_node_class
@
dataclass
class
DelayedScaleQuantizer
(
CurrentScaleQuantizer
):
"""Quantizer implementation using delayed scaling.
This quantizer uses delayed scaling mode with float32 scales and maintains
a history of maximum absolute values for dynamic scaling.
Attributes:
scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING
q_layout: Quantization axis (default: ROWWISE_COLWISE)
scale: Current scaling factor
amax_history: History of maximum absolute values
"""
scaling_mode
:
ScalingMode
=
ScalingMode
.
DELAYED_TENSOR_SCALING
q_layout
:
QuantizeLayout
=
QuantizeLayout
.
ROWWISE_COLWISE
scale
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
ones
((
1
,),
jnp
.
float32
))
amax_history
:
jnp
.
ndarray
=
field
(
default_factory
=
lambda
:
jnp
.
zeros
((
QuantizeConfig
.
AMAX_HISTORY_LEN
,),
jnp
.
float32
)
)
def
tree_flatten
(
self
):
"""Flatten the quantizer for JAX tree operations.
Returns:
Tuple of (children, aux_data) for tree operations
"""
children
=
(
self
.
scale
,
self
.
amax_history
)
aux_data
=
(
self
.
q_dtype
,
self
.
scaling_mode
,
self
.
q_layout
)
return
(
children
,
aux_data
)
def
_quantize_func
(
self
,
x
:
jnp
.
ndarray
,
is_colwise
=
False
,
dq_dtype
=
None
,
flatten_axis
=-
1
)
->
ScaledTensor1x
:
"""Quantize function helper for delayed scaling FP8.
Args:
x: Input tensor to quantize
is_colwise: Whether to use column-wise quantization
dq_dtype: Data type for dequantized values
flatten_axis: The quantization axis for the tensor
Returns:
A ScaledTensor1x containing the quantized data
"""
dq_dtype
=
dq_dtype
if
dq_dtype
is
not
None
else
x
.
dtype
compute_dtype
=
jnp
.
float32
dtype_max
=
(
jnp
.
finfo
(
self
.
q_dtype
).
max
).
astype
(
compute_dtype
)
scaled_x
=
x
.
astype
(
compute_dtype
)
*
self
.
scale
# quantize() in the old dot.py do this way, leave this code block here for future debugging
# compute_dtype = x.dtype
# dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype)
# scaled_x = x * self.scale.astype(compute_dtype)
clipped_scaled_x
=
jnp
.
clip
(
scaled_x
,
-
dtype_max
,
dtype_max
).
astype
(
self
.
q_dtype
)
scale_inv
=
1.0
/
self
.
scale
self
.
update
(
jnp
.
max
(
jnp
.
abs
(
x
)).
reshape
((
1
,)))
return
ScaledTensorFactory
.
create_1x
(
data
=
clipped_scaled_x
,
scale_inv
=
scale_inv
,
scaling_mode
=
self
.
scaling_mode
,
dq_dtype
=
dq_dtype
,
flatten_axis
=
flatten_axis
,
)
@
staticmethod
@
jax
.
jit
def
_update_amax_history
(
amax_history
,
new_amax
):
...
...
@@ -323,18 +399,12 @@ class DelayedScaleQuantizer(Quantizer):
Updated scale value
"""
# 2. Calculate the current scale
fp8_max
=
jnp
.
astype
(
jnp
.
finfo
(
q_dtype
).
max
,
jnp
.
float32
)
if
QuantizeConfig
.
AMAX_COMPUTE_ALGO
is
AmaxComputeAlgo
.
MAX
:
amax
=
jnp
.
max
(
amax_history
,
axis
=-
1
,
keepdims
=
True
)
else
:
amax
=
amax_history
[
0
:
1
]
sf
=
(
fp8_max
/
amax
)
/
(
2
**
QuantizeConfig
.
MARGIN
)
sf
=
jnp
.
where
(
amax
>
0.0
,
sf
,
scale
)
sf
=
jnp
.
where
(
jnp
.
isfinite
(
amax
),
sf
,
scale
)
scale
=
scale
.
at
[
0
].
set
(
sf
[
0
])
return
scale
return
compute_scale_from_amax
(
amax
,
q_dtype
,
scale
=
scale
)
@
staticmethod
@
jax
.
jit
...
...
@@ -531,6 +601,7 @@ class QuantizerFactory:
quantizer_type_map
=
{
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScaleQuantizer
,
ScalingMode
.
CURRENT_TENSOR_SCALING
:
CurrentScaleQuantizer
,
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScaleQuantizer
,
}
...
...
transformer_engine/jax/quantize/scaling_modes.py
View file @
f8c2af4c
...
...
@@ -95,10 +95,10 @@ class ScalingModeMetadataImpl(ABC):
"""
class
Delayed
ScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for
delayed
scaling mode.
class
Current
ScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for
current
scaling mode.
This implementation provides metadata for
delayed
scaling mode, including scale data type and shape.
This implementation provides metadata for
current
scaling mode, including scale data type and shape.
"""
def
get_scale_dtype
(
self
)
->
jnp
.
dtype
:
...
...
@@ -148,6 +148,13 @@ class DelayedScalingModeMetadataImpl(ScalingModeMetadataImpl):
return
QuantizeShardyRules
(
input_spec
,
(
unique_var
,),
(
unique_var
,),
{})
class
DelayedScalingModeMetadataImpl
(
CurrentScalingModeMetadataImpl
):
"""Implementation for delayed scaling mode.
This implementation provides metadata for delayed scaling mode, including scale data type and shape.
"""
class
BlockScalingModeMetadataImpl
(
ScalingModeMetadataImpl
):
"""Implementation for block scaling mode.
...
...
@@ -317,12 +324,14 @@ class ScalingMode(Enum):
This class defines the available scaling modes for tensor quantization:
- DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales
- MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales
- CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales
- NO_SCALING: No scaling applied
"""
NO_SCALING
=
JAXX_Scaling_Mode
.
NO_SCALING
DELAYED_TENSOR_SCALING
=
JAXX_Scaling_Mode
.
DELAYED_TENSOR_SCALING
MXFP8_1D_SCALING
=
JAXX_Scaling_Mode
.
MXFP8_1D_SCALING
CURRENT_TENSOR_SCALING
=
JAXX_Scaling_Mode
.
CURRENT_TENSOR_SCALING
def
_get_impl
(
self
)
->
ScalingModeMetadataImpl
:
"""Get the implementation for this scaling mode.
...
...
@@ -395,6 +404,25 @@ class ScalingMode(Enum):
"""
return
self
.
_get_impl
().
get_shardy_sharding_rules
(
input_rank
,
unique_var
,
flatten_axis
)
def
is_tensor_scaling
(
self
)
->
bool
:
"""Check if this scaling mode is per-tensor scaling.
Returns:
True if the scaling mode is tensor scaling, False otherwise
"""
return
self
in
(
ScalingMode
.
DELAYED_TENSOR_SCALING
,
ScalingMode
.
CURRENT_TENSOR_SCALING
,
)
def
is_1d_block_scaling
(
self
)
->
bool
:
"""Check if this scaling mode is 1D block scaling.
Returns:
True if the scaling mode is 1D block scaling, False otherwise
"""
return
self
==
ScalingMode
.
MXFP8_1D_SCALING
def
__eq__
(
self
,
other
):
"""Compare this scaling mode with another.
...
...
@@ -434,5 +462,6 @@ SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = {
ScalingMode
.
DELAYED_TENSOR_SCALING
:
DelayedScalingModeMetadataImpl
(),
ScalingMode
.
MXFP8_1D_SCALING
:
BlockScalingModeMetadataImpl
(
block_dims
=
(
1
,
32
)),
# WAR
ScalingMode
.
CURRENT_TENSOR_SCALING
:
CurrentScalingModeMetadataImpl
(),
ScalingMode
.
NO_SCALING
:
DelayedScalingModeMetadataImpl
(),
}
transformer_engine/jax/setup.py
View file @
f8c2af4c
...
...
@@ -84,6 +84,7 @@ if __name__ == "__main__":
The script requires JAX to be installed for building.
It will raise a RuntimeError if JAX is not available.
"""
# Extensions
common_headers_dir
=
"common_headers"
copy_common_headers
(
current_file_path
.
parent
,
str
(
current_file_path
/
common_headers_dir
))
...
...
@@ -100,6 +101,17 @@ if __name__ == "__main__":
description
=
"Transformer acceleration library - Jax Lib"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
setup_requires
=
[
"jax[cuda12]"
,
"flax>=0.7.1"
,
"nvidia-cuda-runtime-cu12"
,
"nvidia-cublas-cu12"
,
"nvidia-cudnn-cu12"
,
"nvidia-cuda-cccl-cu12"
,
"nvidia-cuda-nvcc-cu12"
,
"nvidia-nvtx-cu12"
,
"nvidia-cuda-nvrtc-cu12"
,
],
install_requires
=
[
"jax"
,
"flax>=0.7.1"
],
tests_require
=
[
"numpy"
],
)
...
...
transformer_engine/jax/sharding.py
View file @
f8c2af4c
...
...
@@ -13,7 +13,7 @@ import os
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
enum
import
Enum
from
typing
import
Callable
from
typing
import
Callable
,
Optional
from
jax.interpreters
import
pxla
import
jax
import
jax.numpy
as
jnp
...
...
@@ -112,9 +112,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec):
return
jax
.
lax
.
with_sharding_constraint
(
x
,
pspec
)
def
with_sharding_constraint_by_logical_axes
(
x
:
jnp
.
array
,
logical_axis_names
:
tuple
|
list
):
def
with_sharding_constraint_by_logical_axes
(
x
:
jnp
.
array
,
logical_axis_names
:
Optional
[
tuple
|
list
]
):
"""
A wrapper function to jax.lax.with_sharding_constraint to accept logical axes.
If logical_axis_names = None, this means no sharding constraint is applied.
If logical_axis_names = (None, None, ...), this means a sharding constraint is applied and the tensor is replicated across all devices.
Args:
x: Input tensor to apply sharding constraint
logical_axis_names: Logical axis names to apply sharding constraint
Returns:
Tensor with sharding constraint applied, or the original tensor if no logical axes are provided.
"""
if
not
logical_axis_names
:
return
x
...
...
@@ -321,7 +334,9 @@ class ShardingType(Enum):
DP_TP_ROW
=
(
MajorShardingType
.
DPTP
,
"dp_tp_row"
)
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
,
contracting_dims
):
def
get_non_contracting_logical_axes
(
ndim
,
logical_axes
:
tuple
[
Optional
[
str
]],
contracting_dims
)
->
tuple
[
Optional
[
str
]]:
"""Get logical axes for non-contracting dimensions.
Args:
...
...
@@ -332,11 +347,8 @@ def get_non_contracting_logical_axes(ndim, logical_axes, contracting_dims):
Returns:
Tuple of logical axes for non-contracting dimensions.
"""
if
not
logical_axes
:
logical_axes
=
(
None
,)
*
ndim
elif
len
(
logical_axes
)
<
ndim
:
logical_axes
=
logical_axes
+
(
None
,)
*
(
ndim
-
len
(
logical_axes
))
assert
len
(
logical_axes
)
==
ndim
assert
logical_axes
is
not
None
,
"Logical axes must be a tuple and cannot be None."
assert
len
(
logical_axes
)
==
ndim
,
"Logical axes must match the number of dimensions."
non_contracting_dims
=
[
i
for
i
in
range
(
ndim
)
if
i
not
in
contracting_dims
]
non_contracting_logical_axes
=
tuple
(
logical_axes
[
i
]
for
i
in
non_contracting_dims
)
...
...
transformer_engine/pytorch/__init__.py
View file @
f8c2af4c
...
...
@@ -4,22 +4,14 @@
"""Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position
,wrong-import-order
# pylint: disable=wrong-import-position
import
logging
import
functools
import
sys
import
importlib
import
importlib.util
from
importlib.metadata
import
version
from
packaging.version
import
Version
as
PkgVersion
import
torch
from
transformer_engine.common
import
get_te_path
,
is_package_installed
from
transformer_engine.common
import
_get_sys_extension
_logger
=
logging
.
getLogger
(
__name__
)
from
transformer_engine.common
import
load_framework_extension
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -28,57 +20,10 @@ def torch_version() -> tuple[int, ...]:
return
PkgVersion
(
str
(
torch
.
__version__
)).
release
def
_load_library
():
"""Load shared library with Transformer Engine C extensions"""
module_name
=
"transformer_engine_torch"
if
is_package_installed
(
module_name
):
assert
is_package_installed
(
"transformer_engine"
),
"Could not find `transformer-engine`."
assert
is_package_installed
(
"transformer_engine_cu12"
),
"Could not find `transformer-engine-cu12`."
assert
(
version
(
module_name
)
==
version
(
"transformer-engine"
)
==
version
(
"transformer-engine-cu12"
)
),
(
"TransformerEngine package version mismatch. Found"
f
"
{
module_name
}
v
{
version
(
module_name
)
}
, transformer-engine"
f
" v
{
version
(
'transformer-engine'
)
}
, and transformer-engine-cu12"
f
" v
{
version
(
'transformer-engine-cu12'
)
}
. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'"
)
if
is_package_installed
(
"transformer-engine-cu12"
):
if
not
is_package_installed
(
module_name
):
_logger
.
info
(
"Could not find package %s. Install transformer-engine using "
"'pip3 install transformer-engine[pytorch]==VERSION'"
,
module_name
,
)
extension
=
_get_sys_extension
()
try
:
so_dir
=
get_te_path
()
/
"transformer_engine"
so_path
=
next
(
so_dir
.
glob
(
f
"
{
module_name
}
.*.
{
extension
}
"
))
except
StopIteration
:
try
:
so_dir
=
get_te_path
()
/
"transformer_engine"
/
"wheel_lib"
so_path
=
next
(
so_dir
.
glob
(
f
"
{
module_name
}
.*.
{
extension
}
"
))
except
StopIteration
:
so_dir
=
get_te_path
()
so_path
=
next
(
so_dir
.
glob
(
f
"
{
module_name
}
.*.
{
extension
}
"
))
spec
=
importlib
.
util
.
spec_from_file_location
(
module_name
,
so_path
)
solib
=
importlib
.
util
.
module_from_spec
(
spec
)
sys
.
modules
[
module_name
]
=
solib
spec
.
loader
.
exec_module
(
solib
)
assert
torch_version
()
>=
(
2
,
1
),
f
"Minimum torch version 2.1 required. Found
{
torch_version
()
}
."
_
load_
library
(
)
load_
framework_extension
(
"torch"
)
from
transformer_engine.pytorch.module
import
LayerNormLinear
from
transformer_engine.pytorch.module
import
Linear
from
transformer_engine.pytorch.module
import
LayerNormMLP
...
...
@@ -90,7 +35,8 @@ from transformer_engine.pytorch.module import initialize_ub
from
transformer_engine.pytorch.module
import
destroy_ub
from
transformer_engine.pytorch.attention
import
DotProductAttention
from
transformer_engine.pytorch.attention
import
MultiheadAttention
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.attention
import
InferenceParams
from
transformer_engine.pytorch.attention
import
RotaryPositionEmbedding
from
transformer_engine.pytorch.transformer
import
TransformerLayer
from
transformer_engine.pytorch.permutation
import
(
moe_permute
,
...
...
Prev
1
…
3
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment